|
|
""" |
|
|
model.py |
|
|
======== |
|
|
Complete SmolLM2-135M model implementation |
|
|
|
|
|
Architecture: |
|
|
- 30 transformer blocks |
|
|
- 576 hidden dimensions |
|
|
- 9 query heads, 3 KV heads (Grouped Query Attention) |
|
|
- SwiGLU feed-forward network |
|
|
- RoPE position embeddings |
|
|
- RMSNorm layer normalization |
|
|
- Weight tying (embeddings = lm_head) |
|
|
|
|
|
Total parameters: 134,515,008 (~135M) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from components import RMSNorm, TransformerBlock |
|
|
from transformers import AutoConfig |
|
|
|
|
|
|
|
|
class SmolLM2Model(nn.Module): |
|
|
""" |
|
|
SmolLM2-135M Language Model |
|
|
|
|
|
A decoder-only transformer based on Llama architecture with: |
|
|
- Grouped Query Attention (memory efficient) |
|
|
- SwiGLU FFN (improved expressiveness) |
|
|
- RoPE position embeddings (length extrapolation) |
|
|
- RMSNorm (faster than LayerNorm) |
|
|
|
|
|
Model configuration: |
|
|
- Layers: 30 |
|
|
- Hidden size: 576 |
|
|
- Attention heads: 9 (Q) / 3 (KV) |
|
|
- FFN size: 1536 |
|
|
- Vocab size: 49,152 |
|
|
- Context length: 2048 |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Initialize SmolLM2 model |
|
|
|
|
|
Args: |
|
|
config: Model configuration object with attributes: |
|
|
- vocab_size: Size of vocabulary (49152) |
|
|
- hidden_size: Model dimension (576) |
|
|
- num_hidden_layers: Number of transformer blocks (30) |
|
|
- tie_word_embeddings: Whether to tie input/output embeddings |
|
|
- rms_norm_eps: Epsilon for RMSNorm |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TransformerBlock(config) for _ in range(config.num_hidden_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
if config.tie_word_embeddings: |
|
|
self.lm_head.weight = self.embed_tokens.weight |
|
|
|
|
|
print(f"β
Model initialized with {config.num_hidden_layers} transformer blocks") |
|
|
print(f"β
Weight tying: {config.tie_word_embeddings}") |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, position_ids=None): |
|
|
""" |
|
|
Forward pass through the model |
|
|
|
|
|
Args: |
|
|
input_ids (torch.Tensor): Input token IDs [batch, seq_len] |
|
|
attention_mask (torch.Tensor, optional): Attention mask |
|
|
position_ids (torch.Tensor, optional): Position indices |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size] |
|
|
""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(seq_len, device=input_ids.device) |
|
|
|
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, attention_mask, position_ids) |
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
return logits |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
max_new_tokens=50, |
|
|
temperature=1.0, |
|
|
top_p=0.9, |
|
|
top_k=None, |
|
|
do_sample=True |
|
|
): |
|
|
""" |
|
|
Generate text autoregressively |
|
|
|
|
|
Supports multiple sampling strategies: |
|
|
- Greedy decoding (temperature=0) |
|
|
- Temperature sampling |
|
|
- Nucleus (top-p) sampling |
|
|
- Top-k sampling |
|
|
|
|
|
Args: |
|
|
input_ids (torch.Tensor): Input token IDs [batch, seq_len] |
|
|
max_new_tokens (int): Number of tokens to generate |
|
|
temperature (float): Sampling temperature (0 = greedy, >1 = more random) |
|
|
top_p (float): Nucleus sampling threshold (0-1) |
|
|
top_k (int, optional): Top-k sampling threshold |
|
|
do_sample (bool): Whether to sample or use greedy decoding |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens] |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
with torch.no_grad(): |
|
|
|
|
|
logits = self(input_ids) |
|
|
|
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if temperature > 0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
if not do_sample or temperature == 0: |
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
else: |
|
|
|
|
|
if top_k is not None: |
|
|
top_k = min(top_k, next_token_logits.size(-1)) |
|
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
|
|
|
return input_ids |
|
|
|
|
|
def get_num_params(self, non_embedding=False): |
|
|
""" |
|
|
Count model parameters |
|
|
|
|
|
Args: |
|
|
non_embedding (bool): If True, exclude embedding parameters |
|
|
|
|
|
Returns: |
|
|
int: Number of parameters |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
if non_embedding: |
|
|
n_params -= self.embed_tokens.weight.numel() |
|
|
|
|
|
if not self.config.tie_word_embeddings: |
|
|
n_params -= self.lm_head.weight.numel() |
|
|
|
|
|
return n_params |
|
|
|
|
|
|
|
|
def initialize_weights(model, config): |
|
|
""" |
|
|
Initialize model weights using GPT-style initialization |
|
|
|
|
|
Strategy: |
|
|
- All weights: Normal(0, 0.02) |
|
|
- Residual projections: Scaled by 1/sqrt(2 * num_layers) |
|
|
- RMSNorm: Initialized to 1.0 (PyTorch default) |
|
|
|
|
|
The residual scaling prevents variance explosion in deep networks. |
|
|
|
|
|
Args: |
|
|
model (SmolLM2Model): Model to initialize |
|
|
config: Model configuration |
|
|
""" |
|
|
std = 0.02 |
|
|
num_layers = config.num_hidden_layers |
|
|
|
|
|
residual_scaling = 1.0 / math.sqrt(2 * num_layers) |
|
|
|
|
|
print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}") |
|
|
|
|
|
|
|
|
nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std) |
|
|
|
|
|
|
|
|
for layer in model.layers: |
|
|
|
|
|
nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std) |
|
|
nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std) |
|
|
nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std) |
|
|
|
|
|
nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling) |
|
|
|
|
|
|
|
|
nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std) |
|
|
nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std) |
|
|
|
|
|
nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling) |
|
|
|
|
|
|
|
|
|
|
|
print(f"β
Initialized {sum(1 for _ in model.parameters())} weight tensors") |
|
|
|
|
|
|
|
|
def load_pretrained_weights(our_model, official_model, device='cuda'): |
|
|
""" |
|
|
Load weights from HuggingFace official model |
|
|
|
|
|
Maps weight names from official model to our implementation: |
|
|
- model.embed_tokens.weight -> embed_tokens.weight |
|
|
- model.layers.{i}.* -> layers[i].* |
|
|
- model.norm.weight -> norm.weight |
|
|
- lm_head.weight (tied with embeddings) |
|
|
|
|
|
Args: |
|
|
our_model (SmolLM2Model): Our model to load weights into |
|
|
official_model: HuggingFace official model |
|
|
device (str): Device to load weights to |
|
|
|
|
|
Returns: |
|
|
int: Number of weight tensors loaded |
|
|
""" |
|
|
print("=" * 70) |
|
|
print("LOADING PRETRAINED WEIGHTS") |
|
|
print("=" * 70) |
|
|
|
|
|
official_state = official_model.state_dict() |
|
|
loaded_count = 0 |
|
|
|
|
|
|
|
|
our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device) |
|
|
loaded_count += 1 |
|
|
|
|
|
|
|
|
num_layers = our_model.config.num_hidden_layers |
|
|
for layer_idx in range(num_layers): |
|
|
prefix = f'model.layers.{layer_idx}' |
|
|
|
|
|
|
|
|
our_model.layers[layer_idx].input_layernorm.weight.data = \ |
|
|
official_state[f'{prefix}.input_layernorm.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].post_attention_layernorm.weight.data = \ |
|
|
official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device) |
|
|
|
|
|
|
|
|
our_model.layers[layer_idx].self_attn.q_proj.weight.data = \ |
|
|
official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].self_attn.k_proj.weight.data = \ |
|
|
official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].self_attn.v_proj.weight.data = \ |
|
|
official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].self_attn.o_proj.weight.data = \ |
|
|
official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device) |
|
|
|
|
|
|
|
|
our_model.layers[layer_idx].mlp.gate_proj.weight.data = \ |
|
|
official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].mlp.up_proj.weight.data = \ |
|
|
official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device) |
|
|
our_model.layers[layer_idx].mlp.down_proj.weight.data = \ |
|
|
official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device) |
|
|
|
|
|
loaded_count += 9 |
|
|
|
|
|
|
|
|
our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device) |
|
|
loaded_count += 1 |
|
|
|
|
|
print(f"\nβ
Loaded {num_layers} transformer blocks") |
|
|
print(f"β
Total loaded: {loaded_count} weight tensors") |
|
|
print("=" * 70) |
|
|
|
|
|
return loaded_count |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
"""Test model creation and parameter count""" |
|
|
|
|
|
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
|
|
|
|
|
|
model = SmolLM2Model(config) |
|
|
|
|
|
|
|
|
total_params = model.get_num_params() |
|
|
print(f"\nTotal parameters: {total_params:,}") |
|
|
print(f"Expected: 134,515,008") |
|
|
print(f"Match: {total_params == 134_515_008}") |
|
|
|
|
|
|
|
|
test_input = torch.randint(0, config.vocab_size, (1, 10)) |
|
|
output = model(test_input) |
|
|
print(f"\nForward pass test:") |
|
|
print(f" Input shape: {test_input.shape}") |
|
|
print(f" Output shape: {output.shape}") |
|
|
print(f" Expected: torch.Size([1, 10, 49152])") |
|
|
|
|
|
|
|
|
generated = model.generate(test_input, max_new_tokens=5) |
|
|
print(f"\nGeneration test:") |
|
|
print(f" Generated shape: {generated.shape}") |
|
|
print(f" Expected: torch.Size([1, 15])") |