| | import torch |
| | from typing import Optional, Tuple |
| |
|
| | from .triton_attention import ( |
| | fused_mha_with_paged_cache, fused_mha_with_cache |
| | ) |
| |
|
| | dtype_int = torch.int32 |
| |
|
| | def fused_mha_interface( |
| | query_states: torch.Tensor, |
| | key_states: torch.Tensor, |
| | value_states: torch.Tensor, |
| | k_cache: torch.Tensor, |
| | v_cache: torch.Tensor, |
| | position_ids: torch.Tensor=None, |
| | page_table: torch.Tensor=None, |
| | max_seq_len = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Replacement for _flash_attention_forward(...) that uses |
| | Tritonβs fused_mha_with_paged_cache under the hood. |
| | Returns: [batch, q_len, heads*head_dim] |
| | """ |
| | |
| | b, ql, n_heads, head_dim = query_states.shape |
| | _, kvl, n_kv_heads, _ = key_states.shape |
| |
|
| | q = query_states.reshape(b, ql, n_heads * head_dim) |
| | k = key_states.reshape(b, kvl, n_kv_heads * head_dim) |
| | v = value_states.reshape(b, kvl, n_kv_heads * head_dim) |
| |
|
| | if position_ids is not None: |
| | if ql == 1: |
| | input_pos = position_ids[:, -1] |
| | else: |
| | input_pos = position_ids[:, 0] |
| | else: |
| | |
| | input_pos = torch.zeros(b, device=q.device, dtype=torch.int32) |
| | |
| | freqs_cis = None |
| | |
| | if page_table is None: |
| | y = torch.ops.attention.fused_mha_with_cache( |
| | q, k, v, |
| | input_pos, |
| | k_cache, v_cache, |
| | freqs_cis, |
| | ) |
| |
|
| | |
| | else: |
| | batch_size = b |
| | |
| | |
| | cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int) |
| |
|
| | |
| | input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int) |
| |
|
| | |
| | seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int) |
| |
|
| | |
| | seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int) |
| |
|
| | assert max_seq_len is not None, "max_seq_len must be provided when using paged attention." |
| |
|
| | y = torch.ops.attention.fused_mha_with_paged_cache( |
| | q, k, v, |
| | input_positions, cache_loc, |
| | seq_len, seq_start, |
| | page_table, max_seq_len, |
| | k_cache, v_cache, |
| | freqs_cis, |
| | ) |
| | |
| | y = y.view(b, ql, n_heads, head_dim) |
| | |
| | return y |
| |
|
| |
|
| |
|
| | def main(): |
| | |
| | batch_size = 1 |
| | q_len = 1 |
| | kv_len = 1 |
| | num_heads = 16 |
| | n_kv_heads = 16 |
| | head_dim = 128 |
| | |
| | max_batch_size = 1 |
| | max_seq_len = 1024 |
| | |
| | page_size = 256 |
| |
|
| | device = "cuda" |
| |
|
| | |
| | query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device) |
| | key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
| | value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
| | |
| | k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
| | v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
| |
|
| | attn_out = fused_mha_interface( |
| | query_states, |
| | key_states, |
| | value_states, |
| | k_cache=k_cache, |
| | v_cache=v_cache, |
| | ) |
| | |
| | expected_shape = (batch_size, q_len, num_heads, head_dim) |
| | print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})") |
| |
|
| | if attn_out.shape == expected_shape: |
| | print("[test] β
Success: output tensor has correct shape.") |
| | else: |
| | print("[test] β Failure: shape mismatch.") |
| |
|
| | if __name__ == "__main__": |
| | main() |