avo_test_cases / final_patch.diff
jiliu1's picture
Upload folder using huggingface_hub
2622da8 verified
Raw
History Blame Contribute Delete
2.39 kB
diff --git a/host.py b/host.py
index 13dd2c0..1841e24 100644
--- a/host.py
+++ b/host.py
@@ -72,12 +72,15 @@ def _flash_attn_forward(
FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max
is_varlen = True if cu_seqlens_q is not None else False
+ # The kernel writes every (row < seqlen_q, full head_dim) element of the
+ # output (causal early-exit rows are explicitly zeroed inside the kernel),
+ # so we can skip the redundant memset of torch.zeros and use torch.empty.
if IS_FP8:
- o = torch.zeros(
+ o = torch.empty(
(q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device
)
else:
- o = torch.zeros((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device)
+ o = torch.empty((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device)
if is_varlen:
# Layout is thd.
# q and k are [total_tokens, num_head, head_dim_qk].
diff --git a/kernel_jit.py b/kernel_jit.py
index 610b36c..a87a2fd 100644
--- a/kernel_jit.py
+++ b/kernel_jit.py
@@ -895,8 +895,25 @@ def _get_config(
# TODO: pe + dropout is not tuned
if has_pe and has_dropout_or_fp32 and "pe_dropout_or_fp32" in fwd_cfg:
return fwd_cfg["pe_dropout_or_fp32"]
- elif has_pe and "pe" in fwd_cfg:
- return fwd_cfg["pe"]
+ elif has_pe:
+ # MLA prefill (head_dim_qk=192/v=128) tuned for gfx942 (MI300X).
+ # The stock "pe" config uses BLOCK_M=256, which produces too few
+ # workgroups (batch*heads*cdiv(seqlen,256)) to fill the 304 CUs for the
+ # short prefill seqlens seen here, leaving the GPU under-occupied.
+ # Halving BLOCK_M to 128 doubles workgroup count (better occupancy) and
+ # enabling 2-stage software pipelining (num_stages=2) overlaps the K/V
+ # loads with the QK/PV MFMA chain. waves_per_eu=1 + num_warps=4 keeps
+ # register/LDS pressure low enough to actually realize 2 pipeline stages
+ # (num_stages>=3 overflows the 64KB LDS for this 192/128 head config).
+ return {
+ "BLOCK_M": 128,
+ "BLOCK_N": 64,
+ "PRELOAD_V": True,
+ "waves_per_eu": 1,
+ "num_warps": 4,
+ "num_ctas": 1,
+ "num_stages": 2,
+ }
elif enable_dropout or dtype == torch.float32:
return fwd_cfg["dropout_or_fp32"]
else: