|
|
|
|
|
""" |
|
|
n_flex.py β Flexible Attention Mechanisms |
|
|
Constraint: Must support AR (causal), SAT (block), and NAR (bidirectional) |
|
|
|
|
|
Testing: |
|
|
1. Linear Attention - O(n) instead of O(nΒ²) |
|
|
2. Cosine Attention - Different similarity metric |
|
|
3. Differential Attention - Noise cancellation (Microsoft 2024) |
|
|
4. Local + Global - Sparse hybrid |
|
|
5. Multi-Query Attention (MQA) - Inference efficient |
|
|
6. Grouped Query Attention (GQA) - Between MHA and MQA |
|
|
7. Retention - RetNet style (recurrent + parallel) |
|
|
8. Gated Linear Attention - Recent efficient attention |
|
|
9. ReLU Attention - Simpler activation |
|
|
10. Sigmoid Attention - Bounded attention |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
import argparse, math, time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Literal |
|
|
|
|
|
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
VOCAB = 128256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mask(n: int, mode: str = "ar", block_size: int = 2): |
|
|
""" |
|
|
AR (autoregressive): causal, see only past |
|
|
SAT (semi-autoregressive): see within block + all past blocks |
|
|
NAR (non-autoregressive): bidirectional, see everything |
|
|
""" |
|
|
if mode == "nar": |
|
|
return None |
|
|
elif mode == "ar": |
|
|
return torch.triu(torch.full((n, n), float("-inf"), device=DEV), 1) |
|
|
elif mode == "sat": |
|
|
|
|
|
idx = torch.arange(n, device=DEV) |
|
|
block_idx = idx // block_size |
|
|
|
|
|
mask = torch.where( |
|
|
(block_idx.unsqueeze(0) <= block_idx.unsqueeze(1)), |
|
|
torch.tensor(0.0, device=DEV), |
|
|
torch.tensor(float("-inf"), device=DEV) |
|
|
) |
|
|
return mask |
|
|
else: |
|
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
|
|
|
|
|
def alibi_bias(n_heads: int, n_tokens: int): |
|
|
def slopes(n): |
|
|
start = 2 ** (-2 ** -(math.log2(n) - 3)) |
|
|
return [start * (start ** i) for i in range(n)] |
|
|
if n_heads > 0 and math.log2(n_heads).is_integer(): |
|
|
s = slopes(n_heads) |
|
|
else: |
|
|
closest = 2 ** math.floor(math.log2(max(1, n_heads))) |
|
|
s = slopes(closest)[:n_heads] |
|
|
s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1) |
|
|
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) |
|
|
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) |
|
|
return -s * (j - i).clamp_min(0).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StandardAttention(nn.Module): |
|
|
"""Standard multi-head attention - O(nΒ²)""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
|
|
att = att + alibi_bias(self.h, N) |
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearAttention(nn.Module): |
|
|
""" |
|
|
Linear attention: O(n) instead of O(nΒ²) |
|
|
Uses feature map Ο(x) so that Ο(q)Ο(k)^T β softmax(qk^T) |
|
|
|
|
|
Key insight: (QK^T)V = Q(K^TV) - compute K^TV first for O(n) |
|
|
|
|
|
Works with AR/SAT/NAR via cumsum tricks for causal |
|
|
""" |
|
|
def __init__(self, d: int, h: int, feature_map: str = "elu"): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
self.feature_map = feature_map |
|
|
self.eps = 1e-6 |
|
|
|
|
|
def _phi(self, x): |
|
|
"""Feature map for linear attention""" |
|
|
if self.feature_map == "elu": |
|
|
return F.elu(x) + 1 |
|
|
elif self.feature_map == "relu": |
|
|
return F.relu(x) |
|
|
elif self.feature_map == "softmax": |
|
|
return F.softmax(x, dim=-1) |
|
|
else: |
|
|
return x |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
|
|
|
q = self._phi(q) |
|
|
k = self._phi(k) |
|
|
|
|
|
if mask is None: |
|
|
|
|
|
|
|
|
kv = torch.einsum('bhnd,bhnv->bhdv', k, v) |
|
|
out = torch.einsum('bhnd,bhdv->bhnv', q, kv) |
|
|
|
|
|
|
|
|
k_sum = k.sum(dim=2, keepdim=True) |
|
|
normalizer = torch.einsum('bhnd,bhkd->bhnk', q, k_sum).clamp(min=self.eps) |
|
|
out = out / normalizer |
|
|
else: |
|
|
|
|
|
|
|
|
kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2) |
|
|
k_cumsum = torch.cumsum(k, dim=2) |
|
|
|
|
|
out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum) |
|
|
normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps) |
|
|
out = out / normalizer |
|
|
|
|
|
return self.proj(out.transpose(1, 2).reshape(B, N, -1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CosineAttention(nn.Module): |
|
|
""" |
|
|
Use cosine similarity instead of dot product. |
|
|
More stable, bounded [-1, 1] before scaling. |
|
|
""" |
|
|
def __init__(self, d: int, h: int, temp: float = 10.0): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
self.temp = nn.Parameter(torch.tensor(temp)) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
|
|
|
q = F.normalize(q, dim=-1) |
|
|
k = F.normalize(k, dim=-1) |
|
|
|
|
|
att = self.temp * (q @ k.transpose(-1, -2)) |
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DifferentialAttention(nn.Module): |
|
|
""" |
|
|
From Microsoft's "Differential Transformer" (2024) |
|
|
|
|
|
Compute two attention patterns and subtract: |
|
|
Attn = softmax(Q1 K1^T) - Ξ» * softmax(Q2 K2^T) |
|
|
|
|
|
Cancels noise, improves signal. |
|
|
""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
|
|
|
|
|
|
self.q1 = nn.Linear(d, d, bias=False) |
|
|
self.k1 = nn.Linear(d, d, bias=False) |
|
|
self.q2 = nn.Linear(d, d, bias=False) |
|
|
self.k2 = nn.Linear(d, d, bias=False) |
|
|
self.v = nn.Linear(d, d, bias=False) |
|
|
|
|
|
|
|
|
self.lambda_param = nn.Parameter(torch.tensor(0.5)) |
|
|
|
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
|
|
|
q1 = self.q1(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
k1 = self.k1(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
q2 = self.q2(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
k2 = self.k2(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
v = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
|
|
|
scale = math.sqrt(self.dk) |
|
|
|
|
|
|
|
|
att1 = (q1 @ k1.transpose(-1, -2)) / scale |
|
|
if mask is not None: |
|
|
att1 = att1 + mask.unsqueeze(0).unsqueeze(0) |
|
|
att1 = att1.softmax(-1) |
|
|
|
|
|
|
|
|
att2 = (q2 @ k2.transpose(-1, -2)) / scale |
|
|
if mask is not None: |
|
|
att2 = att2 + mask.unsqueeze(0).unsqueeze(0) |
|
|
att2 = att2.softmax(-1) |
|
|
|
|
|
|
|
|
lam = torch.sigmoid(self.lambda_param) |
|
|
att = att1 - lam * att2 |
|
|
|
|
|
|
|
|
att = F.relu(att) |
|
|
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
|
|
|
z = (att @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiQueryAttention(nn.Module): |
|
|
""" |
|
|
MQA: Multiple query heads, single K/V head. |
|
|
Massive inference speedup (smaller KV cache). |
|
|
Same training cost as standard. |
|
|
""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
|
|
|
|
|
|
self.q = nn.Linear(d, d, bias=False) |
|
|
self.k = nn.Linear(d, self.dk, bias=False) |
|
|
self.v = nn.Linear(d, self.dk, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
|
|
|
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) |
|
|
v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) |
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
|
|
att = att + alibi_bias(self.h, N) |
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupedQueryAttention(nn.Module): |
|
|
""" |
|
|
GQA: Groups of query heads share K/V heads. |
|
|
Llama 2 uses this. Balance between quality and inference speed. |
|
|
""" |
|
|
def __init__(self, d: int, h: int, num_kv_heads: int = 2): |
|
|
super().__init__() |
|
|
self.h = h |
|
|
self.num_kv_heads = num_kv_heads |
|
|
self.dk = d // h |
|
|
self.heads_per_group = h // num_kv_heads |
|
|
|
|
|
self.q = nn.Linear(d, d, bias=False) |
|
|
self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False) |
|
|
self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
|
|
|
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) |
|
|
v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) |
|
|
|
|
|
|
|
|
k = k.repeat_interleave(self.heads_per_group, dim=1) |
|
|
v = v.repeat_interleave(self.heads_per_group, dim=1) |
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
|
|
att = att + alibi_bias(self.h, N) |
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RetentionAttention(nn.Module): |
|
|
""" |
|
|
From RetNet: Retentive Network |
|
|
|
|
|
Parallel mode (training): Like linear attention |
|
|
Recurrent mode (inference): O(1) per step |
|
|
|
|
|
Key: exponential decay instead of softmax |
|
|
""" |
|
|
def __init__(self, d: int, h: int, gamma: float = 0.9): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
|
|
|
self.gamma = nn.Parameter(torch.ones(h) * gamma) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
|
|
|
gamma = torch.sigmoid(self.gamma).view(1, self.h, 1, 1) |
|
|
positions = torch.arange(N, device=x.device).float() |
|
|
decay = gamma ** (positions.unsqueeze(0) - positions.unsqueeze(1)).clamp(min=0) |
|
|
|
|
|
|
|
|
causal = torch.tril(torch.ones(N, N, device=x.device)) |
|
|
decay = decay * causal.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
mask_binary = (mask == 0).float().unsqueeze(0).unsqueeze(0) |
|
|
decay = decay * mask_binary |
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) * decay |
|
|
|
|
|
|
|
|
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
|
|
|
z = (att @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GatedLinearAttention(nn.Module): |
|
|
""" |
|
|
Linear attention with gating for better gradient flow. |
|
|
From "Gated Linear Attention Transformers" (2024) |
|
|
""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.gate = nn.Linear(d, d) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
self.eps = 1e-6 |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
|
|
|
q = F.elu(q) + 1 |
|
|
k = F.elu(k) + 1 |
|
|
|
|
|
if mask is None: |
|
|
|
|
|
kv = torch.einsum('bhnd,bhnv->bhdv', k, v) |
|
|
out = torch.einsum('bhnd,bhdv->bhnv', q, kv) |
|
|
normalizer = torch.einsum('bhnd,bhd->bhn', q, k.sum(dim=2)).unsqueeze(-1).clamp(min=self.eps) |
|
|
else: |
|
|
|
|
|
kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2) |
|
|
k_cumsum = torch.cumsum(k, dim=2) |
|
|
out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum) |
|
|
normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps) |
|
|
|
|
|
out = out / normalizer |
|
|
out = out.transpose(1, 2).reshape(B, N, -1) |
|
|
|
|
|
|
|
|
gate = torch.sigmoid(self.gate(x)) |
|
|
out = out * gate |
|
|
|
|
|
return self.proj(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReLUAttention(nn.Module): |
|
|
""" |
|
|
Replace softmax with ReLU + normalization. |
|
|
Simpler, faster, sometimes works as well. |
|
|
From "ReLU Attention" papers. |
|
|
""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
|
|
att = att + alibi_bias(self.h, N) |
|
|
|
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
att = F.relu(att) |
|
|
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
|
|
|
z = (att @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SigmoidAttention(nn.Module): |
|
|
""" |
|
|
Sigmoid attention: each position independently decides attention weight. |
|
|
Not normalized to sum to 1 - allows variable "total attention". |
|
|
""" |
|
|
def __init__(self, d: int, h: int): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
self.bias = nn.Parameter(torch.zeros(h, 1, 1)) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, _ = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) + self.bias |
|
|
|
|
|
if mask is not None: |
|
|
att = att + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
att = torch.sigmoid(att) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
att = att * (mask == 0).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
z = (att @ v).transpose(1, 2).reshape(B, N, -1) |
|
|
return self.proj(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ATTN_REGISTRY = { |
|
|
"standard": StandardAttention, |
|
|
"linear": LinearAttention, |
|
|
"cosine": CosineAttention, |
|
|
"differential": DifferentialAttention, |
|
|
"mqa": MultiQueryAttention, |
|
|
"gqa": GroupedQueryAttention, |
|
|
"retention": RetentionAttention, |
|
|
"gated_linear": GatedLinearAttention, |
|
|
"relu": ReLUAttention, |
|
|
"sigmoid": SigmoidAttention, |
|
|
} |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d: int, h: int, attn_type: str = "standard"): |
|
|
super().__init__() |
|
|
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) |
|
|
self.attn = ATTN_REGISTRY[attn_type](d, h) |
|
|
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = x + self.attn(self.ln1(x), mask) |
|
|
return x + self.ff(self.ln2(x)) |
|
|
|
|
|
|
|
|
class FlexModel(nn.Module): |
|
|
def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard"): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(VOCAB, d) |
|
|
self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)]) |
|
|
self.ln = nn.LayerNorm(d) |
|
|
self.head = nn.Linear(d, VOCAB, bias=False) |
|
|
self.head.weight = self.emb.weight |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = self.emb(x) |
|
|
for b in self.blocks: |
|
|
x = b(x, mask) |
|
|
return self.head(self.ln(x)) |
|
|
|
|
|
def count_params(self): |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(attn_type: str, mode: str, d: int, layers: int, h: int, |
|
|
batch: int, seq: int, steps: int, block_size: int = 4): |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"ATTENTION: {attn_type.upper()} | MODE: {mode.upper()}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
model = FlexModel(d, layers, h, attn_type).to(DEV) |
|
|
print(f"Parameters: {model.count_params():,}") |
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=1e-4) |
|
|
|
|
|
losses, times = [], [] |
|
|
|
|
|
for step in range(steps): |
|
|
ids = torch.randint(0, VOCAB, (batch, seq), device=DEV) |
|
|
|
|
|
if mode == "ar": |
|
|
|
|
|
target = ids[:, 1:] |
|
|
input_ids = ids[:, :-1] |
|
|
mask = get_mask(seq - 1, "ar") |
|
|
elif mode == "sat": |
|
|
|
|
|
target = ids[:, 1:] |
|
|
input_ids = ids[:, :-1] |
|
|
mask = get_mask(seq - 1, "sat", block_size) |
|
|
else: |
|
|
|
|
|
target = ids |
|
|
|
|
|
noise_mask = torch.rand(batch, seq, device=DEV) < 0.15 |
|
|
input_ids = ids.clone() |
|
|
input_ids[noise_mask] = torch.randint(0, VOCAB, (noise_mask.sum().item(),), device=DEV) |
|
|
mask = get_mask(seq, "nar") |
|
|
|
|
|
start = time.time() |
|
|
opt.zero_grad() |
|
|
|
|
|
try: |
|
|
logits = model(input_ids, mask) |
|
|
loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1)) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
opt.step() |
|
|
except Exception as e: |
|
|
print(f"Step {step} failed: {e}") |
|
|
return None |
|
|
|
|
|
elapsed = time.time() - start |
|
|
losses.append(loss.item()) |
|
|
times.append(elapsed) |
|
|
|
|
|
if step % 20 == 0 or step == steps - 1: |
|
|
tok_s = batch * seq / elapsed |
|
|
print(f"Step {step:3d} | Loss {loss.item():.4f} | {tok_s:.0f} tok/s") |
|
|
|
|
|
avg_loss = sum(losses[-20:]) / min(20, len(losses)) |
|
|
avg_toks = batch * seq / (sum(times[-20:]) / min(20, len(times))) |
|
|
|
|
|
return {"attn": attn_type, "mode": mode, "loss": avg_loss, "tok_s": avg_toks} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--d", type=int, default=256) |
|
|
parser.add_argument("--layers", type=int, default=4) |
|
|
parser.add_argument("--heads", type=int, default=8) |
|
|
parser.add_argument("--batch", type=int, default=16) |
|
|
parser.add_argument("--seq", type=int, default=128) |
|
|
parser.add_argument("--steps", type=int, default=100) |
|
|
parser.add_argument("--mode", type=str, default="ar", choices=["ar", "sat", "nar", "all"]) |
|
|
parser.add_argument("--types", type=str, default="all") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Device: {DEV}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"GPU: {torch.cuda.get_device_name()}") |
|
|
|
|
|
if args.types == "all": |
|
|
types = list(ATTN_REGISTRY.keys()) |
|
|
else: |
|
|
types = [t.strip() for t in args.types.split(",")] |
|
|
|
|
|
modes = ["ar", "sat", "nar"] if args.mode == "all" else [args.mode] |
|
|
|
|
|
results = [] |
|
|
for mode in modes: |
|
|
for attn_type in types: |
|
|
r = train(attn_type, mode, args.d, args.layers, args.heads, |
|
|
args.batch, args.seq, args.steps) |
|
|
if r: |
|
|
results.append(r) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for mode in modes: |
|
|
print(f"\n--- MODE: {mode.upper()} ---") |
|
|
mode_results = [r for r in results if r['mode'] == mode] |
|
|
baseline = next((r for r in mode_results if r['attn'] == 'standard'), None) |
|
|
|
|
|
for r in sorted(mode_results, key=lambda x: x['loss']): |
|
|
rel = "" |
|
|
if baseline and r['attn'] != 'standard': |
|
|
loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100 |
|
|
speed_ratio = r['tok_s'] / baseline['tok_s'] |
|
|
rel = f" | vs std: {loss_diff:+.1f}%, {speed_ratio:.2f}x" |
|
|
print(f"{r['attn']:15s} | Loss {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s{rel}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|