File size: 33,393 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 | # Adapted from https://github.com/Dao-AILab/flash-attention/blob/5d4c9537a1e0f1adcc3e4c3e11ae46fe94a18b11/flash_attn/cute/interface.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-10-14] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.2.1.
import copy
import gc
import logging
import math
import os
from functools import lru_cache
from typing import Callable, Optional, Tuple
logger = logging.getLogger(__name__)
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import torch
from cutlass.cute.runtime import from_dlpack
from flash_attn_origin.cute import utils
from flash_attn_origin.cute.block_sparsity import (
BlockSparseTensorsTorch,
get_block_sparse_expected_shapes,
normalize_block_sparse_tensors,
to_cute_block_sparse_tensors,
)
from flash_attn_origin.cute.flash_fwd import FlashAttentionForwardSm90
from flash_attn_origin.cute.flash_fwd_combine import FlashAttentionForwardCombine
from flash_attn_origin.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
@lru_cache(maxsize=None)
def _get_device_capability():
"""Cached device capability check."""
return torch.cuda.get_device_capability()[0]
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
assert (
t.shape == expected_shape
), f"{name} shape {t.shape} != expected {expected_shape}"
assert (
t.dtype == expected_dtype
), f"{name} dtype {t.dtype} != expected {expected_dtype}"
assert (
t.device == expected_device
), f"{name} device {t.device} != expected {expected_device}"
assert t.is_cuda, f"{name} must be on CUDA"
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True)
if fully_dynamic:
return tensor.mark_layout_dynamic()
if leading_dim == -1:
leading_dim = t.ndim - 1
return tensor.mark_layout_dynamic(leading_dim=leading_dim)
torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
}
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
# If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
if num_n_blocks <= 4:
return 1
# NOTE: We should revisit this heuristic after persistence is supported for split KV.
# Sometimes, it's ideal to over-schedule splits for better efficiency.
return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
def _flash_attn_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
softcap: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
learnable_sink: Optional[torch.Tensor] = None,
# m_block_size: int = 128,
# n_block_size: int = 64,
# num_threads: int = 128,
m_block_size: int = 128,
n_block_size: int = 128,
num_threads: int = 384,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
_compute_capability: Optional[int] = None,
score_mod: Optional[Callable] = None,
mask_mod: Optional[Callable] = None,
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
return_lse: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
aux_tensors: Optional[list[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for FlashAttention.
Args:
...
score_mod: A callable that takes the attention scores and applies a modification.
mask_mod: A callable that takes token position information and selectively masks
block_sparse_tensors: A tuple of tensors used for block sparsity.
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
out: Optional pre-allocated output tensor. If None, will be allocated internally.
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
"""
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
num_head, head_dim = q.shape[-2:]
if cu_seqlens_q is None:
batch_size, seqlen_q = q.shape[:2]
total_q = batch_size * seqlen_q
else:
batch_size = cu_seqlens_q.shape[0] - 1
seqlen_q = None
total_q = q.shape[0]
if page_table is not None:
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
assert page_table.dtype == torch.int32, "page_table must be int32"
assert (
page_table.stride(-1) == 1
), "page_table must be contiguous in the last dimension"
max_num_pages_per_seq = page_table.shape[1]
assert page_table.shape == (batch_size, max_num_pages_per_seq)
num_pages, page_size = k.shape[:2]
seqlen_k = num_pages * page_size
else:
num_pages, page_size = None, None
seqlen_k = k.shape[-3]
num_head_kv = k.shape[-2]
head_dim_v = v.shape[-1]
if cu_seqlens_k is None:
if page_table is None:
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
else:
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
else:
assert k.shape == (seqlen_k, num_head_kv, head_dim)
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
assert cu_seqlens_k.shape == (
batch_size + 1,
), "cu_seqlens_k must have shape (batch_size + 1,)"
if cu_seqlens_q is not None:
assert cu_seqlens_q.shape == (
batch_size + 1,
), "cu_seqlens_q must have shape (batch_size + 1,)"
assert seqused_q is None or seqused_q.shape == (
batch_size,
), "seqused_q must have shape (batch_size,)"
assert seqused_k is None or seqused_k.shape == (
batch_size,
), "seqused_k must have shape (batch_size,)"
assert q.dtype in [
torch.float16,
torch.bfloat16,
], "inputs must be float16 or bfloat16"
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
if t is not None:
assert (
t.dtype == torch.int32
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
assert (
t.stride(0) == 1
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
if learnable_sink is not None:
assert learnable_sink.shape == (num_head,)
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
assert all(
t is None or t.is_cuda
for t in (
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table,
learnable_sink,
)
), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
if softcap == 0.0:
softcap = None
qhead_per_kvhead = num_head // num_head_kv
if pack_gqa is None:
pack_gqa = qhead_per_kvhead > 1
out_torch_dtype = q.dtype
device = q.device
q_batch_seqlen_shape = (
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
)
lse_shape = (
(batch_size, num_head, seqlen_q)
if cu_seqlens_q is None
else (num_head, total_q)
)
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
if out is None:
out = torch.empty(
*q_batch_seqlen_shape,
num_head,
head_dim_v,
dtype=out_torch_dtype,
device=device,
)
else:
_validate_tensor(
out,
"out",
(*q_batch_seqlen_shape, num_head, head_dim_v),
out_torch_dtype,
device,
)
if lse is None:
lse = (
torch.empty(lse_shape, dtype=torch.float32, device=device)
if requires_grad or return_lse
else None
)
elif lse is not None:
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
dtype = torch2cute_dtype_map[q.dtype]
compute_capability = (
_get_device_capability() if _compute_capability is None else _compute_capability
)
assert compute_capability in [
9,
10,
11,
], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
use_block_sparsity = block_sparse_tensors is not None
if mask_mod is None:
if causal:
window_size_right = 0
local = window_size_left is not None or window_size_right is not None
if window_size_left is not None or window_size_right is not None:
if window_size_left is None and window_size_right == 0:
causal, local = True, False
window_size_right = None
else:
causal, local = False, True
else:
causal, local = False, False
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
if compute_capability == 9: # TODO: tune block size according to hdim.
if (
head_dim == head_dim_v == 128
and not causal
and not local
and not use_block_sparsity
):
n_block_size = 192
if compute_capability in [10, 11]:
if pack_gqa and (128 % qhead_per_kvhead != 0):
pack_gqa = False
# TODO: fix GQA + SplitKV + non-varlen
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
pack_gqa = False
if max_seqlen_q is None:
max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
if max_seqlen_k is None:
max_seqlen_k = seqlen_k
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
if compute_capability == 10:
q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
else:
q_stage = 1
if num_splits < 1:
m_block_size_effective = q_stage * m_block_size
seqlen_k_loaded = (
max_seqlen_k
if not local
else max(
0,
min(
max_seqlen_k,
window_size_right + window_size_left + 1 + m_block_size,
),
)
)
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
num_m_blocks = (
seqlen_q_packgqa + m_block_size_effective - 1
) // m_block_size_effective
total_mblocks = batch_size * num_head_kv * num_m_blocks
num_splits = num_splits_heuristic(
total_mblocks,
torch.cuda.get_device_properties(device).multi_processor_count,
num_n_blocks,
128,
)
is_split_kv = num_splits > 1
if is_split_kv:
out_partial = torch.empty(
num_splits,
*q_batch_seqlen_shape,
num_head,
head_dim_v,
dtype=torch.float32,
device=device,
)
lse_partial = torch.empty(
num_splits, *lse_shape, dtype=torch.float32, device=device
)
# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
if softcap is not None:
assert score_mod is None, "softcap and score_mod cannot be used together"
score_mod = utils.create_softcap_scoremod(softcap)
is_varlen = (
cu_seqlens_q is not None
or cu_seqlens_k is not None
or seqused_q is not None
or seqused_k is not None
)
if mask_mod is not None:
if is_varlen:
raise NotImplementedError(
"mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
)
if use_block_sparsity:
if is_varlen:
raise NotImplementedError(
"Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
)
# NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
pack_gqa = False
if is_split_kv:
raise NotImplementedError(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
)
compile_key = (
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
causal,
score_mod_hash,
mask_mod_hash,
use_block_sparsity,
len(aux_tensors) if aux_tensors is not None else 0,
lse is None,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
page_table is not None,
window_size_left is not None,
window_size_right is not None,
learnable_sink is not None,
m_block_size,
n_block_size,
q_stage,
num_threads,
is_split_kv,
pack_gqa,
compute_capability,
page_size not in [None, 128], # paged KV non-TMA
)
if compile_key not in _flash_attn_fwd.compile_cache:
(
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
learnable_sink_tensor,
) = [
to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = (
to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
if page_table is not None
else None
)
q_tensor, k_tensor, v_tensor, o_tensor = [
to_cute_tensor(t)
for t in (q, k, v, out if not is_split_kv else out_partial)
]
if is_split_kv:
lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
elif lse is not None:
lse_tensor = to_cute_tensor(lse, assumed_align=4)
else:
lse_tensor = None
sparse_tensors = None
if block_sparse_tensors is not None:
if seqlen_q is None:
raise ValueError(
"Block sparsity requires fixed-length sequences (seqlen_q must be known)."
)
expected_count_shape, expected_index_shape = (
get_block_sparse_expected_shapes(
batch_size,
num_head,
seqlen_q,
seqlen_k,
m_block_size,
n_block_size,
q_stage,
)
)
compile_time_normalized = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized)
cute_aux_tensors = None
if aux_tensors is not None:
cute_aux_tensors = [
to_cute_tensor(buf, assumed_align=None, fully_dynamic=True)
for buf in aux_tensors
]
if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
assert not is_split_kv, "SplitKV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
is_causal=causal,
is_local=local,
pack_gqa=pack_gqa,
tile_m=m_block_size,
tile_n=n_block_size,
# num_stages=1,
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
intra_wg_overlap=True,
mma_pv_is_rs=True,
mask_mod=mask_mod,
score_mod=score_mod,
has_aux_tensors=aux_tensors is not None,
)
elif compute_capability in [10, 11]:
fa_fwd = FlashAttentionForwardSm100(
head_dim,
head_dim_v,
qhead_per_kvhead=qhead_per_kvhead,
is_causal=causal,
is_local=local,
is_split_kv=is_split_kv,
pack_gqa=pack_gqa,
m_block_size=m_block_size,
n_block_size=n_block_size,
q_stage=q_stage,
is_persistent=not causal
and not local
and cu_seqlens_q is None
and seqused_q is None
and not is_split_kv,
score_mod=score_mod,
mask_mod=mask_mod,
has_aux_tensors=aux_tensors is not None,
paged_kv_non_tma=page_size not in [None, 128],
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
)
else:
raise ValueError(
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x"
)
# TODO: check @can_implement
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd,
q_tensor,
k_tensor,
v_tensor,
o_tensor,
lse_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
window_size_left,
window_size_right,
learnable_sink_tensor,
sparse_tensors,
cute_aux_tensors,
options="--enable-tvm-ffi",
)
# Expand block sparse tensors to match actual head count (may be broadcast from 1)
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size,
num_head,
seqlen_q,
seqlen_k,
m_block_size,
n_block_size,
q_stage,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
_flash_attn_fwd.compile_cache[compile_key](
q,
k,
v,
out if not is_split_kv else out_partial,
lse_partial if is_split_kv else lse,
softmax_scale,
current_stream,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table,
window_size_left,
window_size_right,
learnable_sink,
normalized_block_sparse_tensors,
aux_tensors,
)
if is_split_kv:
_flash_attn_fwd_combine(
out_partial,
lse_partial.transpose(-1, -2),
out,
lse.transpose(-1, -2) if lse is not None else None,
cu_seqlens_q,
seqused_q,
)
return out, lse
_flash_attn_fwd.compile_cache = {}
def _flash_attn_fwd_combine(
out_partial: torch.Tensor,
lse_partial: torch.Tensor,
out: torch.Tensor,
lse: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
seqused: Optional[torch.Tensor] = None,
num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
semaphore_to_reset: Optional[torch.Tensor] = None,
) -> None:
"""Forward combine kernel for split attention computation.
Combines partial outputs and log-sum-exp values from multiple splits
of attention computation into final outputs.
Args:
out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or
(num_splits, total_q, nheads, headdim) if there's cu_seqlens
lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or
(num_splits, total_q, nheads) if there's cu_seqlens
out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens
lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.
cu_seqlens: Cumulative sequence lengths for variable length sequences
seqused: Used sequence lengths for each batch
num_splits_dynamic_ptr: Dynamic number of splits per batch
semaphore_to_reset: Semaphore for synchronization
k_block_size: Block size for head dimension
Returns:
None
"""
# Input validation
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
assert out_partial.dtype in [
torch.float16,
torch.bfloat16,
torch.float32,
], "out_partial must be fp16, bf16, or fp32"
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
assert (
out_partial.stride(-1) == 1
), "out_partial must be contiguous in the last dimension"
assert (
lse_partial.stride(-2) == 1
), "lse_partial must be contiguous in the seqlen dimension"
assert lse_partial.shape == out_partial.shape[:-1]
# Determine if this is variable length based on dimensions
is_varlen = out_partial.dim() == 4
# Validate output tensor shapes and types
assert out.shape == out_partial.shape[1:], "out shape mismatch"
if lse is not None:
assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
assert lse.dtype == torch.float32, "lse must be fp32"
# Validate optional tensors
for t, name in [
(cu_seqlens, "cu_seqlens"),
(seqused, "seqused"),
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
]:
if t is not None:
assert t.dtype == torch.int32, f"{name} must be int32"
assert t.is_cuda, f"{name} must be on CUDA device"
assert t.is_contiguous(), f"{name} must be contiguous"
head_dim = out_partial.shape[-1]
num_splits = out_partial.shape[0]
assert num_splits <= 256
# If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
# so that kBlockM is smaller and we have more parallelism.
k_block_size = 64 if head_dim <= 64 else 128
# We want kBlockM to be as small as possible to maximize parallelism.
# E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
m_block_size = (
8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
)
log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
if m_block_size == 8:
# If kBlockM == 8 then the minimum number of splits is 32.
# TODO: we can deal w this by using 128 threads instead
log_max_splits = max(log_max_splits, 5)
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Create combine kernel configuration
dtype = torch2cute_dtype_map[out.dtype]
dtype_partial = torch2cute_dtype_map[out_partial.dtype]
compile_key = (
dtype,
dtype_partial,
head_dim,
m_block_size,
k_block_size,
log_max_splits,
cu_seqlens is not None,
seqused is not None,
lse is not None,
)
if compile_key not in _flash_attn_fwd_combine.compile_cache:
out_partial_tensor = to_cute_tensor(
out_partial, leading_dim=4 if not is_varlen else 3
)
lse_partial_tensor = to_cute_tensor(
lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
)
out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
lse_tensor = (
to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
if lse is not None
else None
)
optional_tensors = [
to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
]
(
cu_seqlens_tensor,
seqused_tensor,
num_splits_dynamic_tensor,
semaphore_tensor,
) = optional_tensors
fa_combine = FlashAttentionForwardCombine(
dtype=dtype,
dtype_partial=dtype_partial,
head_dim=head_dim,
m_block_size=m_block_size,
k_block_size=k_block_size,
log_max_splits=log_max_splits,
)
# Check if implementation is supported
if not fa_combine.can_implement(
dtype,
dtype_partial,
head_dim,
m_block_size,
k_block_size,
log_max_splits,
num_threads=256,
):
raise RuntimeError(
"FlashAttention combine kernel cannot be implemented with given parameters"
)
_flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
fa_combine,
out_partial_tensor,
lse_partial_tensor,
out_tensor,
lse_tensor,
cu_seqlens_tensor,
seqused_tensor,
num_splits_dynamic_tensor,
semaphore_tensor,
current_stream,
options="--enable-tvm-ffi",
)
_flash_attn_fwd_combine.compile_cache[compile_key](
out_partial,
lse_partial,
out,
lse,
cu_seqlens,
seqused,
num_splits_dynamic_ptr,
semaphore_to_reset,
current_stream,
)
_flash_attn_fwd_combine.compile_cache = {}
def warmup_flash_attn(f):
"""
Decorator for flash_attn_varlen_func:
- On first call, run several warmup passes with different flag combinations:
* return_softmax_lse in {False, True}
* global noncausal (window_size=(None,None))
* causal (window_size=(None,0))
* local sliding window (window_size=(64,64))
* optionally pack_gqa=True if qheads > kvheads and allowed
- No score_mod / softcap (not supported for varlen yet)
- Executes sequentially to minimize peak GPU mem
- Does not modify user tensors (clones)
"""
disable_warmup = os.getenv("SGLANG_DISABLE_FA4_WARMUP", "").lower() in (
"1",
"true",
"yes",
"on",
)
if disable_warmup:
return f
done = False
def _clone_args(args, kwargs):
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
def maybe_clone(x):
if isinstance(x, torch.Tensor):
return x.detach().clone() # detach to avoid autograd edges
return copy.deepcopy(x)
return tuple(maybe_clone(a) for a in args), {
k: maybe_clone(v) for k, v in kwargs.items()
}
def _infer_heads(args, kwargs):
"""Infer q and kv head counts from arguments."""
# Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
q = args[0] if len(args) > 0 else kwargs.get("q")
k = args[1] if len(args) > 1 else kwargs.get("k")
try:
qh = int(q.shape[-2])
kvh = int(k.shape[-2])
return qh, kvh
except Exception:
return None, None
def _run_warmups(args, kwargs):
"""Run warmup calls sequentially and release memory after each."""
base_args, base_kwargs = _clone_args(args, kwargs)
qh, kvh = _infer_heads(base_args, base_kwargs)
can_pack_gqa = (
qh is not None and kvh is not None and qh % kvh == 0 and qh // kvh > 1
)
has_page_table = (
"page_table" in base_kwargs and base_kwargs["page_table"] is not None
)
# Window presets covering global, causal, and local
window_presets = [
(None, None), # global noncausal
(None, 0), # causal
(64, 64), # local sliding window
]
lse_flags = [False, True]
# Base combo list
combos = []
for ws in window_presets:
for return_lse_flag in lse_flags:
combos.append(dict(window_size=ws, return_softmax_lse=return_lse_flag))
# Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs)
if can_pack_gqa:
for ws in window_presets:
combos.append(
dict(window_size=ws, return_softmax_lse=False, pack_gqa=True)
)
# If page_table is present, warm one combo with it (page_table in compile key for SM100)
if has_page_table:
combos.append(dict(window_size=(None, None), return_softmax_lse=False))
# Run sequentially
for combo in combos:
wa, wk = _clone_args(base_args, base_kwargs)
# Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported)
wk.pop("score_mod", None)
if "softcap" in wk and wk["softcap"]:
wk["softcap"] = 0.0
# Apply combo
wk.update(combo)
with torch.cuda.stream(torch.cuda.current_stream()):
try:
f(*wa, **wk)
except Exception as e:
# Some combos can be invalid for specific head dims / arch. Ignore and continue.
logger.debug("Warmup combo skipped: %s", e)
del wa, wk
torch.cuda.empty_cache()
gc.collect()
def wrapper(*args, **kwargs):
nonlocal done
if not done:
logger.info(
"Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..."
)
_run_warmups(args, kwargs)
done = True
return f(*args, **kwargs)
return wrapper
@warmup_flash_attn
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
learnable_sink: Optional[torch.Tensor] = None,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
return_softmax_lse: Optional[bool] = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
learnable_sink=learnable_sink,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
return_lse=return_softmax_lse,
score_mod=score_mod,
aux_tensors=aux_tensors,
)
return (out, lse) if return_softmax_lse else out
|