File size: 5,637 Bytes
f7501a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Vanilla Transformer 的标准 Softmax Attention
用于替换 flash_attn 的实现
"""
import math
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional, Tuple

def vanilla_attention_std(
    q: torch.Tensor,
    k: torch.Tensor, 
    v: torch.Tensor,
    causal: bool = True,
    window_size: Optional[Tuple[int, int]] = None,
    sm_scale: Optional[float] = None,
) -> torch.Tensor:
    """
    标准 Softmax Attention,兼容 flash_attn_func 的输入格式
    
    Args:
        q, k, v: [batch, seq_len, num_heads, head_dim] 格式
        causal: 是否使用因果mask
        window_size: 滑动窗口大小 (left, right),(-1, -1) 表示无限制
        sm_scale: softmax 缩放因子
    
    Returns:
        output: [batch, seq_len, num_heads, head_dim] 格式
    """
    B, T_q, H, D = q.shape
    T_k = k.shape[1]
    
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    
    # 转换为 [B, H, T, D] 格式进行计算
    q = rearrange(q, 'b t h d -> b h t d')
    k = rearrange(k, 'b t h d -> b h t d')
    v = rearrange(v, 'b t h d -> b h t d')
    
    # 计算 attention scores
    scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
    
    # Causal mask
    if causal:
        P_SEQ = T_k - T_q  # 处理 KV cache 的情况
        causal_mask = torch.triu(
            torch.ones((T_q, T_k), dtype=torch.bool, device=q.device),
            diagonal=P_SEQ + 1
        )
        scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
    
    # Window mask (sliding window attention)
    if window_size is not None and window_size != (-1, -1):
        left_window, right_window = window_size
        window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device)
        for i in range(T_q):
            # 计算每个查询位置的有效窗口范围
            start = max(0, i - left_window)
            end = min(T_k, i + right_window + 1)
            window_mask[i, start:end] = False
        scores = scores.masked_fill(window_mask[None, None, :, :], float('-inf'))
    
    # Softmax
    attn_weights = F.softmax(scores, dim=-1)
    attn_weights = torch.nan_to_num(attn_weights, 0.0)
    
    # Apply attention to values
    output = torch.matmul(attn_weights.to(v.dtype), v)
    
    # 转换回 [B, T, H, D] 格式
    output = rearrange(output, 'b h t d -> b t h d')
    
    return output


def vanilla_attention_varlen_std(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    causal: bool = True,
    window_size: Optional[Tuple[int, int]] = None,
    sm_scale: Optional[float] = None,
) -> torch.Tensor:
    """
    变长序列的标准 Softmax Attention,兼容 flash_attn_varlen_func
    
    Args:
        q: [total_q_tokens, num_heads, head_dim]
        k: [total_k_tokens, num_kv_heads, head_dim]
        v: [total_k_tokens, num_kv_heads, head_dim]
        cu_seqlens_q: 累积序列长度 [batch_size + 1]
        cu_seqlens_k: 累积序列长度 [batch_size + 1]
        max_seqlen_q: 最大查询序列长度
        max_seqlen_k: 最大键值序列长度
    
    Returns:
        output: [total_q_tokens, num_heads, head_dim]
    """
    batch_size = cu_seqlens_q.shape[0] - 1
    H = q.shape[1]
    D = q.shape[2]
    
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    
    outputs = []
    
    # 逐批次处理
    for b in range(batch_size):
        q_start, q_end = cu_seqlens_q[b].item(), cu_seqlens_q[b+1].item()
        k_start, k_end = cu_seqlens_k[b].item(), cu_seqlens_k[b+1].item()
        
        if q_start == q_end:  # 空序列
            continue
            
        # 提取当前批次的 q, k, v
        q_b = q[q_start:q_end]  # [T_q, H, D]
        k_b = k[k_start:k_end]  # [T_k, H, D] 
        v_b = v[k_start:k_end]  # [T_k, H, D]
        
        T_q = q_b.shape[0]
        T_k = k_b.shape[0]
        
        # 转换为 [H, T, D] 格式
        q_b = rearrange(q_b, 't h d -> h t d')
        k_b = rearrange(k_b, 't h d -> h t d')
        v_b = rearrange(v_b, 't h d -> h t d')
        
        # 计算 attention scores
        scores = torch.matmul(q_b.float(), k_b.float().transpose(-2, -1)) * sm_scale
        
        # Causal mask
        if causal:
            P_SEQ = T_k - T_q
            causal_mask = torch.triu(
                torch.ones((T_q, T_k), dtype=torch.bool, device=q.device),
                diagonal=P_SEQ + 1
            )
            scores = scores.masked_fill(causal_mask[None, :, :], float('-inf'))
        
        # Window mask
        if window_size is not None and window_size != (-1, -1):
            left_window, right_window = window_size
            window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device)
            for i in range(T_q):
                start = max(0, i - left_window)
                end = min(T_k, i + right_window + 1)
                window_mask[i, start:end] = False
            scores = scores.masked_fill(window_mask[None, :, :], float('-inf'))
        
        # Softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = torch.nan_to_num(attn_weights, 0.0)
        
        # Apply attention
        output_b = torch.matmul(attn_weights.to(v_b.dtype), v_b)
        
        # 转换回 [T, H, D] 格式
        output_b = rearrange(output_b, 'h t d -> t h d')
        outputs.append(output_b)
    
    # 拼接所有批次的输出
    output = torch.cat(outputs, dim=0)
    
    return output