alibi_pile_4layer / ops /sliding_window_attention_std.py
Lanni-ni's picture
add remote code + model files
92bfd9d verified
"""
Sliding Window / Hard Attention
Based on "Context Limitations Make Neural Language Models More Human-Like"
(Kuribayashi et al., 2022)
"""
import math
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional
def sliding_window_attention_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
head_first: bool = False,
seq_start: Optional[torch.Tensor] = None,
sm_scale: Optional[float] = None,
window_size: int = 2, # 默认2-gram(看前1个token)
) -> torch.Tensor:
"""
Sliding Window Attention
硬截断:只能attend到最近window_size个token
"""
if not head_first:
q = rearrange(q, "b t h d -> b h t d")
k = rearrange(k, "b t h d -> b h t d")
v = rearrange(v, "b t h d -> b h t d")
B, H, T_q, D = q.shape
T_k = k.shape[2]
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# Compute logits
logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
# Create sliding window mask
mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device)
logits = logits.masked_fill(~mask, float('-inf'))
# Seq start mask
if seq_start is not None:
seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None]
logits = logits.masked_fill(seq_mask, float('-inf'))
# Standard softmax
weights = F.softmax(logits, dim=-1)
# Apply to values
out = torch.matmul(weights, v)
if not head_first:
out = rearrange(out, "b h t d -> b t h d")
return out
def create_sliding_window_mask(
T_q: int,
T_k: int,
window_size: int,
device: torch.device
) -> torch.Tensor:
"""
创建sliding window mask
window_size=1: 只看前1个token (2-gram)
window_size=2: 只看前2个token (3-gram)
"""
# 基础causal mask
mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device))
# 应用window限制
if window_size > 0 and window_size < T_k:
for i in range(T_q):
# 只保留 [i-window_size+1, i] 范围
start = max(0, i - window_size + 1)
if start > 0:
mask[i, :start] = False
return mask[None, None, :, :] # [1, 1, T_q, T_k]