Sualeh Qureshi
Commited the training code and model file
c175ce3
raw
history blame
22.2 kB
# Minimal SmolLM2-135M style model implemented in PyTorch.
# Architecture: LLaMA-style decoder-only Transformer with:
# - RMSNorm
# - RoPE positional encoding
# - SwiGLU MLP
# - Grouped (GQA/MQA) attention: num_attention_heads != num_key_value_heads
#
# This file is self-contained (except PyTorch) and can be used as:
#
# from model import SmolConfig, SmolLM2
#
# cfg = SmolConfig.from_hf("HuggingFaceTB/SmolLM2-135M")
# model = SmolLM2(cfg)
from dataclasses import dataclass
from typing import Optional, Tuple, List
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================
# 1. Config
# Got config from HuggingFace Using: transformers.AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
# Config: SmolLM2-135M
# LlamaConfig {
# "architectures": [
# "LlamaForCausalLM"
# ],
# "attention_bias": false,
# "attention_dropout": 0.0,
# "bos_token_id": 0,
# "dtype": "bfloat16",
# "eos_token_id": 0,
# "head_dim": 64,
# "hidden_act": "silu",
# "hidden_size": 576,
# "initializer_range": 0.041666666666666664,
# "intermediate_size": 1536,
# "is_llama_config": true,
# "max_position_embeddings": 8192,
# "mlp_bias": false,
# "model_type": "llama",
# "num_attention_heads": 9,
# "num_hidden_layers": 30,
# "num_key_value_heads": 3,
# "pretraining_tp": 1,
# "rms_norm_eps": 1e-05,
# "rope_interleaved": false,
# "rope_scaling": null,
# "rope_theta": 100000,
# "tie_word_embeddings": true,
# "transformers_version": "4.57.3",
# "use_cache": true,
# "vocab_size": 49152
# }
# =========================
@dataclass
class SmolConfig:
# Core dimensions
vocab_size: int = 49152 # from HF config
hidden_size: int = 576 # "hidden_size"
intermediate_size: int = 1536 # "intermediate_size"
num_hidden_layers: int = 30 # "num_hidden_layers"
num_attention_heads: int = 9 # "num_attention_heads"
num_key_value_heads: int = 3 # "num_key_value_heads"
max_position_embeddings: int = 8192 # "max_position_embeddings"
# Positional / RoPE
rope_theta: float = 100000.0 # "rope_theta"
# Norm / numerical
rms_norm_eps: float = 1e-5 # "rms_norm_eps"
# Biases
attention_bias: bool = False # "attention_bias"
mlp_bias: bool = False # "mlp_bias"
# Misc
dtype: torch.dtype = torch.bfloat16
@property
def head_dim(self) -> int:
# Should be 64 for SmolLM2-135M (576 / 9).
return self.hidden_size // self.num_attention_heads # 576 / 9 = 64
@classmethod
def from_hf(cls, hf_config) -> "SmolConfig":
"""
Helper to build this config from a transformers LlamaConfig (Which is the config for the HuggingFace SmolLM2-135M model).
Example:
from transformers import AutoConfig
hf = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
cfg = SmolConfig.from_hf(hf)
And then pass this config to this function call to set the config for the model.
"""
return cls(
vocab_size=hf_config.vocab_size,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size,
num_hidden_layers=hf_config.num_hidden_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=getattr(hf_config, "num_key_value_heads",
hf_config.num_attention_heads),
max_position_embeddings=hf_config.max_position_embeddings,
rope_theta=getattr(hf_config, "rope_theta", 10000.0),
rms_norm_eps=hf_config.rms_norm_eps,
attention_bias=getattr(hf_config, "attention_bias", False),
mlp_bias=getattr(hf_config, "mlp_bias", False),
dtype=torch.bfloat16, # SmolLM2 uses bfloat16
)
# =========================
# 2. RMSNorm
# =========================
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization (RMSNorm)
Used in LLaMA / SmolLM2 instead of LayerNorm.
"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (..., dim)
# rms = sqrt(mean(x^2)), but we can use rsqrt for stability
norm = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(norm + self.eps)
return self.weight * x
# =========================
# 3. RoPE (Rotary Positional Embeddings)
# =========================
def rope_freqs(head_dim: int, base: float, device, dtype):
"""
Compute inverse frequencies for RoPE.
"""
half_dim = head_dim // 2
# Equivalent to: base^{ -2i / d }
freq_seq = torch.arange(half_dim, device=device, dtype=dtype)
inv_freq = 1.0 / (base ** (freq_seq / half_dim))
return inv_freq # shape: (half_dim,)
def build_rope_cache(
seq_len: int,
head_dim: int,
base: float,
device,
dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Build cosine and sine caches for RoPE.
Returns:
cos: (1, 1, seq_len, head_dim/2)
sin: (1, 1, seq_len, head_dim/2)
"""
inv_freq = rope_freqs(head_dim, base, device, dtype) # (half_dim,)
# Positions
t = torch.arange(seq_len, device=device, dtype=dtype) # (seq_len,)
freqs = torch.outer(t, inv_freq) # (seq_len, half_dim)
cos = freqs.cos()[None, None, :, :] # (1,1,seq_len,half_dim)
sin = freqs.sin()[None, None, :, :] # (1,1,seq_len,half_dim)
return cos, sin
def apply_rope(
x: torch.Tensor, # (B, n_head, T, head_dim)
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply RoPE to last dimension of x.
cos, sin are broadcast to match (..., head_dim/2).
"""
b, h, t, d = x.shape
half = d // 2
x1 = x[..., :half] # (B, n_head, T, head_dim/2)
x2 = x[..., half:] # (B, n_head, T, head_dim/2)
# cos/sin: (1,1,T,half) -> broadcast over B,h
cos_t = cos[..., :t, :]
sin_t = sin[..., :t, :]
x1_rot = x1 * cos_t - x2 * sin_t
x2_rot = x1 * sin_t + x2 * cos_t
return torch.cat([x1_rot, x2_rot], dim=-1) # (B, n_head, T, head_dim)
# =========================
# 4. Attention
# =========================
class MultiHeadSelfAttention(nn.Module):
"""
LLaMA / SmolLM2-style attention with:
- Q heads = num_attention_heads
- K/V heads = num_key_value_heads (GQA/MQA)
- RoPE on Q and K
- Causal masking
"""
def __init__(self, config: SmolConfig):
super().__init__()
self.config = config
self.n_heads = config.num_attention_heads # 9
self.n_kv_heads = config.num_key_value_heads # 3
self.head_dim = config.head_dim # 64
self.hidden_size = config.hidden_size # 576
assert self.hidden_size == self.n_heads * self.head_dim
# Projections
self.q_proj = nn.Linear(
self.hidden_size,
self.n_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.n_kv_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.n_kv_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.n_heads * self.head_dim,
self.hidden_size,
bias=config.attention_bias,
)
def forward(
self,
x: torch.Tensor, # (B, T, C) or (B, 1, C) for inference
cos: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
sin: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
attention_mask: Optional[torch.Tensor] = None, # (B, T) or (B,1,1,T)
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (k_cache, v_cache)
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = x.shape
# Projections: (B,T,C) -> (B,T,h,d) -> (B,h,T,d)
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,h*d) -> (B,T,h,d) -> (B,h,T,d)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,k*d) -> (B,T,k,d) -> (B,k,T,d)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,v*d) -> (B,T,v,d) -> (B,v,T,d)
# Apply RoPE to Q and K
q = apply_rope(q, cos, sin) # (B, h, T, d)
k = apply_rope(k, cos, sin) # (B, n_kv_heads, T, d)
# v doesn't need RoPE
# If using KV cache, concatenate with past keys/values
if past_key_value is not None:
past_k, past_v = past_key_value
# past_k, past_v: (B, n_kv_heads, past_len, head_dim)
k = torch.cat([past_k, k], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
v = torch.cat([past_v, v], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
seq_len = k.shape[2]
else:
seq_len = T
# Store k, v for cache (before GQA expansion)
k_cache = k # (B, n_kv_heads, seq_len, head_dim)
v_cache = v # (B, n_kv_heads, seq_len, head_dim)
# GQA: expand K/V if num_kv_heads < num_heads
if self.n_kv_heads != self.n_heads:
repeat_factor = self.n_heads // self.n_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
v = v.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
# Attention scores: (B,h,T,d) @ (B,h,d,seq_len) -> (B,h,T,seq_len)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Causal mask: prevent attending to future tokens
# For inference with KV cache, we only need to mask the current position
if past_key_value is None:
# Full sequence: mask all future positions
causal_mask = torch.full(
(T, T), float("-inf"), device=x.device, dtype=x.dtype
).triu(1) # upper triangle (i < j)
scores = scores + causal_mask.unsqueeze(0).unsqueeze(0) # (B,h,T,T) + (1,1,T,T) -> (B,h,T,T)
else:
# With KV cache: only mask positions beyond current (shouldn't happen, but safety)
# Since we're generating one token at a time, T=1, and we attend to all past + current
pass
# Optional attention mask (e.g., padding). Should be additive (0 or -inf).
if attention_mask is not None:
# Expect attention_mask as (B, 1, 1, seq_len) or (B, seq_len)
if attention_mask.dim() == 2:
# (B, seq_len) -> (B,1,1,seq_len)
attention_mask = attention_mask[:, None, None, :]
# Adjust mask shape if needed
if attention_mask.shape[-1] != seq_len:
# For inference, we might need to extend the mask
if past_key_value is not None:
# Extend mask to include past positions (all 0s for past, current mask for new token)
past_len = past_k.shape[2]
extended_mask = torch.zeros(B, 1, 1, seq_len, device=attention_mask.device, dtype=attention_mask.dtype)
extended_mask[..., past_len:] = attention_mask[..., -T:]
attention_mask = extended_mask
scores = scores + attention_mask
# Softmax over last dim (seq_len)
probs = F.softmax(scores, dim=-1) # (B,h,T,seq_len) -> (B,h,T,seq_len)
# Weighted sum of values
out = torch.matmul(probs, v) # (B,h,T,seq_len) @ (B,h,seq_len,d) -> (B,h,T,d)
# Reshape back: (B,T,C)
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,h,T,d) -> (B,T,h,d) -> (B,T,h*d) -> (B,T,C)
out = self.o_proj(out) # (B,T,C) -> (B,T,C)
# Return output and optionally the new KV cache
present_key_value = None
if use_cache:
# Return k_cache, v_cache (before GQA expansion, after RoPE)
present_key_value = (k_cache, v_cache)
return out, present_key_value
# =========================
# 5. MLP (SwiGLU)
# =========================
class SmolMLP(nn.Module):
"""
SwiGLU MLP:
z = W1(x) -> split -> (x1, x2)
out = W2( SiLU(x1) * x2 )
"""
def __init__(self, config: SmolConfig):
super().__init__()
self.fc1 = nn.Linear(
config.hidden_size,
2 * config.intermediate_size, # for SwiGLU split (2 x 1536 = 3072)
bias=config.mlp_bias,
)
self.fc2 = nn.Linear(
config.intermediate_size, # 1536
config.hidden_size, # 576
bias=config.mlp_bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)# (B,T,C) -> (B,T,2*intermediate_size) -> (B,T,1536*2) -> (B,T,3072)
x1, x2 = x.chunk(2, dim=-1) # (B,T,2*intermediate_size) = (B,T,3072) -> (B,T,intermediate), (B,T,intermediate) = (B,T,1536), (B,T,1536)
return self.fc2(F.silu(x1) * x2) # (B,T,intermediate) * (B,T,intermediate) -> (B,T,intermediate) -> (B,T,hidden_size) = (B,T,576)
# =========================
# 6. Transformer Block
# =========================
class SmolBlock(nn.Module):
def __init__(self, config: SmolConfig):
super().__init__()
self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = MultiHeadSelfAttention(config)
self.mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = SmolMLP(config)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Pre-norm + residual for attention
attn_out, present_key_value = self.attn(
self.attn_norm(x), cos, sin, attention_mask, past_key_value, use_cache
)
x = x + attn_out
# Pre-norm + residual for MLP
x = x + self.mlp(self.mlp_norm(x))
return x, present_key_value
# =============================================
# 7. Top-level SmolLM2-135M Model Architecture
# SmolLM2 follows the LLaMA-style decoder-only Transformer architecture.
# =============================================
class SmolLM2(nn.Module):
"""
SmolLM2-135M-style LLaMA decoder-only language model.
Usage:
cfg = SmolConfig()
model = SmolLM2(cfg)
input_ids: LongTensor (B, T)
logits = model(input_ids)
"""
def __init__(self, config: SmolConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
) # (Vocab_Size, Hidden_Size) (49152 x 576)
self.layers = nn.ModuleList(
[SmolBlock(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
) # (Hidden_Size, Vocab_Size) (576 x 49152)
# tie weights
self.lm_head.weight = self.embed_tokens.weight
def forward(
self,
input_ids: torch.Tensor, # (B, T)
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
B, T = input_ids.shape
# For inference with KV cache, we might have T=1
if past_key_values is None:
assert T <= self.config.max_position_embeddings, (
f"Sequence length {T} exceeds max_position_embeddings "
f"{self.config.max_position_embeddings}"
)
seq_len = T
else:
# With KV cache, current sequence length is past_len + T
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
seq_len = past_len + T
assert seq_len <= self.config.max_position_embeddings, (
f"Total sequence length {seq_len} exceeds max_position_embeddings "
f"{self.config.max_position_embeddings}"
)
# Embedding
x = self.embed_tokens(input_ids) # (B,T) -> (B,T,C)
# RoPE cache - build for the full sequence length (past + current)
cos, sin = build_rope_cache(
seq_len=seq_len,
head_dim=self.config.head_dim,
base=self.config.rope_theta,
device=x.device,
dtype=x.dtype,
)
# If using KV cache, we only need cos/sin for current positions
if past_key_values is not None:
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
# Slice to get only the current positions for RoPE
cos = cos[..., past_len:, :]
sin = sin[..., past_len:, :]
# Layers
present_key_values = [] if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
x, present_kv = layer(x, cos, sin, attention_mask, past_kv, use_cache)
if use_cache:
present_key_values.append(present_kv)
# Final norm + lm head
x = self.norm(x)
logits = self.lm_head(x) # (B,T,C) -> (B,T,vocab_size)
return logits, present_key_values
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""
Generate text using KV cache for efficient inference.
Args:
input_ids: (B, T) input token ids
max_new_tokens: maximum number of new tokens to generate
temperature: sampling temperature
top_k: top-k sampling (keep top k tokens)
top_p: nucleus sampling (keep tokens with cumulative probability <= top_p)
eos_token_id: end-of-sequence token id (stop generation when encountered)
Returns:
generated_ids: (B, T + max_new_tokens) generated token ids
"""
self.eval()
device = input_ids.device
B, T = input_ids.shape
# Start with input_ids
generated_ids = input_ids.clone()
past_key_values = None
for step in range(max_new_tokens):
# Forward pass with KV cache
# On first iteration, use full input_ids. On subsequent iterations, use only last token
if past_key_values is None:
# First iteration: process full sequence
current_input = generated_ids
else:
# Subsequent iterations: only process the last generated token
current_input = generated_ids[:, -1:]
logits, past_key_values = self.forward(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
# Get logits for the last token (always the last position in logits)
next_token_logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample next token
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append to generated sequence
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Check for EOS token
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated_ids
# =========================
# 8. Quick self-test
# =========================
if __name__ == "__main__":
# Tiny sanity check: runs a forward pass on random input
cfg = SmolConfig()
model = SmolLM2(cfg)
B, T = 2, 16
x = torch.randint(0, cfg.vocab_size, (B, T))
with torch.no_grad():
logits, _ = model(x)
print("Input shape :", x.shape)
print("Logits shape:", logits.shape) # should be (2, 16, vocab_size)