| |
| |
|
|
| import torch |
| import sys |
|
|
|
|
| |
| def persistent_lean_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| Mp: torch.Tensor, |
| Lp: torch.Tensor, |
| Op: torch.Tensor, |
| locks: torch.Tensor, |
| batch_num_block_n: torch.Tensor, |
| total_programs: int, |
| BLOCK_M: int, |
| BLOCK_N: int, |
| causal: bool, |
| batch_size: int, |
| sm_scale: torch.float16, |
| ): |
| |
| HEAD_DIM_Q, HEAD_DIM_K, HEAD_DIM_V = q.shape[-1], k.shape[-1], v.shape[-1] |
| assert ( |
| HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V |
| ), "Incompatible Q/K/V Hidden Dimensions" |
| assert HEAD_DIM_K in {16, 32, 64, 128, 256} |
|
|
| N_CTX_Q = q.shape[0] // batch_size |
| N_CTX_K = k.shape[0] |
| H = q.shape[1] |
|
|
| BLOCK_RATIO = BLOCK_M // BLOCK_N |
| print(f"BLOCK_RATIO={BLOCK_RATIO}") |
|
|
| qk_scale = sm_scale * 1.44269504 |
|
|
| ( |
| num_m_blocks, |
| num_n_blocks, |
| high_load_wgs, |
| max_tiles_per_wg, |
| tiles_per_head, |
| total_programs, |
| num_splits, |
| even_split, |
| ) = get_num_splits_and_buffer_sizes( |
| causal, |
| batch_size, |
| N_CTX_Q, |
| N_CTX_K, |
| H, |
| H, |
| HEAD_DIM_Q, |
| BLOCK_M, |
| BLOCK_N, |
| total_programs, |
| ) |
| print( |
| f"high_load_wgs={high_load_wgs}, max_tiles_per_wg={max_tiles_per_wg}, tiles_per_head={tiles_per_head}" |
| ) |
| print( |
| f"total_programs={total_programs}, num_splits={num_splits}, even_split={even_split}" |
| ) |
| print(f"num_m_blocks={num_m_blocks}, num_n_blocks={num_n_blocks}") |
|
|
| |
|
|
| o = torch.empty_like(q, dtype=v.dtype) |
|
|
| print( |
| f"q.stride(0)={q.stride(0)}, q.stride(1)={q.stride(1)}, q.stride(2)={q.stride(2)}" |
| ) |
| print( |
| f"k.stride(0)={k.stride(0)}, k.stride(1)={k.stride(1)}, k.stride(2)={k.stride(2)}" |
| ) |
|
|
| for pid in range(total_programs): |
| la_persistent( |
| pid, |
| q, |
| k, |
| v, |
| qk_scale, |
| Mp, |
| Lp, |
| Op, |
| o, |
| batch_num_block_n, |
| locks, |
| q.stride(0), |
| q.stride(1), |
| q.stride(2), |
| k.stride(0), |
| k.stride(1), |
| k.stride(2), |
| v.stride(0), |
| v.stride(1), |
| v.stride(2), |
| o.stride(0), |
| o.stride(1), |
| o.stride(2), |
| Op.stride(0), |
| Op.stride(1), |
| Op.stride(2), |
| HEAD_DIM=HEAD_DIM_K, |
| BLOCK_M=BLOCK_M, |
| BLOCK_N=BLOCK_N, |
| BLOCK_RATIO=BLOCK_RATIO, |
| batch_size=batch_size, |
| causal=causal, |
| num_m_blocks=num_m_blocks, |
| num_n_blocks=num_n_blocks, |
| |
| high_load_wgs=high_load_wgs, |
| max_tiles_per_wg=max_tiles_per_wg, |
| tiles_per_head=tiles_per_head, |
| num_splits=num_splits, |
| ) |
|
|
|
|
| def get_num_splits_and_buffer_sizes( |
| causal, |
| batch_size, |
| max_seqlen_q, |
| max_seqlen_k, |
| num_heads, |
| num_heads_k, |
| head_size, |
| BLOCK_M, |
| BLOCK_N, |
| num_SMs, |
| ): |
| |
| |
| num_m_blocks = (max_seqlen_q + BLOCK_M - 1) // BLOCK_M |
| num_n_blocks = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N |
|
|
| |
| max_seqlen_q = max_seqlen_q * num_heads // num_heads_k |
|
|
| print(f"block_m: {BLOCK_M}, block_n: {BLOCK_N} ") |
| print(f"num_m_block: {num_m_blocks}, num_n_block: {num_n_blocks} ") |
| print(f"max_seqlen_q: {max_seqlen_q}, max_seqlen_k: {max_seqlen_k}") |
| print(f"num_heads: {num_heads}, num_heads_k: {num_heads_k} ") |
|
|
| if max_seqlen_q == 1: |
| causal = False |
|
|
| tiles_per_head = 0 |
| if causal: |
| |
| for i in range(0, num_m_blocks): |
| tiles_per_head += (((i + 1) * BLOCK_M) + BLOCK_N - 1) // BLOCK_N |
| print(f"tiles_per_head={tiles_per_head}") |
| |
| tiles_per_head = tiles_per_head * batch_size |
| print(f"batch_size={batch_size}, tiles_per_head={tiles_per_head}") |
| else: |
| |
| tiles_per_head = num_m_blocks * num_n_blocks |
|
|
| total_tiles = tiles_per_head * num_heads_k |
| print(f"total_tiles={total_tiles}") |
| |
| |
| |
| lean_griddimz = num_SMs |
|
|
| |
| |
| |
| |
|
|
| |
| |
| max_tiles_per_tb = (total_tiles + lean_griddimz - 1) // lean_griddimz |
| |
|
|
| |
| num_splits = 0 |
| even_split = False |
| if total_tiles % lean_griddimz == 0: |
| even_split = True |
| num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 2) // (max_tiles_per_tb)) |
| else: |
| even_split = False |
| num_splits = 1 + ( |
| (num_n_blocks + max_tiles_per_tb - 3) // (max_tiles_per_tb - 1) |
| ) |
|
|
| |
| high_load_tbs = total_tiles - ((max_tiles_per_tb - 1) * lean_griddimz) |
|
|
| |
| num_n_blocks = num_n_blocks // batch_size |
|
|
| return ( |
| num_m_blocks, |
| num_n_blocks, |
| high_load_tbs, |
| max_tiles_per_tb, |
| tiles_per_head, |
| lean_griddimz, |
| num_splits, |
| even_split, |
| ) |
|
|
|
|
| def find_group(x, BLOCK_RATIO): |
| group_id = 0 |
| total_blocks = 0 |
| while total_blocks + (group_id + 1) * BLOCK_RATIO <= x: |
| total_blocks += (group_id + 1) * BLOCK_RATIO |
| group_id += 1 |
| print(f"find_group(): x={x}, group_id={group_id}, total_blocks={total_blocks}") |
| group_size = (group_id + 1) * BLOCK_RATIO |
| return group_id, group_size, total_blocks |
|
|
|
|
| def la_persistent( |
| pid, |
| Q, |
| K, |
| V, |
| qk_scale, |
| Mp, |
| Lp, |
| Op, |
| Out, |
| batch_num_block_n, |
| locks, |
| stride_qm, |
| stride_qh, |
| stride_qk, |
| stride_kn, |
| stride_kh, |
| stride_kk, |
| stride_vn, |
| stride_vh, |
| stride_vk, |
| stride_om, |
| stride_oh, |
| stride_on, |
| stride_oph, |
| stride_opm, |
| stride_opn, |
| HEAD_DIM, |
| BLOCK_M, |
| BLOCK_N, |
| BLOCK_RATIO, |
| batch_size, |
| causal, |
| num_m_blocks, |
| num_n_blocks, |
| |
| high_load_wgs, |
| max_tiles_per_wg, |
| tiles_per_head, |
| num_splits, |
| ): |
| current_pid = pid |
|
|
| if current_pid < high_load_wgs: |
| iter = max_tiles_per_wg * current_pid |
| cta_end_tile_gid = iter + max_tiles_per_wg |
| else: |
| iter = (max_tiles_per_wg - 1) * ( |
| current_pid - high_load_wgs |
| ) + high_load_wgs * max_tiles_per_wg |
| cta_end_tile_gid = iter + (max_tiles_per_wg - 1) |
| print( |
| f"current_pid={current_pid}, iter={iter}, cta_end_tile_gid={cta_end_tile_gid}" |
| ) |
|
|
| |
| while iter < cta_end_tile_gid: |
| |
| |
| tile_head_idx = iter // tiles_per_head |
| print(f" tile_head_idx={tile_head_idx}") |
| |
| |
| if causal: |
| tile_batch_idx = (iter % tiles_per_head) // (tiles_per_head // batch_size) |
| |
| |
| |
| per_head_tile_idx, per_head_tile_size, total_blocks = find_group( |
| iter |
| - (tile_head_idx * tiles_per_head) |
| - (tile_batch_idx * (tiles_per_head // batch_size)), |
| BLOCK_RATIO, |
| ) |
| tile_iter = ( |
| tile_head_idx * tiles_per_head |
| + (tile_batch_idx * (tiles_per_head // batch_size)) |
| + total_blocks |
| ) |
| tile_iter_end = tile_iter + (per_head_tile_size) |
| tile_idx = ( |
| tile_head_idx * batch_size + tile_batch_idx |
| ) * num_m_blocks + per_head_tile_idx |
| print(f" causal: per_head_tile_idx={per_head_tile_idx}") |
| print(f" causal: per_head_tile_size={per_head_tile_size},") |
| print(f" causal: total_blocks={total_blocks}") |
| print(f" causal: tile_batch_idx={tile_batch_idx}") |
| else: |
| tile_idx = ( |
| tile_head_idx * batch_size |
| ) |
| tile_iter = tile_head_idx * tiles_per_head |
| if batch_size == 1: |
| req_size = tiles_per_head |
| else: |
| |
| req_size = batch_num_block_n[0] |
| tile_iter_end = tile_iter + req_size |
| for b in range(1, batch_size): |
| |
| next_req_size = batch_num_block_n[b] |
| local_head_iter = iter % tiles_per_head |
| if (local_head_iter < next_req_size) and (local_head_iter >= req_size): |
| tile_iter = tile_iter + req_size |
| tile_idx = tile_idx + b |
| tile_iter_end = tile_iter + (next_req_size - req_size) |
| req_size = next_req_size |
| print( |
| f" tile_idx={tile_idx}, tile_iter={tile_iter}, tile_iter_end={tile_iter_end}" |
| ) |
| |
| local_iter = iter - tile_iter |
| |
| local_iter_end = min(tile_iter_end, cta_end_tile_gid) - tile_iter |
| print(f" local_iter={local_iter}, local_iter_end={local_iter_end}") |
|
|
| if iter == tile_iter: |
| host_block = True |
| else: |
| host_block = False |
| |
| if cta_end_tile_gid >= tile_iter_end: |
| finishing_block = True |
| else: |
| finishing_block = False |
| print(f" host_block={host_block}, finishing_block={finishing_block}") |
| offs_m = torch.arange(0, BLOCK_M) |
| offs_n = torch.arange(0, BLOCK_N) |
| offs_k = torch.arange(0, HEAD_DIM) |
|
|
| if causal: |
| b_seq_size = tile_batch_idx * num_n_blocks |
| else: |
| tile_batch_idx = tile_idx % batch_size |
| b_seq_size = 0 |
| if tile_batch_idx > 0: |
| b_seq_size = 1 |
| |
| |
| |
|
|
| k_offs = ( |
| (b_seq_size + local_iter) * BLOCK_N * stride_kn |
| + tile_head_idx * stride_kh |
| + offs_n[None, :] * stride_kn |
| + offs_k[:, None] * stride_kk |
| ) |
| v_offs = ( |
| (b_seq_size + local_iter) * BLOCK_N * stride_vn |
| + tile_head_idx * stride_vh |
| + offs_n[:, None] * stride_vn |
| + offs_k[None, :] * stride_vk |
| ) |
| print( |
| f" b_seq_size={b_seq_size}, k_offs.shape={k_offs.shape}, k_offs={k_offs}" |
| ) |
| print( |
| f" b_seq_size={b_seq_size}, v_offs.shape={v_offs.shape}, v_offs={v_offs}" |
| ) |
| |
| |
| |
| |
|
|
| if causal: |
| q_idx = per_head_tile_idx + tile_batch_idx * num_m_blocks |
| else: |
| q_idx = tile_batch_idx |
| q_offs = ( |
| q_idx * BLOCK_M * stride_qm |
| + tile_head_idx * stride_qh |
| + offs_m[:, None] * stride_qm |
| + offs_k[None, :] * stride_qk |
| ) |
| print(f" q_idx={q_idx}, q_offs.shape={q_offs.shape}, q_offs={q_offs}") |
| o_h_offs = ( |
| q_idx * BLOCK_M * stride_om |
| + tile_head_idx * stride_oh |
| + offs_m[:, None] * stride_om |
| + offs_k[None, :] * stride_on |
| ) |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| offs_m = torch.arange(BLOCK_M) |
| |
| offs_n = torch.arange(BLOCK_N) |
| for l_iter in range(local_iter, local_iter_end): |
| """ |
| if causal: |
| if (tile_iter_end - tile_iter) - l_iter <= BLOCK_RATIO: |
| OFFSN = (l_iter + tile_batch_idx * num_n_blocks) * BLOCK_N + offs_n |
| #mask = offs_m[:, None] >= offs_n[None, :] |
| mask = OFFSM[:, None] >= OFFSN[None, :] |
| #torch.set_printoptions(threshold=10000) |
| """ |
| if causal and (BLOCK_RATIO > 1): |
| if l_iter == (tile_iter_end - tile_iter) - 2: |
| mask = offs_m[:, None] >= offs_n[None, :] |
| torch.set_printoptions(threshold=10_000) |
| print(f" Inner loop: l_iter={l_iter}, mask.shape={mask.shape}") |
| torch.set_printoptions(threshold=10_000) |
| print(f" mask = {mask}") |
| |
| if l_iter == (tile_iter_end - tile_iter) - 1: |
| mask = (offs_m[:, None] >= BLOCK_N) & ( |
| offs_n[None, :] <= (offs_m[:, None] - BLOCK_N) |
| ) |
| |
| |
| torch.set_printoptions(threshold=10_000) |
| print(f" Inner loop: l_iter={l_iter}, mask.shape={mask.shape}") |
| torch.set_printoptions(threshold=10_000) |
| print(f" mask = {mask}") |
| if causal and (BLOCK_RATIO == 1): |
| |
| if (iter + (l_iter - local_iter)) == (tile_iter_end - 1): |
| mask = offs_m[:, None] >= offs_n[None, :] |
| |
|
|
| |
| |
| |
|
|
| |
| print(f" Inner loop: {local_iter} to {local_iter_end}") |
|
|
| |
| if not host_block: |
| |
| mp_ptrs = Mp + current_pid * BLOCK_M + offs_m |
| lp_ptrs = Lp + current_pid * BLOCK_M + offs_m |
| op_ptrs = ( |
| Op |
| + current_pid * stride_oph |
| + offs_m[:, None] * stride_opm |
| + offs_k[None, :] * stride_opn |
| ) |
| print(" Non host block write partial result") |
| print(f"mp_ptrs.shape={mp_ptrs.shape}") |
| print(f"mp_ptrs={mp_ptrs}") |
| print(f"op_ptrs={op_ptrs}") |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| else: |
| |
| |
|
|
| o_h_offs = ( |
| q_idx * BLOCK_M * stride_om |
| + tile_head_idx * stride_oh |
| + offs_m[:, None] * stride_om |
| + offs_k[None, :] * stride_on |
| ) |
| print(f"o_h_offs={o_h_offs}") |
| |
| if not finishing_block: |
| |
| """ |
| if causal: |
| q_idx = per_head_tile_idx + tile_batch_idx * num_m_blocks |
| else: |
| q_idx = tile_batch_idx |
| |
| o_h_offs = ( |
| q_idx * BLOCK_M * stride_om |
| + tile_head_idx * stride_oh |
| + offs_m[:, None] * stride_om |
| + offs_k[None, :] * stride_on |
| ) |
| o_ptrs = Out + o_h_offs |
| """ |
|
|
| last_cta = current_pid + 1 |
| temp_end_gid = cta_end_tile_gid |
| split = 1 |
| while (split < num_splits) and (temp_end_gid < tile_iter_end): |
| if last_cta < high_load_wgs: |
| if (tile_iter_end - temp_end_gid) < max_tiles_per_wg: |
| temp_end_gid += tile_iter_end - temp_end_gid |
| else: |
| temp_end_gid += max_tiles_per_wg |
| else: |
| if (tile_iter_end - temp_end_gid) < (max_tiles_per_wg - 1): |
| temp_end_gid += tile_iter_end - temp_end_gid |
| else: |
| temp_end_gid += max_tiles_per_wg - 1 |
|
|
| last_cta += 1 |
| split += 1 |
| |
| for cta in range((current_pid + 1), last_cta): |
| print( |
| f" Host-NonFinishing block cta{cta} loop {current_pid + 1} to {last_cta}" |
| ) |
|
|
| |
| offs_mplp = cta * BLOCK_M + offs_m |
| mp_ptrs = Mp + offs_mplp |
| lp_ptrs = Lp + offs_mplp |
| op_h_offs = ( |
| cta * stride_oph |
| + offs_m[:, None] * stride_opm |
| + offs_k[None, :] * stride_opn |
| ) |
| print(f" Host-NonFinishing block offs_mplp={offs_mplp}") |
| print(f" Host-NonFinishing block mp_ptrs={mp_ptrs}") |
| print(f" Host-NonFinishing block lp_ptrs={lp_ptrs}") |
| print(f" Host-NonFinishing block op_h_offs={op_h_offs}") |
| |
|
|
| |
| iter = iter + (local_iter_end - local_iter) |
|
|
|
|
| def main(): |
| batch = 1 |
| causal = True |
| h = 1 |
| n_ctx_q = 512 |
| n_ctx = [512] |
| d = 128 |
| total_programs = 4 |
|
|
| init_dtype = torch.float16 |
| BLOCK_M = 128 |
| BLOCK_N = 64 |
| assert batch == len(n_ctx) |
|
|
| try: |
| sum_n_ctx = sum(int(n) for n in n_ctx) |
| except ValueError: |
| print(f"N_CTX contains non-numeric values: {n_ctx}") |
|
|
| print(f"causal={causal}, batch={batch}") |
| |
| |
| |
| |
| for s in n_ctx: |
| list_num_block_n = [ |
| (int(str(s).strip()) + BLOCK_N - 1) // BLOCK_N for s in n_ctx |
| ] |
| len_sum = 0 |
| list_sum_block_n = [] |
| for i in range(batch): |
| len_sum += list_num_block_n[i] |
| list_sum_block_n.append(len_sum) |
| batch_num_block_n = torch.tensor(list_sum_block_n, dtype=torch.int32) |
|
|
| sm_scale = 0.5 |
|
|
| |
| q = torch.empty((n_ctx_q * batch, h, d), dtype=init_dtype).normal_( |
| mean=0.0, std=0.5 |
| ) |
| k = torch.empty((sum_n_ctx, h, d), dtype=init_dtype).normal_(mean=0.0, std=0.5) |
| v = torch.empty((sum_n_ctx, h, d), dtype=init_dtype).normal_(mean=0.0, std=0.5) |
|
|
| |
| |
| |
| |
| Mp = torch.empty((total_programs, BLOCK_M), device=q.device, dtype=torch.float32) |
| Lp = torch.empty((total_programs, BLOCK_M), device=q.device, dtype=torch.float32) |
| Op = torch.empty((total_programs, BLOCK_M, d), device=q.device, dtype=torch.float32) |
|
|
| locks = torch.zeros((total_programs,), device=q.device, dtype=torch.int32) |
|
|
| |
| persistent_lean_attention( |
| q, |
| k, |
| v, |
| Mp, |
| Lp, |
| Op, |
| locks, |
| batch_num_block_n, |
| total_programs, |
| BLOCK_M, |
| BLOCK_N, |
| causal, |
| batch, |
| sm_scale, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
| |
| |
| |
|
|