| | """ |
| | Sliding Window / Hard Attention |
| | Based on "Context Limitations Make Neural Language Models More Human-Like" |
| | (Kuribayashi et al., 2022) |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from typing import Optional |
| |
|
| |
|
| | def sliding_window_attention_std( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | *, |
| | head_first: bool = False, |
| | seq_start: Optional[torch.Tensor] = None, |
| | sm_scale: Optional[float] = None, |
| | window_size: int = 2, |
| | ) -> torch.Tensor: |
| | """ |
| | Sliding Window Attention |
| | |
| | 硬截断:只能attend到最近window_size个token |
| | """ |
| | |
| | if not head_first: |
| | q = rearrange(q, "b t h d -> b h t d") |
| | k = rearrange(k, "b t h d -> b h t d") |
| | v = rearrange(v, "b t h d -> b h t d") |
| | |
| | B, H, T_q, D = q.shape |
| | T_k = k.shape[2] |
| | |
| | if sm_scale is None: |
| | sm_scale = 1.0 / math.sqrt(D) |
| | |
| | |
| | logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
| | |
| | |
| | mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device) |
| | logits = logits.masked_fill(~mask, float('-inf')) |
| | |
| | |
| | if seq_start is not None: |
| | seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
| | logits = logits.masked_fill(seq_mask, float('-inf')) |
| | |
| | |
| | weights = F.softmax(logits, dim=-1) |
| | |
| | |
| | out = torch.matmul(weights, v) |
| | |
| | if not head_first: |
| | out = rearrange(out, "b h t d -> b t h d") |
| | |
| | return out |
| |
|
| |
|
| | def create_sliding_window_mask( |
| | T_q: int, |
| | T_k: int, |
| | window_size: int, |
| | device: torch.device |
| | ) -> torch.Tensor: |
| | """ |
| | 创建sliding window mask |
| | |
| | window_size=1: 只看前1个token (2-gram) |
| | window_size=2: 只看前2个token (3-gram) |
| | """ |
| | |
| | mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device)) |
| | |
| | |
| | if window_size > 0 and window_size < T_k: |
| | for i in range(T_q): |
| | |
| | start = max(0, i - window_size + 1) |
| | if start > 0: |
| | mask[i, :start] = False |
| | |
| | return mask[None, None, :, :] |