| """ |
| Sliding Window / Hard Attention |
| Based on "Context Limitations Make Neural Language Models More Human-Like" |
| (Kuribayashi et al., 2022) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from typing import Optional |
|
|
|
|
| def sliding_window_attention_std( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| head_first: bool = False, |
| seq_start: Optional[torch.Tensor] = None, |
| sm_scale: Optional[float] = None, |
| window_size: int = 2, |
| ) -> torch.Tensor: |
| """ |
| Sliding Window Attention |
| |
| 硬截断:只能attend到最近window_size个token |
| """ |
| |
| if not head_first: |
| q = rearrange(q, "b t h d -> b h t d") |
| k = rearrange(k, "b t h d -> b h t d") |
| v = rearrange(v, "b t h d -> b h t d") |
| |
| B, H, T_q, D = q.shape |
| T_k = k.shape[2] |
| |
| if sm_scale is None: |
| sm_scale = 1.0 / math.sqrt(D) |
| |
| |
| logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
| |
| |
| mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device) |
| logits = logits.masked_fill(~mask, float('-inf')) |
| |
| |
| if seq_start is not None: |
| seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
| logits = logits.masked_fill(seq_mask, float('-inf')) |
| |
| |
| weights = F.softmax(logits, dim=-1) |
| |
| |
| out = torch.matmul(weights, v) |
| |
| if not head_first: |
| out = rearrange(out, "b h t d -> b t h d") |
| |
| return out |
|
|
|
|
| def create_sliding_window_mask( |
| T_q: int, |
| T_k: int, |
| window_size: int, |
| device: torch.device |
| ) -> torch.Tensor: |
| """ |
| 创建sliding window mask |
| |
| window_size=1: 只看前1个token (2-gram) |
| window_size=2: 只看前2个token (3-gram) |
| """ |
| |
| mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device)) |
| |
| |
| if window_size > 0 and window_size < T_k: |
| for i in range(T_q): |
| |
| start = max(0, i - window_size + 1) |
| if start > 0: |
| mask[i, :start] = False |
| |
| return mask[None, None, :, :] |