Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
import math
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange, repeat
from flash_attn.utils.distributed import get_dim_for_local_rank
from flash_attn.utils.distributed import all_reduce
try:
from flash_attn import (
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
flash_attn_with_kvcache = None
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear, fused_dense_func
except ImportError:
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
try:
from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
RotaryEmbedding = None
from flash_attn.modules.mha import SelfAttention, FlashSelfAttention, LinearResidual, FlashCrossAttention, CrossAttention
from flash_attn.modules.mha import get_alibi_slopes #, _update_kv_cache
from flash_attn.utils.generation import InferenceParams
# from HybridTensor.modules.references.mha_dejavu import ParallelTracker # use this in the full implementation
# from HybridTensor.modules.references.mha_dejavu import ParallelMHASparseAttMlp
# from HybridTensor.triton.references.attention_proj_sparse import qkv_proj_sparse
# from HybridTensor.triton.select_attn import select_attn
# from HybridTensor.triton.select_attn_64b_kernel import select_attn
from HybridTensor.triton.attn_interface import flash_attn_with_kvcache_triton
from HybridTensor.triton.select_attn_v1 import select_attn
from HybridTensor.utils.utils import arg_parser, generate_BH_index, generate_random_BH_index
from HybridTensor.utils.profiling import cuda_profiler
class MHARouter(torch.nn.Module):
def __init__(self, embed_dim, low_rank_dim = None, out_dim = None, top_k = 0.5, device = None, dtype = None):
super(MHARouter, self).__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.model_dim = embed_dim
self.num_heads = out_dim
self.topk = top_k
self.linear1 = torch.nn.Linear(embed_dim, out_dim, bias = True, **factory_kwargs)
def forward(self, x):
out = self.linear1(x)
return out
def _select_heads(self, x, topk = None):
if topk is None:
topk = int(self.topk * self.num_heads)
else:
topk = int(self.num_heads * topk)
head_scores = self.forward(x)
_, selected_heads = torch.topk(head_scores, topk, dim=1)
return selected_heads
class ParallelMHARouter(torch.nn.Module):
def __init__(self, embed_dim, low_rank_dim, out_dim, top_k, process_group, sequence_parallel=False, device = None, dtype = None):
super(ParallelMHARouter, self).__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.model_dim = embed_dim
self.num_heads = out_dim
self.topk = top_k
world_size = torch.distributed.get_world_size(process_group)
self.local_heads = out_dim // world_size
self.linear1 = ColumnParallelLinear(
embed_dim,
out_dim,
process_group,
bias=True,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
def forward(self, x):
out = self.linear1(x)
return out
def _select_heads(self, x, topk = None):
if topk is None:
topk = int(self.topk * self.local_heads)
else:
topk = int(self.local_heads * topk)
head_scores = self.forward(x)
# head_scores = head_scores.squeeze(1)
# print(f"Head Scores.shape: {head_scores.shape}")
_, selected_heads = torch.topk(head_scores, topk, dim=1)
# print(f"Selected Heads: {selected_heads}")
return selected_heads
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_seqlen,
2,
num_heads,
head_dim,
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= kv_cache.shape[0]
assert sequence_end <= kv_cache.shape[1]
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
return kv_cache[batch_start:batch_end, :sequence_end, ...]
class SMHA(nn.Module):
"""Multi-head self-attention and cross-attention with Triton decode kernels + Selective Attention"""
def __init__(
self,
embed_dim,
num_heads,
num_heads_kv=None,
cross_attn=False,
qkv_proj_bias=True,
out_proj_bias=True,
dropout=0.0,
softmax_scale=None,
causal=False,
layer_idx=None,
dwconv=False,
rotary_emb_dim=0,
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
fused_bias_fc=False,
use_flash_attn=False,
return_residual=False,
checkpointing=False,
use_triton=True,
device=None,
dtype=None,
) -> None:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.cross_attn = cross_attn
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual
self.checkpointing = checkpointing
self.use_triton = use_triton
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0:
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
assert RotaryEmbedding is not None, "rotary_emb is not installed"
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim,
base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved,
device=device,
)
if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
)
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else CrossAttention
)
if not self.cross_attn:
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
if self.dwconv:
if self.num_heads_kv == self.num_heads:
self.dwconv_qkv = nn.Conv1d(
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
)
else:
self.dwconv_q = nn.Conv1d(
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
)
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
self.inner_attn = inner_attn_cls(
causal=causal,
softmax_scale=softmax_scale,
attention_dropout=dropout,
)
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert not self.dwconv, "Generation does not support dwconv yet"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_seqlen, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
rotary_cos, rotary_sin = None, None
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
def _update_kvcache_attention_triton(self, q, kv, inference_params, batch_head_idx=None):
"""
The rotary embeddings have to be applied before calling this function. The KV cache is update here.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
if (
inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = self._update_kv_cache(kv, inference_params)
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache_triton(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
None, # kv[:, :, 0],
None, #kv[:, :, 1],
rotary_cos=None,
rotary_sin=None,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved= False, #self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
batch_head_idx=batch_head_idx,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if (
inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
def forward(
self,
x,
x_kv=None,
key_padding_mask=None,
cu_seqlens=None,
max_seqlen=None,
mixer_subset=None,
inference_params=None,
batch_head_idx=None,
use_triton=True,
**kwargs,
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
batch_head_idx: (batch, num_heads). The index of the heads to be selected. Only applicable for Selective Head/Group Attention.
use_triton: whether to use triton kernels for attention in decode. If False, use the original flash attention implementation.
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.use_flash_attn
assert not self.dwconv
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
if inference_params is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
assert not self.dwconv
# use_triton = self.use_triton if use_triton is None else use_triton
kwargs = (
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
if self.use_flash_attn
else {"key_padding_mask": key_padding_mask, **kwargs}
)
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
if self.dwconv:
qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if (
inference_params is None
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
# prefill stage
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
if use_triton:
# print("Using the (prefill) triton flash attention implementation")
context = self._update_kvcache_attention_triton(
qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
)
else:
# print("Using the (prefill) original flash attention implementation")
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
# decode stage
# print("Using triton kernels for attention")
if use_triton:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
context = self._update_kvcache_attention_triton(
qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
)
else:
# print("Using the original flash attention implementation")
context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else: # cross-attention or MQA/GQA
if self.cross_attn:
if not self.return_residual:
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
else:
assert self.num_heads_kv != self.num_heads
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
if self.dwconv:
q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
kv = rearrange(
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
if (
inference_params is None
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
# prefill
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(
self.inner_cross_attn, q, kv, **kwargs
)
else:
if use_triton:
context = self._update_kvcache_attention_triton(
q, kv, inference_params, batch_head_idx
)
else:
context = self._update_kvcache_attention(q, kv, inference_params)
else:
# decode
# print("Using triton kernels for attention")
if use_triton:
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
context = self._update_kvcache_attention_triton(
q, kv, inference_params, batch_head_idx
)
else:
# print("Using the original gqa flash attention implementation")
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
# print(f"Context.shape: {context.shape}")
# print(f"Context: {context}")
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
return out if not self.return_residual else (out, x)
class ParallelSMHA(nn.Module):
"""Multi-head self-attention and cross-attention"""
def __init__(
self,
embed_dim,
num_heads,
process_group,
num_heads_kv=None,
qkv_proj_bias=True,
out_proj_bias=True,
dropout=0.0,
softmax_scale=None,
causal=False,
layer_idx=None,
rotary_emb_dim=0,
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
use_flash_attn=False,
checkpointing=False,
sequence_parallel=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing
self.process_group = process_group
self.world_size = process_group.size()
self.local_rank = torch.distributed.get_rank(process_group)
self.num_heads = num_heads
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads_kv, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
num_heads_local = math.ceil(self.num_heads / self.world_size)
alibi_slopes = torch.tensor(
get_alibi_slopes(num_heads)[
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
],
device=device,
)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim,
base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved,
device=device,
)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError("fused_dense is not installed")
self.Wqkv = ColumnParallelLinear(
embed_dim,
qkv_dim,
process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
**factory_kwargs,
)
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else CrossAttention
)
self.inner_attn = inner_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim,
**factory_kwargs,
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv_per_rank,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_seqlen, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
rotary_cos, rotary_sin = None, None
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
return context
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv = self.Wqkv(x)
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if (
inference_params is None
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else: # GQA/MQA
q = rearrange(
qkv[..., : self.num_heads_per_rank * self.head_dim],
"... (h d) -> ... h d",
d=self.head_dim,
)
kv = rearrange(
qkv[..., self.num_heads_per_rank * self.head_dim :],
"... (two hkv d) -> ... two hkv d",
two=2,
d=self.head_dim,
)
if (
inference_params is None
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(
self.inner_cross_attn, q, kv, **kwargs
)
else:
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
context = rearrange(context, "b s h d -> b s (h d)")
if seqlen is not None:
context = rearrange(context, "b s d -> (b s) d")
out = self.out_proj(context)
return out
class SelectMHA(nn.Module):
"""Multi-head, Group-query self-attention using select attention"""
def __init__(
self,
embed_dim,
num_heads,
num_heads_kv=None,
cross_attn=False,
qkv_proj_bias=True,
out_proj_bias=True,
dropout=0.0,
softmax_scale=None,
causal=False,
layer_idx=None,
dwconv=False,
rotary_emb_dim=0,
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
fused_bias_fc=False,
use_flash_attn=True,
return_residual=False,
checkpointing=False,
device=None,
dtype=None,
) -> None:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.cross_attn = cross_attn
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = True # use_flash_attn
self.return_residual = return_residual
self.checkpointing = checkpointing
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0:
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
assert RotaryEmbedding is not None, "rotary_emb is not installed"
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim,
base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved,
device=device,
)
if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
)
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
self.inner_attn = inner_attn_cls(
causal=causal,
softmax_scale=softmax_scale,
attention_dropout=dropout,
)
self.softmax_scale = softmax_scale
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""Update kv cache in inference_params."""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(
self,
x,
x_kv=None,
key_padding_mask=None,
cu_seqlens=None,
max_seqlen=None,
mixer_subset=None,
inference_params=None,
batch_head_idx=None,
**kwargs,
):
"""
Arguments:
x: (batch, seqlen, hidden_dim)
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
Shape: (batch_size, top_k)
inference_params: for generation.
"""
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads:
# Self-attention, no MQA/GQA
assert x_kv is None and mixer_subset is None
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None or inference_params.seqlen_offset == 0:
# Inference stage without inference_params
if inference_params is not None:
# Update kv cache during prefill
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
context = self.inner_attn(qkv, **kwargs)
else:
# Generation stage
if batch_head_idx is None:
# Apply select attention without kv cache update
context = self._update_kvcache_attention(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params)
else:
# Apply select attention with kv cache update
context = self._update_kvcache_select_attn(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params, batch_head_idx = batch_head_idx)
else:
raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
return out if not self.return_residual else (out, x)
def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx):
"""
Apply select attention during generation stage.
q: (batch_size, seqlen=1, n_heads, head_dim)
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
Shape: (batch_size, top_k)
# currently only supports batches with same seqlen
# different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
"""
# check batch_head_idx shape
# assert batch_head_idx.shape[0] == 2, "batch_head_idx must have shape (N_selected, 2)"
# check batch_head_idx is not None
assert batch_head_idx is not None, "batch_head_idx must not be None"
# update kv cache
kv_cache = self._update_kv_cache(kv, inference_params)
# inference_params.seqlen_offset += 1 # if seqlen_offset is int
batch = q.shape[0]
# kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
# make sure seqlen_offset accounts for the current token
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset + 1 # +1 for the current token
)
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
q = q.unsqueeze(2)
k_cache = kv_cache[:, :, 0].unsqueeze(2)
v_cache = kv_cache[:, :, 1].unsqueeze(2)
# Call select_attn
context = select_attn(
q,
k_cache,
v_cache,
self.softmax_scale,
batch_head_idx,
cache_seqlens)
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
# context = context.squeeze(2) # Remove G dimension
batch_size = batch_head_idx.shape[0]
context = context.view(batch_size, 1, self.num_heads, self.head_dim)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if (
inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
# alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
alibi_slopes = None
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_attn.softmax_scale,
causal=self.inner_attn.causal,
alibi_slopes=alibi_slopes,
)
# dummy function for testing
def _select_attn(self, q, kv, inference_params, batch_head_idx):
"""
Apply select attention during generation stage.
q: (batch_size, seqlen=1, n_heads, head_dim)
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
Shape: (N_selected, 2)
# currently only supports batches with same seqlen
# different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
"""
# check batch_head_idx shape
assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
# update kv cache
# kv_cache = self._update_kv_cache(kv, inference_params)
# inference_params.seqlen_offset += 1 # if seqlen_offset is int
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
# make sure seqlen_offset accounts for the current token
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset # +1 for the current token
)
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
q = q.unsqueeze(2)
k_cache = kv_cache[:, :, 0].unsqueeze(2)
v_cache = kv_cache[:, :, 1].unsqueeze(2)
# Call select_attn
context = select_attn(
q,
k_cache,
v_cache,
self.softmax_scale,
batch_head_idx,
cache_seqlens)
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
context = context.squeeze(2) # Remove G dimension
return context
# SelectiveGQA: Future work
class ParallelSelectMHA(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
process_group,
num_heads_kv=None,
qkv_proj_bias=True,
out_proj_bias=True,
dropout=0.0,
softmax_scale=None,
causal=True,
layer_idx=None,
dwconv=False,
rotary_emb_dim=0,
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
fused_bias_fc=True,
use_flash_attn=True,
return_residual=False,
checkpointing=False,
sequence_parallel=False,
device=None,
dtype=None,
) -> None:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual
self.checkpointing = checkpointing
self.process_group = process_group
self.world_size = process_group.size()
self.local_rank = torch.distributed.get_rank(process_group)
self.num_heads = num_heads
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads_kv, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
num_heads_local = math.ceil(self.num_heads / self.world_size)
alibi_slopes = torch.tensor(
get_alibi_slopes(num_heads)[
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
],
device=device,
)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim,
base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved,
device=device,
)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError("fused_dense is not installed")
self.Wqkv = ColumnParallelLinear(
embed_dim,
qkv_dim,
process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
**factory_kwargs,
)
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
self.inner_attn = inner_attn_cls(
causal=causal,
softmax_scale=softmax_scale,
attention_dropout=dropout,
)
self.softmax_scale = softmax_scale
# replace this with no reduce
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim,
**factory_kwargs,
)
self.mha_router = None
self.mlp_router = None
# We'll use an extra stream for concurrency
self.current_stream = None
self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
self.mha_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.mlp_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.main_event = torch.cuda.Event(enable_timing=False, blocking=False)
# self.local_head_idx = generate_random_BH_index(1, self.num_heads_per_rank,self.num_heads_per_rank)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv_per_rank,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""Update kv cache in inference_params."""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(
self,
x,
seqlen=None,
inference_params=None,
batch_head_idx=None,
**kwargs,
):
"""
Arguments:
x: (batch, seqlen, hidden_dim)
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
Shape: (N_selected,)
inference_params: for generation.
"""
router_inputs = x.squeeze(1)
self.current_stream = torch.cuda.current_stream()
self.main_stream.wait_stream(self.current_stream )
self.sparse_stream.wait_stream(self.current_stream )
is_decode = inference_params is not None and inference_params.seqlen_offset > 0
# if self.mha_router and is_decode:
# with torch.cuda.stream(self.sparse_stream):
# batch_head_idx = self.mha_router._select_heads(router_inputs)
# self.sparse_stream.record_event(self.mha_router_event)
qkv = self.Wqkv(x)
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
batch, seqlen = x.shape[:2]
if self.num_heads_kv == self.num_heads:
# Self-attention, no MQA/GQA
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None or inference_params.seqlen_offset == 0:
# Inference stage without inference_params, prefill stage
if inference_params is not None:
# Update kv cache during prefill
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
context = self.inner_attn(qkv, **kwargs)
else:
# Generation stage
# apply rotary embeddings
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
# Apply select attention with kv cache update
context = self._update_kvcache_select_attn(qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx)
else: # cross-attention, MQA/GQA
raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
context = rearrange(context, "b s h d -> b s (h d)")
if seqlen is not None:
context = rearrange(context, "b s d -> (b s) d")
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
# out = self.out_proj(context)
out = fused_dense_func(context, self.out_proj.weight, self.out_proj.bias)
# if is_decode:
# if self.mlp_router:
# with torch.cuda.stream(self.sparse_stream):
# index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
# self.sparse_stream.record_event(self.mlp_router_event)
# with torch.cuda.stream(self.main_stream):
# out = all_reduce(out, self.process_group)
# self.main_stream.record_event(self.main_event)
# self.current_stream.wait_event(self.mlp_router_event)
# self.current_stream.wait_event(self.main_event)
# # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
# # out = all_reduce(out, self.process_group)
# return out, index_vec
# else:
# out = all_reduce(out, self.process_group)
out = all_reduce(out, self.process_group)
return out if not self.return_residual else (out, x)
# return out
def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx = None):
"""
Apply select attention during generation stage.
q: (batch_size, seqlen=1, n_heads, head_dim)
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
Shape: (batch_size, top_k)
"""
# check batch_head_idx shape
# assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
# if batch_head_idx is None:
# batch_head_idx = self.local_head_idx
# print("Using local_head_idx, router not used.")
# batch_head_idx = self.local_head_idx
# print("Using local_head_idx, router not used.")
# update kv cache
kv_cache = self._update_kv_cache(kv, inference_params)
batch = q.shape[0]
# kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset + 1 # +1 for the current token
)
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
q = q.unsqueeze(2)
k_cache = kv_cache[:, :, 0].unsqueeze(2)
v_cache = kv_cache[:, :, 1].unsqueeze(2)
self.current_stream.wait_event(self.mha_router_event)
assert batch_head_idx is not None, "batch_head_idx must not be None"
# Call select_attn
context = select_attn(
q,
k_cache,
v_cache,
self.softmax_scale,
batch_head_idx,
cache_seqlens
)
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
# context = context.squeeze(2) # Remove G dimension
context = context.view(batch, 1, self.num_heads_kv_per_rank, self.head_dim)
return context
'''
PYTHONWARNINGS="ignore" python -m HybridTensor.modules.SelectiveMHA --batch_size 8 --in_features 8192 --seq_len 512 --head_density 0.25
'''
if __name__ == "__main__":
args = arg_parser()
max_seqlen = args.seq_len + 128
max_batch_size = args.batch_size
device = torch.device(f"cuda:{args.device}")
# simulates SelectiveMHA inference generation stage
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
nheads = args.in_features // 128
softmax_scale = 1 / (128 ** 0.5)
rotary_emb_dim = 0
# build SelectiveMHA
select_mha = SelectMHA(
embed_dim=args.in_features,
num_heads=nheads,
num_heads_kv=None,
causal=True,
layer_idx=0,
use_flash_attn=True,
softmax_scale=softmax_scale,
return_residual=False,
rotary_emb_dim=rotary_emb_dim,
device=device,
dtype=torch.float16,
)
standard_mha = SMHA(
embed_dim=args.in_features,
num_heads=nheads,
num_heads_kv=None,
causal=True,
layer_idx=0,
use_flash_attn=True,
softmax_scale=softmax_scale,
return_residual=False,
rotary_emb_dim=rotary_emb_dim,
device=device,
dtype=torch.float16,
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
with torch.no_grad():
# prefill stage to generate kv cache for all batches
og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
# out, time_ms = cuda_profiler(select_mha, og_x, inference_params=inference_params)
# print(f"MHA Prefill time: {time_ms:.3f} ms")
# out = select_mha(og_x, inference_params=inference_params)
# simulate kv cache, bug in flash_attn for larger batches
kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False)
_ = _update_kv_cache(kv, inference_params, 0)
# increment the sequence length to move to the generation stage
inference_params.seqlen_offset += args.seq_len
input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
selected_heads = math.ceil(nheads * args.head_density)
# generate batch_head_idx for SelectiveMHA
# batch_head_index = generate_BH_index(args.batch_size, nheads, selected_heads, device=device)
batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device)
# generatation stage Standard MHA
out, standard_time_ms = cuda_profiler(standard_mha, input_x, inference_params=inference_params)
print(f"Standard MHA time: {standard_time_ms:.3f} ms")
# generatation stage SelectiveMHA
out, select_time_ms = cuda_profiler(select_mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index)
print(f"SelectMHA time: {select_time_ms:.3f} ms")
speedup = standard_time_ms / select_time_ms
print(f"Speedup: {speedup:.3f}")