|
|
""" |
|
|
Sliding Window / Hard Attention |
|
|
Based on "Context Limitations Make Neural Language Models More Human-Like" |
|
|
(Kuribayashi et al., 2022) |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
def sliding_window_attention_std( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
window_size: int = 2, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Sliding Window Attention |
|
|
|
|
|
硬截断:只能attend到最近window_size个token |
|
|
""" |
|
|
|
|
|
if not head_first: |
|
|
q = rearrange(q, "b t h d -> b h t d") |
|
|
k = rearrange(k, "b t h d -> b h t d") |
|
|
v = rearrange(v, "b t h d -> b h t d") |
|
|
|
|
|
B, H, T_q, D = q.shape |
|
|
T_k = k.shape[2] |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
|
|
|
logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device) |
|
|
logits = logits.masked_fill(~mask, float('-inf')) |
|
|
|
|
|
|
|
|
if seq_start is not None: |
|
|
seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
|
|
logits = logits.masked_fill(seq_mask, float('-inf')) |
|
|
|
|
|
|
|
|
weights = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
out = torch.matmul(weights, v) |
|
|
|
|
|
if not head_first: |
|
|
out = rearrange(out, "b h t d -> b t h d") |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def create_sliding_window_mask( |
|
|
T_q: int, |
|
|
T_k: int, |
|
|
window_size: int, |
|
|
device: torch.device |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
创建sliding window mask |
|
|
|
|
|
window_size=1: 只看前1个token (2-gram) |
|
|
window_size=2: 只看前2个token (3-gram) |
|
|
""" |
|
|
|
|
|
mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device)) |
|
|
|
|
|
|
|
|
if window_size > 0 and window_size < T_k: |
|
|
for i in range(T_q): |
|
|
|
|
|
start = max(0, i - window_size + 1) |
|
|
if start > 0: |
|
|
mask[i, :start] = False |
|
|
|
|
|
return mask[None, None, :, :] |