user-anto's picture
Upload MiniLM base checkpoint
dd3a77b verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from typing import Optional
from config import ModelConfig
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * rms).to(x.dtype) * self.weight
# RoPE
def precompute_rope_freqs(head_dim: int, max_seq_len: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len)
angles = torch.outer(t, freqs)
return angles.cos(), angles.sin()
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: (B, T, n_heads, head_dim)
x_even = x[..., ::2].float()
x_odd = x[..., 1::2].float()
# (T, head_dim/2) -> (1, T, 1, head_dim/2) for broadcasting
cos = cos[: x.shape[1]].unsqueeze(0).unsqueeze(2)
sin = sin[: x.shape[1]].unsqueeze(0).unsqueeze(2)
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
x_rot = torch.stack((out_even, out_odd), dim=-1).flatten(-2)
return x_rot.to(x.dtype)
# GQA
class GQAttention(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
assert cfg.n_heads % cfg.n_kv_heads == 0
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.n_rep = cfg.n_heads // cfg.n_kv_heads
self.head_dim = cfg.dim // cfg.n_heads
self.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
self.dropout_p = cfg.dropout
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cache_k: Optional[torch.Tensor] = None,
cache_v: Optional[torch.Tensor] = None,
return_cache: bool = False,
):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if cache_k is not None:
k = torch.cat([cache_k, k], dim=1)
v = torch.cat([cache_v, v], dim=1)
new_cache_k, new_cache_v = (k, v) if return_cache else (None, None)
# Expand KV heads → Q heads
k = k.repeat_interleave(self.n_rep, dim=2)
v = v.repeat_interleave(self.n_rep, dim=2)
# (B, n_heads, T, head_dim) for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Flash / memory-efficient attention — never materialises (B,H,T,T) score matrix
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=(cache_k is None), # causal during training; non-causal with cache
)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out), new_cache_k, new_cache_v
# SwiGLU FFN
class SwiGLU(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
hidden = int(cfg.dim * cfg.ffn_dim_multiplier)
hidden = (hidden + 255) & ~255
self.w1 = nn.Linear(cfg.dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, cfg.dim, bias=False)
self.w3 = nn.Linear(cfg.dim, hidden, bias=False)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
# Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps)
self.attn = GQAttention(cfg)
self.ffn_norm = RMSNorm(cfg.dim, cfg.norm_eps)
self.ffn = SwiGLU(cfg)
def _forward(self, x, cos, sin, cache_k, cache_v, return_cache):
attn_out, nck, ncv = self.attn(
self.attn_norm(x), cos, sin, cache_k, cache_v, return_cache=return_cache
)
x = x + attn_out
x = x + self.ffn(self.ffn_norm(x))
return x, nck, ncv
def forward(self, x, cos, sin, cache_k=None, cache_v=None, use_grad_ckpt=False, return_cache=False):
if use_grad_ckpt and self.training:
# gradient checkpointing: recompute activations on backward instead of storing them
# cache is None during training so we pass dummy tensors to satisfy checkpoint API
def ckpt_fn(x, cos, sin):
out, _, _ = self._forward(x, cos, sin, None, None, False)
return out
x = checkpoint(ckpt_fn, x, cos, sin, use_reentrant=False)
return x, None, None
return self._forward(x, cos, sin, cache_k, cache_v, return_cache)
# LLM Definition
class LLM(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
self.norm = RMSNorm(cfg.dim, cfg.norm_eps)
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
self.embed.weight = self.lm_head.weight # weight tying
head_dim = cfg.dim // cfg.n_heads
cos, sin = precompute_rope_freqs(head_dim, cfg.max_seq_len * 2, cfg.rope_theta)
self.register_buffer("rope_cos", cos)
self.register_buffer("rope_sin", sin)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.cfg.n_layers))
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(
self,
idx: torch.Tensor,
targets: Optional[torch.Tensor] = None,
cache: Optional[list] = None,
use_grad_ckpt: bool = False,
return_cache: bool = False,
):
B, T = idx.shape
x = self.embed(idx)
pos_start = 0 if (cache is None or cache[0][0] is None) else cache[0][0].shape[1]
cos = self.rope_cos[pos_start: pos_start + T]
sin = self.rope_sin[pos_start: pos_start + T]
need_cache = return_cache or (cache is not None)
new_cache = [] if need_cache else None
for i, layer in enumerate(self.layers):
ck, cv = cache[i] if cache else (None, None)
x, nck, ncv = layer(
x,
cos,
sin,
ck,
cv,
use_grad_ckpt=use_grad_ckpt,
return_cache=need_cache,
)
if need_cache:
new_cache.append((nck, ncv))
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss, new_cache
def param_count(self) -> int:
return sum(p.numel() for p in self.parameters())
@torch.no_grad()
def probe_attention_entropy(self, idx: torch.Tensor, max_probe_len: int = 256) -> float:
"""
Estimate mean causal attention entropy from layer 0 on a short token window.
Lower entropy means sharper/more concentrated attention.
"""
if idx.ndim != 2:
raise ValueError(f"idx must be shape (B, T), got {tuple(idx.shape)}")
if idx.shape[1] == 0:
return float("nan")
probe_len = min(int(max_probe_len), int(idx.shape[1]))
idx = idx[:, -probe_len:]
B, T = idx.shape
x = self.embed(idx)
cos = self.rope_cos[:T]
sin = self.rope_sin[:T]
layer0 = self.layers[0]
attn = layer0.attn
h = layer0.attn_norm(x)
q = attn.wq(h).view(B, T, attn.n_heads, attn.head_dim)
k = attn.wk(h).view(B, T, attn.n_kv_heads, attn.head_dim)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
k = k.repeat_interleave(attn.n_rep, dim=2)
q = q.transpose(1, 2).float() # (B, H, T, D)
k = k.transpose(1, 2).float() # (B, H, T, D)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(attn.head_dim)
causal_mask = torch.triu(
torch.ones((T, T), device=scores.device, dtype=torch.bool), diagonal=1
)
scores = scores.masked_fill(causal_mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)
return float(entropy.mean().item())
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 0.8,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
no_repeat_ngram_size: int = 3,
):
cache = None
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.cfg.max_seq_len:] if cache is None else idx[:, -1:]
logits, _, cache = self(idx_cond, cache=cache, return_cache=True)
logits = logits[:, -1, :]
# Discourage copying previously generated tokens.
if repetition_penalty > 1.0:
for b in range(idx.size(0)):
used = idx[b].unique()
used_logits = logits[b, used]
logits[b, used] = torch.where(
used_logits > 0, used_logits / repetition_penalty, used_logits * repetition_penalty
)
# Block tokens that would create repeated n-grams.
if no_repeat_ngram_size and no_repeat_ngram_size > 1 and idx.size(1) >= no_repeat_ngram_size - 1:
n = int(no_repeat_ngram_size)
for b in range(idx.size(0)):
seq = idx[b].tolist()
prefix = tuple(seq[-(n - 1) :])
banned = set()
for i in range(len(seq) - n + 1):
if tuple(seq[i : i + n - 1]) == prefix:
banned.add(seq[i + n - 1])
if banned:
logits[b, list(banned)] = float("-inf")
if temperature == 0.0:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = sorted_probs.cumsum(-1)
sorted_probs[cumsum - sorted_probs > top_p] = 0.0
sorted_probs /= sorted_probs.sum(-1, keepdim=True)
next_tok = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
idx = torch.cat([idx, next_tok], dim=1)
return idx