File size: 2,386 Bytes
15063d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Sliding Window / Hard Attention
Based on "Context Limitations Make Neural Language Models More Human-Like"
(Kuribayashi et al., 2022)
"""

import math
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional


def sliding_window_attention_std(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    *,
    head_first: bool = False,
    seq_start: Optional[torch.Tensor] = None,
    sm_scale: Optional[float] = None,
    window_size: int = 2,  # 默认2-gram(看前1个token)
) -> torch.Tensor:
    """
    Sliding Window Attention
    
    硬截断:只能attend到最近window_size个token
    """
    
    if not head_first:
        q = rearrange(q, "b t h d -> b h t d")
        k = rearrange(k, "b t h d -> b h t d")
        v = rearrange(v, "b t h d -> b h t d")
    
    B, H, T_q, D = q.shape
    T_k = k.shape[2]
    
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    
    # Compute logits
    logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
    
    # Create sliding window mask
    mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device)
    logits = logits.masked_fill(~mask, float('-inf'))
    
    # Seq start mask
    if seq_start is not None:
        seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None]
        logits = logits.masked_fill(seq_mask, float('-inf'))
    
    # Standard softmax
    weights = F.softmax(logits, dim=-1)
    
    # Apply to values
    out = torch.matmul(weights, v)
    
    if not head_first:
        out = rearrange(out, "b h t d -> b t h d")
    
    return out


def create_sliding_window_mask(
    T_q: int,
    T_k: int,
    window_size: int,
    device: torch.device
) -> torch.Tensor:
    """
    创建sliding window mask
    
    window_size=1: 只看前1个token (2-gram)
    window_size=2: 只看前2个token (3-gram)
    """
    # 基础causal mask
    mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device))
    
    # 应用window限制
    if window_size > 0 and window_size < T_k:
        for i in range(T_q):
            # 只保留 [i-window_size+1, i] 范围
            start = max(0, i - window_size + 1)
            if start > 0:
                mask[i, :start] = False
    
    return mask[None, None, :, :]  # [1, 1, T_q, T_k]