""" model.py ======== Complete SmolLM2-135M model implementation Architecture: - 30 transformer blocks - 576 hidden dimensions - 9 query heads, 3 KV heads (Grouped Query Attention) - SwiGLU feed-forward network - RoPE position embeddings - RMSNorm layer normalization - Weight tying (embeddings = lm_head) Total parameters: 134,515,008 (~135M) """ import torch import torch.nn as nn import torch.nn.functional as F import math from components import RMSNorm, TransformerBlock from transformers import AutoConfig class SmolLM2Model(nn.Module): """ SmolLM2-135M Language Model A decoder-only transformer based on Llama architecture with: - Grouped Query Attention (memory efficient) - SwiGLU FFN (improved expressiveness) - RoPE position embeddings (length extrapolation) - RMSNorm (faster than LayerNorm) Model configuration: - Layers: 30 - Hidden size: 576 - Attention heads: 9 (Q) / 3 (KV) - FFN size: 1536 - Vocab size: 49,152 - Context length: 2048 """ def __init__(self, config): """ Initialize SmolLM2 model Args: config: Model configuration object with attributes: - vocab_size: Size of vocabulary (49152) - hidden_size: Model dimension (576) - num_hidden_layers: Number of transformer blocks (30) - tie_word_embeddings: Whether to tie input/output embeddings - rms_norm_eps: Epsilon for RMSNorm """ super().__init__() self.config = config # Token embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Transformer blocks (30 layers) self.layers = nn.ModuleList([ TransformerBlock(config) for _ in range(config.num_hidden_layers) ]) # Final layer normalization self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Language modeling head (output projection) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Weight tying: share embeddings with output projection if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight print(f"✅ Model initialized with {config.num_hidden_layers} transformer blocks") print(f"✅ Weight tying: {config.tie_word_embeddings}") def forward(self, input_ids, attention_mask=None, position_ids=None): """ Forward pass through the model Args: input_ids (torch.Tensor): Input token IDs [batch, seq_len] attention_mask (torch.Tensor, optional): Attention mask position_ids (torch.Tensor, optional): Position indices Returns: torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size] """ batch_size, seq_len = input_ids.shape # Create position IDs if not provided if position_ids is None: position_ids = torch.arange(seq_len, device=input_ids.device) # Embed tokens hidden_states = self.embed_tokens(input_ids) # Pass through all transformer blocks for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) # Final normalization hidden_states = self.norm(hidden_states) # Project to vocabulary logits = self.lm_head(hidden_states) return logits def generate( self, input_ids, max_new_tokens=50, temperature=1.0, top_p=0.9, top_k=None, do_sample=True ): """ Generate text autoregressively Supports multiple sampling strategies: - Greedy decoding (temperature=0) - Temperature sampling - Nucleus (top-p) sampling - Top-k sampling Args: input_ids (torch.Tensor): Input token IDs [batch, seq_len] max_new_tokens (int): Number of tokens to generate temperature (float): Sampling temperature (0 = greedy, >1 = more random) top_p (float): Nucleus sampling threshold (0-1) top_k (int, optional): Top-k sampling threshold do_sample (bool): Whether to sample or use greedy decoding Returns: torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens] """ self.eval() for _ in range(max_new_tokens): with torch.no_grad(): # Forward pass logits = self(input_ids) # Get next token logits next_token_logits = logits[:, -1, :] # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature # Greedy decoding if not do_sample or temperature == 0: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) else: # Top-k sampling if top_k is not None: top_k = min(top_k, next_token_logits.size(-1)) indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Nucleus (top-p) sampling if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p # Keep at least one token sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False # Scatter to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') # Sample from distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids def get_num_params(self, non_embedding=False): """ Count model parameters Args: non_embedding (bool): If True, exclude embedding parameters Returns: int: Number of parameters """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.embed_tokens.weight.numel() # If weights are tied, don't double-count if not self.config.tie_word_embeddings: n_params -= self.lm_head.weight.numel() return n_params def initialize_weights(model, config): """ Initialize model weights using GPT-style initialization Strategy: - All weights: Normal(0, 0.02) - Residual projections: Scaled by 1/sqrt(2 * num_layers) - RMSNorm: Initialized to 1.0 (PyTorch default) The residual scaling prevents variance explosion in deep networks. Args: model (SmolLM2Model): Model to initialize config: Model configuration """ std = 0.02 num_layers = config.num_hidden_layers # Residual scaling factor: 1/sqrt(2 * num_layers) residual_scaling = 1.0 / math.sqrt(2 * num_layers) print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}") # Initialize embeddings nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std) # Initialize each transformer block for layer in model.layers: # Attention projections nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std) nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std) nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std) # Output projection with residual scaling nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling) # FFN projections nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std) nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std) # Output projection with residual scaling nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling) # RMSNorm weights are initialized to 1.0 by default (PyTorch) print(f"✅ Initialized {sum(1 for _ in model.parameters())} weight tensors") def load_pretrained_weights(our_model, official_model, device='cuda'): """ Load weights from HuggingFace official model Maps weight names from official model to our implementation: - model.embed_tokens.weight -> embed_tokens.weight - model.layers.{i}.* -> layers[i].* - model.norm.weight -> norm.weight - lm_head.weight (tied with embeddings) Args: our_model (SmolLM2Model): Our model to load weights into official_model: HuggingFace official model device (str): Device to load weights to Returns: int: Number of weight tensors loaded """ print("=" * 70) print("LOADING PRETRAINED WEIGHTS") print("=" * 70) official_state = official_model.state_dict() loaded_count = 0 # 1. Load token embeddings our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device) loaded_count += 1 # 2. Load all transformer blocks num_layers = our_model.config.num_hidden_layers for layer_idx in range(num_layers): prefix = f'model.layers.{layer_idx}' # Layer norms our_model.layers[layer_idx].input_layernorm.weight.data = \ official_state[f'{prefix}.input_layernorm.weight'].clone().to(device) our_model.layers[layer_idx].post_attention_layernorm.weight.data = \ official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device) # Attention projections our_model.layers[layer_idx].self_attn.q_proj.weight.data = \ official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device) our_model.layers[layer_idx].self_attn.k_proj.weight.data = \ official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device) our_model.layers[layer_idx].self_attn.v_proj.weight.data = \ official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device) our_model.layers[layer_idx].self_attn.o_proj.weight.data = \ official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device) # FFN projections our_model.layers[layer_idx].mlp.gate_proj.weight.data = \ official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device) our_model.layers[layer_idx].mlp.up_proj.weight.data = \ official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device) our_model.layers[layer_idx].mlp.down_proj.weight.data = \ official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device) loaded_count += 9 # 2 norms + 4 attn + 3 ffn # 3. Load final norm our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device) loaded_count += 1 print(f"\n✅ Loaded {num_layers} transformer blocks") print(f"✅ Total loaded: {loaded_count} weight tensors") print("=" * 70) return loaded_count if __name__ == "__main__": """Test model creation and parameter count""" # Load config config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") # Create model model = SmolLM2Model(config) # Count parameters total_params = model.get_num_params() print(f"\nTotal parameters: {total_params:,}") print(f"Expected: 134,515,008") print(f"Match: {total_params == 134_515_008}") # Test forward pass test_input = torch.randint(0, config.vocab_size, (1, 10)) output = model(test_input) print(f"\nForward pass test:") print(f" Input shape: {test_input.shape}") print(f" Output shape: {output.shape}") print(f" Expected: torch.Size([1, 10, 49152])") # Test generation generated = model.generate(test_input, max_new_tokens=5) print(f"\nGeneration test:") print(f" Generated shape: {generated.shape}") print(f" Expected: torch.Size([1, 15])")