abi96062's picture
Create model.py
49d2fa1 verified
"""
model.py
========
Complete SmolLM2-135M model implementation
Architecture:
- 30 transformer blocks
- 576 hidden dimensions
- 9 query heads, 3 KV heads (Grouped Query Attention)
- SwiGLU feed-forward network
- RoPE position embeddings
- RMSNorm layer normalization
- Weight tying (embeddings = lm_head)
Total parameters: 134,515,008 (~135M)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from components import RMSNorm, TransformerBlock
from transformers import AutoConfig
class SmolLM2Model(nn.Module):
"""
SmolLM2-135M Language Model
A decoder-only transformer based on Llama architecture with:
- Grouped Query Attention (memory efficient)
- SwiGLU FFN (improved expressiveness)
- RoPE position embeddings (length extrapolation)
- RMSNorm (faster than LayerNorm)
Model configuration:
- Layers: 30
- Hidden size: 576
- Attention heads: 9 (Q) / 3 (KV)
- FFN size: 1536
- Vocab size: 49,152
- Context length: 2048
"""
def __init__(self, config):
"""
Initialize SmolLM2 model
Args:
config: Model configuration object with attributes:
- vocab_size: Size of vocabulary (49152)
- hidden_size: Model dimension (576)
- num_hidden_layers: Number of transformer blocks (30)
- tie_word_embeddings: Whether to tie input/output embeddings
- rms_norm_eps: Epsilon for RMSNorm
"""
super().__init__()
self.config = config
# Token embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# Transformer blocks (30 layers)
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_hidden_layers)
])
# Final layer normalization
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Language modeling head (output projection)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Weight tying: share embeddings with output projection
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
print(f"βœ… Model initialized with {config.num_hidden_layers} transformer blocks")
print(f"βœ… Weight tying: {config.tie_word_embeddings}")
def forward(self, input_ids, attention_mask=None, position_ids=None):
"""
Forward pass through the model
Args:
input_ids (torch.Tensor): Input token IDs [batch, seq_len]
attention_mask (torch.Tensor, optional): Attention mask
position_ids (torch.Tensor, optional): Position indices
Returns:
torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size]
"""
batch_size, seq_len = input_ids.shape
# Create position IDs if not provided
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device)
# Embed tokens
hidden_states = self.embed_tokens(input_ids)
# Pass through all transformer blocks
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, position_ids)
# Final normalization
hidden_states = self.norm(hidden_states)
# Project to vocabulary
logits = self.lm_head(hidden_states)
return logits
def generate(
self,
input_ids,
max_new_tokens=50,
temperature=1.0,
top_p=0.9,
top_k=None,
do_sample=True
):
"""
Generate text autoregressively
Supports multiple sampling strategies:
- Greedy decoding (temperature=0)
- Temperature sampling
- Nucleus (top-p) sampling
- Top-k sampling
Args:
input_ids (torch.Tensor): Input token IDs [batch, seq_len]
max_new_tokens (int): Number of tokens to generate
temperature (float): Sampling temperature (0 = greedy, >1 = more random)
top_p (float): Nucleus sampling threshold (0-1)
top_k (int, optional): Top-k sampling threshold
do_sample (bool): Whether to sample or use greedy decoding
Returns:
torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens]
"""
self.eval()
for _ in range(max_new_tokens):
with torch.no_grad():
# Forward pass
logits = self(input_ids)
# Get next token logits
next_token_logits = logits[:, -1, :]
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Greedy decoding
if not do_sample or temperature == 0:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
# Top-k sampling
if top_k is not None:
top_k = min(top_k, next_token_logits.size(-1))
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Nucleus (top-p) sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
# Scatter to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample from distribution
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def get_num_params(self, non_embedding=False):
"""
Count model parameters
Args:
non_embedding (bool): If True, exclude embedding parameters
Returns:
int: Number of parameters
"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.embed_tokens.weight.numel()
# If weights are tied, don't double-count
if not self.config.tie_word_embeddings:
n_params -= self.lm_head.weight.numel()
return n_params
def initialize_weights(model, config):
"""
Initialize model weights using GPT-style initialization
Strategy:
- All weights: Normal(0, 0.02)
- Residual projections: Scaled by 1/sqrt(2 * num_layers)
- RMSNorm: Initialized to 1.0 (PyTorch default)
The residual scaling prevents variance explosion in deep networks.
Args:
model (SmolLM2Model): Model to initialize
config: Model configuration
"""
std = 0.02
num_layers = config.num_hidden_layers
# Residual scaling factor: 1/sqrt(2 * num_layers)
residual_scaling = 1.0 / math.sqrt(2 * num_layers)
print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}")
# Initialize embeddings
nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std)
# Initialize each transformer block
for layer in model.layers:
# Attention projections
nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std)
nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std)
nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std)
# Output projection with residual scaling
nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling)
# FFN projections
nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std)
nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std)
# Output projection with residual scaling
nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling)
# RMSNorm weights are initialized to 1.0 by default (PyTorch)
print(f"βœ… Initialized {sum(1 for _ in model.parameters())} weight tensors")
def load_pretrained_weights(our_model, official_model, device='cuda'):
"""
Load weights from HuggingFace official model
Maps weight names from official model to our implementation:
- model.embed_tokens.weight -> embed_tokens.weight
- model.layers.{i}.* -> layers[i].*
- model.norm.weight -> norm.weight
- lm_head.weight (tied with embeddings)
Args:
our_model (SmolLM2Model): Our model to load weights into
official_model: HuggingFace official model
device (str): Device to load weights to
Returns:
int: Number of weight tensors loaded
"""
print("=" * 70)
print("LOADING PRETRAINED WEIGHTS")
print("=" * 70)
official_state = official_model.state_dict()
loaded_count = 0
# 1. Load token embeddings
our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device)
loaded_count += 1
# 2. Load all transformer blocks
num_layers = our_model.config.num_hidden_layers
for layer_idx in range(num_layers):
prefix = f'model.layers.{layer_idx}'
# Layer norms
our_model.layers[layer_idx].input_layernorm.weight.data = \
official_state[f'{prefix}.input_layernorm.weight'].clone().to(device)
our_model.layers[layer_idx].post_attention_layernorm.weight.data = \
official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device)
# Attention projections
our_model.layers[layer_idx].self_attn.q_proj.weight.data = \
official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device)
our_model.layers[layer_idx].self_attn.k_proj.weight.data = \
official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device)
our_model.layers[layer_idx].self_attn.v_proj.weight.data = \
official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device)
our_model.layers[layer_idx].self_attn.o_proj.weight.data = \
official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device)
# FFN projections
our_model.layers[layer_idx].mlp.gate_proj.weight.data = \
official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device)
our_model.layers[layer_idx].mlp.up_proj.weight.data = \
official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device)
our_model.layers[layer_idx].mlp.down_proj.weight.data = \
official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device)
loaded_count += 9 # 2 norms + 4 attn + 3 ffn
# 3. Load final norm
our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device)
loaded_count += 1
print(f"\nβœ… Loaded {num_layers} transformer blocks")
print(f"βœ… Total loaded: {loaded_count} weight tensors")
print("=" * 70)
return loaded_count
if __name__ == "__main__":
"""Test model creation and parameter count"""
# Load config
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
# Create model
model = SmolLM2Model(config)
# Count parameters
total_params = model.get_num_params()
print(f"\nTotal parameters: {total_params:,}")
print(f"Expected: 134,515,008")
print(f"Match: {total_params == 134_515_008}")
# Test forward pass
test_input = torch.randint(0, config.vocab_size, (1, 10))
output = model(test_input)
print(f"\nForward pass test:")
print(f" Input shape: {test_input.shape}")
print(f" Output shape: {output.shape}")
print(f" Expected: torch.Size([1, 10, 49152])")
# Test generation
generated = model.generate(test_input, max_new_tokens=5)
print(f"\nGeneration test:")
print(f" Generated shape: {generated.shape}")
print(f" Expected: torch.Size([1, 15])")