|
|
""" |
|
|
Forgetting Attention - 标准 Softmax 版本 |
|
|
在 forgetting_attention.py 最后添加这个函数 |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
def forgetting_attention_std( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
log_fgate: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
) -> torch.Tensor: |
|
|
"""标准 Softmax 版本的 Forgetting Attention""" |
|
|
|
|
|
if not head_first: |
|
|
q = rearrange(q, "b t h d -> b h t d") |
|
|
k = rearrange(k, "b t h d -> b h t d") |
|
|
v = rearrange(v, "b t h d -> b h t d") |
|
|
log_fgate = rearrange(log_fgate, "b t h -> b h t") |
|
|
|
|
|
B, H, T_q, D = q.shape |
|
|
T_k = k.shape[2] |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
|
|
|
scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
log_fgate_masked = log_fgate.float() |
|
|
if seq_start is not None: |
|
|
log_fgate_masked = log_fgate_masked.clone() |
|
|
mask_idx = torch.arange(T_k, device=q.device)[None, None, :] < seq_start[:, None, None] |
|
|
log_fgate_masked[mask_idx] = 0.0 |
|
|
|
|
|
|
|
|
log_lambda = torch.cumsum(log_fgate_masked, dim=-1) |
|
|
decay_bias = log_lambda[:, :, :T_q, None] - log_lambda[:, :, None, :] |
|
|
scores = scores + decay_bias |
|
|
|
|
|
|
|
|
P_SEQ = T_k - T_q |
|
|
causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1) |
|
|
scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
if seq_start is not None: |
|
|
seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
|
|
scores = scores.masked_fill(seq_mask, float('-inf')) |
|
|
|
|
|
|
|
|
attn = F.softmax(scores, dim=-1) |
|
|
attn = torch.nan_to_num(attn, 0.0) |
|
|
|
|
|
|
|
|
out = torch.matmul(attn.to(v.dtype), v) |
|
|
|
|
|
if not head_first: |
|
|
out = rearrange(out, "b h t d -> b t h d") |
|
|
|
|
|
return out |
|
|
|