| |
| |
| |
| |
| @@ -72,12 +72,15 @@ def _flash_attn_forward( |
| FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max |
| is_varlen = True if cu_seqlens_q is not None else False |
| |
| + # The kernel writes every (row < seqlen_q, full head_dim) element of the |
| + # output (causal early-exit rows are explicitly zeroed inside the kernel), |
| + # so we can skip the redundant memset of torch.zeros and use torch.empty. |
| if IS_FP8: |
| - o = torch.zeros( |
| + o = torch.empty( |
| (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device |
| ) |
| else: |
| - o = torch.zeros((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) |
| + o = torch.empty((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) |
| if is_varlen: |
| # Layout is thd. |
| # q and k are [total_tokens, num_head, head_dim_qk]. |
| |
| |
| |
| |
| @@ -895,8 +895,25 @@ def _get_config( |
| # TODO: pe + dropout is not tuned |
| if has_pe and has_dropout_or_fp32 and "pe_dropout_or_fp32" in fwd_cfg: |
| return fwd_cfg["pe_dropout_or_fp32"] |
| - elif has_pe and "pe" in fwd_cfg: |
| - return fwd_cfg["pe"] |
| + elif has_pe: |
| + # MLA prefill (head_dim_qk=192/v=128) tuned for gfx942 (MI300X). |
| + # The stock "pe" config uses BLOCK_M=256, which produces too few |
| + # workgroups (batch*heads*cdiv(seqlen,256)) to fill the 304 CUs for the |
| + # short prefill seqlens seen here, leaving the GPU under-occupied. |
| + # Halving BLOCK_M to 128 doubles workgroup count (better occupancy) and |
| + # enabling 2-stage software pipelining (num_stages=2) overlaps the K/V |
| + # loads with the QK/PV MFMA chain. waves_per_eu=1 + num_warps=4 keeps |
| + # register/LDS pressure low enough to actually realize 2 pipeline stages |
| + # (num_stages>=3 overflows the 64KB LDS for this 192/128 head config). |
| + return { |
| + "BLOCK_M": 128, |
| + "BLOCK_N": 64, |
| + "PRELOAD_V": True, |
| + "waves_per_eu": 1, |
| + "num_warps": 4, |
| + "num_ctas": 1, |
| + "num_stages": 2, |
| + } |
| elif enable_dropout or dtype == torch.float32: |
| return fwd_cfg["dropout_or_fp32"] |
| else: |
|
|