diff --git a/host.py b/host.py index 13dd2c0..1841e24 100644 --- a/host.py +++ b/host.py @@ -72,12 +72,15 @@ def _flash_attn_forward( FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max is_varlen = True if cu_seqlens_q is not None else False + # The kernel writes every (row < seqlen_q, full head_dim) element of the + # output (causal early-exit rows are explicitly zeroed inside the kernel), + # so we can skip the redundant memset of torch.zeros and use torch.empty. if IS_FP8: - o = torch.zeros( + o = torch.empty( (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device ) else: - o = torch.zeros((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) + o = torch.empty((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) if is_varlen: # Layout is thd. # q and k are [total_tokens, num_head, head_dim_qk]. diff --git a/kernel_jit.py b/kernel_jit.py index 610b36c..a87a2fd 100644 --- a/kernel_jit.py +++ b/kernel_jit.py @@ -895,8 +895,25 @@ def _get_config( # TODO: pe + dropout is not tuned if has_pe and has_dropout_or_fp32 and "pe_dropout_or_fp32" in fwd_cfg: return fwd_cfg["pe_dropout_or_fp32"] - elif has_pe and "pe" in fwd_cfg: - return fwd_cfg["pe"] + elif has_pe: + # MLA prefill (head_dim_qk=192/v=128) tuned for gfx942 (MI300X). + # The stock "pe" config uses BLOCK_M=256, which produces too few + # workgroups (batch*heads*cdiv(seqlen,256)) to fill the 304 CUs for the + # short prefill seqlens seen here, leaving the GPU under-occupied. + # Halving BLOCK_M to 128 doubles workgroup count (better occupancy) and + # enabling 2-stage software pipelining (num_stages=2) overlaps the K/V + # loads with the QK/PV MFMA chain. waves_per_eu=1 + num_warps=4 keeps + # register/LDS pressure low enough to actually realize 2 pipeline stages + # (num_stages>=3 overflows the 64KB LDS for this 192/128 head config). + return { + "BLOCK_M": 128, + "BLOCK_N": 64, + "PRELOAD_V": True, + "waves_per_eu": 1, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2, + } elif enable_dropout or dtype == torch.float32: return fwd_cfg["dropout_or_fp32"] else: