|
|
""" |
|
|
Vanilla Transformer 的标准 Softmax Attention |
|
|
用于替换 flash_attn 的实现 |
|
|
""" |
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
def vanilla_attention_std( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
causal: bool = True, |
|
|
window_size: Optional[Tuple[int, int]] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
标准 Softmax Attention,兼容 flash_attn_func 的输入格式 |
|
|
|
|
|
Args: |
|
|
q, k, v: [batch, seq_len, num_heads, head_dim] 格式 |
|
|
causal: 是否使用因果mask |
|
|
window_size: 滑动窗口大小 (left, right),(-1, -1) 表示无限制 |
|
|
sm_scale: softmax 缩放因子 |
|
|
|
|
|
Returns: |
|
|
output: [batch, seq_len, num_heads, head_dim] 格式 |
|
|
""" |
|
|
B, T_q, H, D = q.shape |
|
|
T_k = k.shape[1] |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
|
|
|
q = rearrange(q, 'b t h d -> b h t d') |
|
|
k = rearrange(k, 'b t h d -> b h t d') |
|
|
v = rearrange(v, 'b t h d -> b h t d') |
|
|
|
|
|
|
|
|
scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
if causal: |
|
|
P_SEQ = T_k - T_q |
|
|
causal_mask = torch.triu( |
|
|
torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), |
|
|
diagonal=P_SEQ + 1 |
|
|
) |
|
|
scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
if window_size is not None and window_size != (-1, -1): |
|
|
left_window, right_window = window_size |
|
|
window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) |
|
|
for i in range(T_q): |
|
|
|
|
|
start = max(0, i - left_window) |
|
|
end = min(T_k, i + right_window + 1) |
|
|
window_mask[i, start:end] = False |
|
|
scores = scores.masked_fill(window_mask[None, None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_weights = torch.nan_to_num(attn_weights, 0.0) |
|
|
|
|
|
|
|
|
output = torch.matmul(attn_weights.to(v.dtype), v) |
|
|
|
|
|
|
|
|
output = rearrange(output, 'b h t d -> b t h d') |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def vanilla_attention_varlen_std( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
cu_seqlens_q: torch.Tensor, |
|
|
cu_seqlens_k: torch.Tensor, |
|
|
max_seqlen_q: int, |
|
|
max_seqlen_k: int, |
|
|
causal: bool = True, |
|
|
window_size: Optional[Tuple[int, int]] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
变长序列的标准 Softmax Attention,兼容 flash_attn_varlen_func |
|
|
|
|
|
Args: |
|
|
q: [total_q_tokens, num_heads, head_dim] |
|
|
k: [total_k_tokens, num_kv_heads, head_dim] |
|
|
v: [total_k_tokens, num_kv_heads, head_dim] |
|
|
cu_seqlens_q: 累积序列长度 [batch_size + 1] |
|
|
cu_seqlens_k: 累积序列长度 [batch_size + 1] |
|
|
max_seqlen_q: 最大查询序列长度 |
|
|
max_seqlen_k: 最大键值序列长度 |
|
|
|
|
|
Returns: |
|
|
output: [total_q_tokens, num_heads, head_dim] |
|
|
""" |
|
|
batch_size = cu_seqlens_q.shape[0] - 1 |
|
|
H = q.shape[1] |
|
|
D = q.shape[2] |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
|
|
|
for b in range(batch_size): |
|
|
q_start, q_end = cu_seqlens_q[b].item(), cu_seqlens_q[b+1].item() |
|
|
k_start, k_end = cu_seqlens_k[b].item(), cu_seqlens_k[b+1].item() |
|
|
|
|
|
if q_start == q_end: |
|
|
continue |
|
|
|
|
|
|
|
|
q_b = q[q_start:q_end] |
|
|
k_b = k[k_start:k_end] |
|
|
v_b = v[k_start:k_end] |
|
|
|
|
|
T_q = q_b.shape[0] |
|
|
T_k = k_b.shape[0] |
|
|
|
|
|
|
|
|
q_b = rearrange(q_b, 't h d -> h t d') |
|
|
k_b = rearrange(k_b, 't h d -> h t d') |
|
|
v_b = rearrange(v_b, 't h d -> h t d') |
|
|
|
|
|
|
|
|
scores = torch.matmul(q_b.float(), k_b.float().transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
if causal: |
|
|
P_SEQ = T_k - T_q |
|
|
causal_mask = torch.triu( |
|
|
torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), |
|
|
diagonal=P_SEQ + 1 |
|
|
) |
|
|
scores = scores.masked_fill(causal_mask[None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
if window_size is not None and window_size != (-1, -1): |
|
|
left_window, right_window = window_size |
|
|
window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) |
|
|
for i in range(T_q): |
|
|
start = max(0, i - left_window) |
|
|
end = min(T_k, i + right_window + 1) |
|
|
window_mask[i, start:end] = False |
|
|
scores = scores.masked_fill(window_mask[None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_weights = torch.nan_to_num(attn_weights, 0.0) |
|
|
|
|
|
|
|
|
output_b = torch.matmul(attn_weights.to(v_b.dtype), v_b) |
|
|
|
|
|
|
|
|
output_b = rearrange(output_b, 'h t d -> t h d') |
|
|
outputs.append(output_b) |
|
|
|
|
|
|
|
|
output = torch.cat(outputs, dim=0) |
|
|
|
|
|
return output |
|
|
|