grapheneaffiliates's picture
Upload python/baselines.py with huggingface_hub
26849c5 verified
"""
Baseline attention mechanisms for comparison with H4 Polytopic Attention.
Implements standard softmax attention and linear attention (Katharopoulos et al. 2020)
with the SAME model wrapper (embeddings, FFN, LM head) so the only variable is attention.
Usage:
model = BaselineLanguageModel(vocab_size=128, d_model=128, n_heads=8,
n_layers=4, d_value=16, d_ffn=512,
attention_type='softmax') # or 'linear'
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils.phi_positional import PhiPositionalEncoding
from bitlinear import BitLinear
# ---------------------------------------------------------------------------
# Softmax Attention (standard transformer)
# ---------------------------------------------------------------------------
class SoftmaxAttention(nn.Module):
"""Standard multi-head scaled dot-product attention with causal mask."""
def __init__(self, d_model, n_heads, d_value, dropout=0.0, use_bitlinear=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.d_value = d_value
self.scale = 1.0 / math.sqrt(self.d_head)
Linear = BitLinear if use_bitlinear else nn.Linear
self.W_q = Linear(d_model, self.d_head * n_heads, bias=False)
self.W_k = Linear(d_model, self.d_head * n_heads, bias=False)
self.W_v = Linear(d_model, d_value * n_heads, bias=False)
self.W_out = Linear(d_value * n_heads, d_model, bias=False)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x, **kwargs):
B, T, D = x.shape
H = self.n_heads
Q = self.W_q(x).view(B, T, H, self.d_head).transpose(1, 2) # (B, H, T, d_head)
K = self.W_k(x).view(B, T, H, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, H, self.d_value).transpose(1, 2) # (B, H, T, d_value)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B, H, T, T)
# Causal mask
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
)
scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, V) # (B, H, T, d_value)
out = out.transpose(1, 2).contiguous().view(B, T, H * self.d_value)
return self.W_out(out)
# ---------------------------------------------------------------------------
# Linear Attention (Katharopoulos et al. 2020)
# ---------------------------------------------------------------------------
def elu_feature_map(x):
"""ELU+1 feature map for linear attention: phi(x) = elu(x) + 1."""
return F.elu(x) + 1.0
class LinearAttention(nn.Module):
"""
Linear attention: O(T) causal attention via kernel trick.
Instead of softmax(QK^T)V, computes phi(Q) @ (phi(K)^T @ V)
where phi is the ELU+1 feature map.
For causal attention, uses cumulative sum formulation:
S_t = sum_{i<=t} phi(K_i)^T V_i (running state)
z_t = sum_{i<=t} phi(K_i) (running normalizer)
output_t = (phi(Q_t) @ S_t) / (phi(Q_t) @ z_t)
"""
def __init__(self, d_model, n_heads, d_value, dropout=0.0, use_bitlinear=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.d_value = d_value
Linear = BitLinear if use_bitlinear else nn.Linear
self.W_q = Linear(d_model, self.d_head * n_heads, bias=False)
self.W_k = Linear(d_model, self.d_head * n_heads, bias=False)
self.W_v = Linear(d_model, d_value * n_heads, bias=False)
self.W_out = Linear(d_value * n_heads, d_model, bias=False)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x, **kwargs):
B, T, D = x.shape
H = self.n_heads
Q = self.W_q(x).view(B, T, H, self.d_head) # (B, T, H, d_head)
K = self.W_k(x).view(B, T, H, self.d_head)
V = self.W_v(x).view(B, T, H, self.d_value) # (B, T, H, d_value)
# Apply ELU+1 feature map
Q = elu_feature_map(Q) # (B, T, H, d_head)
K = elu_feature_map(K)
# Causal linear attention via cumulative sum
# S_t = cumsum(phi(K)^T @ V) over time dimension
# For each timestep: KV = outer(K_t, V_t) -> (B, H, d_head, d_value)
# Cumulative: S_t = sum_{i<=t} KV_i
# Reshape for batch computation
Q = Q.permute(0, 2, 1, 3) # (B, H, T, d_head)
K = K.permute(0, 2, 1, 3)
V = V.permute(0, 2, 1, 3) # (B, H, T, d_value)
# KV: outer product at each timestep
KV = torch.einsum('bhti,bhtj->bhtij', K, V) # (B, H, T, d_head, d_value)
S = torch.cumsum(KV, dim=2) # (B, H, T, d_head, d_value)
# Normalizer: cumsum of K
z = torch.cumsum(K, dim=2) # (B, H, T, d_head)
# Output: Q @ S / (Q @ z)
# numerator: (B, H, T, d_head) @ (B, H, T, d_head, d_value) -> (B, H, T, d_value)
num = torch.einsum('bhti,bhtij->bhtj', Q, S)
# denominator: (B, H, T, d_head) . (B, H, T, d_head) -> (B, H, T)
den = torch.einsum('bhti,bhti->bht', Q, z).unsqueeze(-1).clamp(min=1e-6)
out = num / den # (B, H, T, d_value)
out = out.permute(0, 2, 1, 3).contiguous().view(B, T, H * self.d_value)
out = self.dropout(out)
return self.W_out(out)
# ---------------------------------------------------------------------------
# Transformer Block (swappable attention)
# ---------------------------------------------------------------------------
class SoftmaxTransformerBlock(nn.Module):
"""Standard pre-norm transformer block with softmax attention."""
def __init__(self, d_model, n_heads, d_value, d_ffn=None, dropout=0.0,
use_bitlinear=False):
super().__init__()
if d_ffn is None:
d_ffn = d_model * 4
Linear = BitLinear if use_bitlinear else nn.Linear
self.ln1 = nn.LayerNorm(d_model)
self.attn = SoftmaxAttention(d_model, n_heads, d_value, dropout, use_bitlinear)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
Linear(d_model, d_ffn, bias=False),
nn.GELU(),
Linear(d_ffn, d_model, bias=False),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
)
def forward(self, x, **kwargs):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class LinearTransformerBlock(nn.Module):
"""Pre-norm transformer block with linear attention (Katharopoulos et al. 2020)."""
def __init__(self, d_model, n_heads, d_value, d_ffn=None, dropout=0.0,
use_bitlinear=False):
super().__init__()
if d_ffn is None:
d_ffn = d_model * 4
Linear = BitLinear if use_bitlinear else nn.Linear
self.ln1 = nn.LayerNorm(d_model)
self.attn = LinearAttention(d_model, n_heads, d_value, dropout, use_bitlinear)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
Linear(d_model, d_ffn, bias=False),
nn.GELU(),
Linear(d_ffn, d_model, bias=False),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
)
def forward(self, x, **kwargs):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
# ---------------------------------------------------------------------------
# Baseline Language Model
# ---------------------------------------------------------------------------
class BaselineLanguageModel(nn.Module):
"""
Language model with swappable attention mechanism.
Same architecture as H4LanguageModel (same embeddings, FFN, LM head)
but with standard softmax or linear attention instead of H4 geometric attention.
This ensures the only variable in comparisons is the attention mechanism.
Args:
vocab_size: vocabulary size
d_model: model dimension
n_heads: number of attention heads
n_layers: number of transformer blocks
d_value: value dimension per head
d_ffn: FFN hidden dimension (default: 4 * d_model)
max_seq_len: max sequence length for positional encoding
dropout: dropout rate
attention_type: 'softmax' or 'linear'
use_bitlinear: if True, use ternary weights
"""
def __init__(
self,
vocab_size,
d_model=128,
n_heads=8,
n_layers=4,
d_value=16,
d_ffn=None,
max_seq_len=512,
dropout=0.0,
attention_type='softmax',
use_bitlinear=False,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.attention_type = attention_type
if d_ffn is None:
d_ffn = d_model * 4
# Token embedding (always float)
self.token_emb = nn.Embedding(vocab_size, d_model)
self.emb_scale = math.sqrt(d_model)
# Same golden-angle positional encoding as H4LanguageModel
self.pos_enc = PhiPositionalEncoding(d_model, max_cached=max_seq_len)
self.emb_dropout = nn.Dropout(dropout)
# Transformer blocks with selected attention type
if attention_type == 'softmax':
BlockClass = SoftmaxTransformerBlock
elif attention_type == 'linear':
BlockClass = LinearTransformerBlock
else:
raise ValueError(f"Unknown attention_type: {attention_type}")
self.blocks = nn.ModuleList([
BlockClass(
d_model=d_model,
n_heads=n_heads,
d_value=d_value,
d_ffn=d_ffn,
dropout=dropout,
use_bitlinear=use_bitlinear,
)
for _ in range(n_layers)
])
# Final layer norm
self.ln_f = nn.LayerNorm(d_model)
# LM head (tied with token embedding)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight
self._init_weights()
def _init_weights(self):
"""Initialize weights following GPT-2 conventions."""
for module in self.modules():
if isinstance(module, BitLinear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, **kwargs):
"""
Args:
input_ids: (batch, seq_len) token indices
Returns:
logits: (batch, seq_len, vocab_size)
"""
B, T = input_ids.shape
tok_emb = self.token_emb(input_ids) * self.emb_scale
pos_emb = self.pos_enc(T).unsqueeze(0).to(tok_emb.device)
x = self.emb_dropout(tok_emb + pos_emb)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits
def count_params(self):
"""Count trainable and frozen parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
buffers = sum(b.numel() for b in self.buffers())
return {
'trainable': trainable,
'frozen': frozen,
'buffers': buffers,
'total': trainable + frozen,
}
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=100, temperature=1.0, top_k_sample=0):
"""Autoregressive generation."""
for _ in range(max_new_tokens):
logits = self.forward(input_ids)
logits = logits[:, -1, :] / temperature
if top_k_sample > 0:
v, _ = torch.topk(logits, min(top_k_sample, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return input_ids