File size: 2,194 Bytes
f7501a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Forgetting Attention - 标准 Softmax 版本
在 forgetting_attention.py 最后添加这个函数
"""

import math
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional


def forgetting_attention_std(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    log_fgate: torch.Tensor,
    *,
    head_first: bool = False,
    seq_start: Optional[torch.Tensor] = None,
    sm_scale: Optional[float] = None,
) -> torch.Tensor:
    """标准 Softmax 版本的 Forgetting Attention"""
    
    if not head_first:
        q = rearrange(q, "b t h d -> b h t d")
        k = rearrange(k, "b t h d -> b h t d")
        v = rearrange(v, "b t h d -> b h t d")
        log_fgate = rearrange(log_fgate, "b t h -> b h t")
    
    B, H, T_q, D = q.shape
    T_k = k.shape[2]
    
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    
    # 计算 QK 分数
    scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
    
    # 处理 seq_start
    log_fgate_masked = log_fgate.float()
    if seq_start is not None:
        log_fgate_masked = log_fgate_masked.clone()
        mask_idx = torch.arange(T_k, device=q.device)[None, None, :] < seq_start[:, None, None]
        log_fgate_masked[mask_idx] = 0.0
    
    # 计算累积衰减
    log_lambda = torch.cumsum(log_fgate_masked, dim=-1)
    decay_bias = log_lambda[:, :, :T_q, None] - log_lambda[:, :, None, :]
    scores = scores + decay_bias
    
    # Causal mask
    P_SEQ = T_k - T_q
    causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1)
    scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
    
    # seq_start mask
    if seq_start is not None:
        seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None]
        scores = scores.masked_fill(seq_mask, float('-inf'))
    
    # Softmax
    attn = F.softmax(scores, dim=-1)
    attn = torch.nan_to_num(attn, 0.0)
    
    # 计算输出
    out = torch.matmul(attn.to(v.dtype), v)
    
    if not head_first:
        out = rearrange(out, "b h t d -> b t h d")
    
    return out