File size: 5,637 Bytes
f7501a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
"""
Vanilla Transformer 的标准 Softmax Attention
用于替换 flash_attn 的实现
"""
import math
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional, Tuple
def vanilla_attention_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
window_size: Optional[Tuple[int, int]] = None,
sm_scale: Optional[float] = None,
) -> torch.Tensor:
"""
标准 Softmax Attention,兼容 flash_attn_func 的输入格式
Args:
q, k, v: [batch, seq_len, num_heads, head_dim] 格式
causal: 是否使用因果mask
window_size: 滑动窗口大小 (left, right),(-1, -1) 表示无限制
sm_scale: softmax 缩放因子
Returns:
output: [batch, seq_len, num_heads, head_dim] 格式
"""
B, T_q, H, D = q.shape
T_k = k.shape[1]
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# 转换为 [B, H, T, D] 格式进行计算
q = rearrange(q, 'b t h d -> b h t d')
k = rearrange(k, 'b t h d -> b h t d')
v = rearrange(v, 'b t h d -> b h t d')
# 计算 attention scores
scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
# Causal mask
if causal:
P_SEQ = T_k - T_q # 处理 KV cache 的情况
causal_mask = torch.triu(
torch.ones((T_q, T_k), dtype=torch.bool, device=q.device),
diagonal=P_SEQ + 1
)
scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
# Window mask (sliding window attention)
if window_size is not None and window_size != (-1, -1):
left_window, right_window = window_size
window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device)
for i in range(T_q):
# 计算每个查询位置的有效窗口范围
start = max(0, i - left_window)
end = min(T_k, i + right_window + 1)
window_mask[i, start:end] = False
scores = scores.masked_fill(window_mask[None, None, :, :], float('-inf'))
# Softmax
attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, 0.0)
# Apply attention to values
output = torch.matmul(attn_weights.to(v.dtype), v)
# 转换回 [B, T, H, D] 格式
output = rearrange(output, 'b h t d -> b t h d')
return output
def vanilla_attention_varlen_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
causal: bool = True,
window_size: Optional[Tuple[int, int]] = None,
sm_scale: Optional[float] = None,
) -> torch.Tensor:
"""
变长序列的标准 Softmax Attention,兼容 flash_attn_varlen_func
Args:
q: [total_q_tokens, num_heads, head_dim]
k: [total_k_tokens, num_kv_heads, head_dim]
v: [total_k_tokens, num_kv_heads, head_dim]
cu_seqlens_q: 累积序列长度 [batch_size + 1]
cu_seqlens_k: 累积序列长度 [batch_size + 1]
max_seqlen_q: 最大查询序列长度
max_seqlen_k: 最大键值序列长度
Returns:
output: [total_q_tokens, num_heads, head_dim]
"""
batch_size = cu_seqlens_q.shape[0] - 1
H = q.shape[1]
D = q.shape[2]
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
outputs = []
# 逐批次处理
for b in range(batch_size):
q_start, q_end = cu_seqlens_q[b].item(), cu_seqlens_q[b+1].item()
k_start, k_end = cu_seqlens_k[b].item(), cu_seqlens_k[b+1].item()
if q_start == q_end: # 空序列
continue
# 提取当前批次的 q, k, v
q_b = q[q_start:q_end] # [T_q, H, D]
k_b = k[k_start:k_end] # [T_k, H, D]
v_b = v[k_start:k_end] # [T_k, H, D]
T_q = q_b.shape[0]
T_k = k_b.shape[0]
# 转换为 [H, T, D] 格式
q_b = rearrange(q_b, 't h d -> h t d')
k_b = rearrange(k_b, 't h d -> h t d')
v_b = rearrange(v_b, 't h d -> h t d')
# 计算 attention scores
scores = torch.matmul(q_b.float(), k_b.float().transpose(-2, -1)) * sm_scale
# Causal mask
if causal:
P_SEQ = T_k - T_q
causal_mask = torch.triu(
torch.ones((T_q, T_k), dtype=torch.bool, device=q.device),
diagonal=P_SEQ + 1
)
scores = scores.masked_fill(causal_mask[None, :, :], float('-inf'))
# Window mask
if window_size is not None and window_size != (-1, -1):
left_window, right_window = window_size
window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device)
for i in range(T_q):
start = max(0, i - left_window)
end = min(T_k, i + right_window + 1)
window_mask[i, start:end] = False
scores = scores.masked_fill(window_mask[None, :, :], float('-inf'))
# Softmax
attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, 0.0)
# Apply attention
output_b = torch.matmul(attn_weights.to(v_b.dtype), v_b)
# 转换回 [T, H, D] 格式
output_b = rearrange(output_b, 'h t d -> t h d')
outputs.append(output_b)
# 拼接所有批次的输出
output = torch.cat(outputs, dim=0)
return output
|