|
|
--- |
|
|
title: SmolLM2-135M Text Generator |
|
|
emoji: 🐨 |
|
|
colorFrom: yellow |
|
|
colorTo: blue |
|
|
sdk: gradio |
|
|
sdk_version: 6.0.1 |
|
|
app_file: app.py |
|
|
pinned: false |
|
|
short_description: A Llama based SmolLM2-135M Transformer (Decoder only) |
|
|
--- |
|
|
|
|
|
HuggingFace space for inference demo: https://huggingface.co/spaces/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun |
|
|
|
|
|
# SmolLM2-135M Implementation |
|
|
|
|
|
A from-scratch PyTorch implementation of the SmolLM2-135M language model, following the LLaMA architecture with modern optimizations. |
|
|
|
|
|
## Overview |
|
|
|
|
|
This repository contains a complete implementation of SmolLM2-135M, a 135 million parameter decoder-only transformer model. The implementation includes: |
|
|
|
|
|
- **Model Architecture** (`model.py`): Complete model definition with KV cache support |
|
|
- **Training Script** (`train.py`): PyTorch Lightning training with WSD scheduler |
|
|
- **Gradio App** (`app.py`): Interactive web interface for text generation |
|
|
|
|
|
## Model Architecture (`model.py`) |
|
|
|
|
|
### Architecture Components |
|
|
|
|
|
The model follows the LLaMA-style decoder-only transformer architecture with the following key components: |
|
|
|
|
|
#### 1. **SmolConfig** (Configuration Class) |
|
|
|
|
|
A dataclass that stores all model hyperparameters: |
|
|
|
|
|
```python |
|
|
@dataclass |
|
|
class SmolConfig: |
|
|
vocab_size: int = 49152 # Vocabulary size |
|
|
hidden_size: int = 576 # Hidden dimension |
|
|
intermediate_size: int = 1536 # MLP intermediate dimension |
|
|
num_hidden_layers: int = 30 # Number of transformer layers |
|
|
num_attention_heads: int = 9 # Number of query heads |
|
|
num_key_value_heads: int = 3 # Number of key/value heads (GQA) |
|
|
max_position_embeddings: int = 8192 # Maximum sequence length |
|
|
rope_theta: float = 100000.0 # RoPE base frequency |
|
|
rms_norm_eps: float = 1e-5 # RMSNorm epsilon |
|
|
attention_bias: bool = False # Whether to use bias in attention |
|
|
mlp_bias: bool = False # Whether to use bias in MLP |
|
|
dtype: torch.dtype = torch.bfloat16 |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- `head_dim` property: Automatically computes head dimension (hidden_size // num_attention_heads = 64) |
|
|
- `from_hf()` class method: Loads configuration from HuggingFace model config |
|
|
|
|
|
#### 2. **RMSNorm** (Root Mean Square Normalization) |
|
|
|
|
|
Replaces LayerNorm with a more efficient normalization: |
|
|
|
|
|
```python |
|
|
class RMSNorm(nn.Module): |
|
|
def forward(self, x): |
|
|
norm = x.pow(2).mean(dim=-1, keepdim=True) |
|
|
x = x * torch.rsqrt(norm + self.eps) |
|
|
return self.weight * x |
|
|
``` |
|
|
|
|
|
**Benefits:** |
|
|
- More efficient than LayerNorm (no mean subtraction) |
|
|
- Used throughout the model for pre-norm architecture |
|
|
|
|
|
#### 3. **RoPE** (Rotary Positional Embeddings) |
|
|
|
|
|
Rotary Position Embeddings applied to query and key tensors: |
|
|
|
|
|
```python |
|
|
def build_rope_cache(seq_len, head_dim, base, device, dtype): |
|
|
# Computes cosine and sine caches for RoPE |
|
|
inv_freq = 1.0 / (base ** (freq_seq / half_dim)) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
cos = freqs.cos()[None, None, :, :] |
|
|
sin = freqs.sin()[None, None, :, :] |
|
|
return cos, sin |
|
|
|
|
|
def apply_rope(x, cos, sin): |
|
|
# Applies rotary transformation to input tensor |
|
|
x1, x2 = x[..., :half], x[..., half:] |
|
|
x1_rot = x1 * cos - x2 * sin |
|
|
x2_rot = x1 * sin + x2 * cos |
|
|
return torch.cat([x1_rot, x2_rot], dim=-1) |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- Relative positional encoding (no absolute position embeddings) |
|
|
- Applied only to Q and K (not V) |
|
|
- Supports efficient caching for inference |
|
|
|
|
|
#### 4. **MultiHeadSelfAttention** (Grouped Query Attention) |
|
|
|
|
|
Implements GQA (Grouped Query Attention) where: |
|
|
- **Query heads**: 9 (full attention) |
|
|
- **Key/Value heads**: 3 (shared across query heads) |
|
|
|
|
|
```python |
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
def forward(self, x, cos, sin, past_key_value=None, use_cache=False): |
|
|
# 1. Project to Q, K, V |
|
|
q = self.q_proj(x) # (B, T, n_heads * head_dim) |
|
|
k = self.k_proj(x) # (B, T, n_kv_heads * head_dim) |
|
|
v = self.v_proj(x) # (B, T, n_kv_heads * head_dim) |
|
|
|
|
|
# 2. Apply RoPE to Q and K |
|
|
q = apply_rope(q, cos, sin) |
|
|
k = apply_rope(k, cos, sin) |
|
|
|
|
|
# 3. KV Cache support (for inference) |
|
|
if past_key_value: |
|
|
k = torch.cat([past_k, k], dim=2) |
|
|
v = torch.cat([past_v, v], dim=2) |
|
|
|
|
|
# 4. GQA: Expand K/V if needed |
|
|
if n_kv_heads < n_heads: |
|
|
k = k.repeat_interleave(repeat_factor, dim=1) |
|
|
v = v.repeat_interleave(repeat_factor, dim=1) |
|
|
|
|
|
# 5. Compute attention scores |
|
|
scores = (q @ k.transpose(-2, -1)) / sqrt(head_dim) |
|
|
scores = scores + causal_mask # Causal masking |
|
|
|
|
|
# 6. Softmax and weighted sum |
|
|
probs = F.softmax(scores, dim=-1) |
|
|
out = probs @ v |
|
|
|
|
|
return out, present_key_value |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- **KV Cache**: Efficient inference by caching past key-value pairs |
|
|
- **GQA**: Reduces memory by sharing K/V heads (3:1 ratio) |
|
|
- **Causal Masking**: Prevents attending to future tokens |
|
|
- **RoPE Integration**: Positional encoding via rotary embeddings |
|
|
|
|
|
#### 5. **SmolMLP** (SwiGLU Activation) |
|
|
|
|
|
Implements the SwiGLU (Swish-Gated Linear Unit) MLP: |
|
|
|
|
|
```python |
|
|
class SmolMLP(nn.Module): |
|
|
def forward(self, x): |
|
|
# fc1 outputs 2 * intermediate_size |
|
|
x = self.fc1(x) # (B, T, 2 * 1536) = (B, T, 3072) |
|
|
x1, x2 = x.chunk(2, dim=-1) # Split into two parts |
|
|
# SwiGLU: SiLU(x1) * x2 |
|
|
return self.fc2(F.silu(x1) * x2) |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- **SwiGLU**: `SiLU(x1) * x2` activation (better than ReLU/GELU) |
|
|
- **No bias**: Following LLaMA architecture |
|
|
- **Efficient**: Single matrix multiplication with split |
|
|
|
|
|
#### 6. **SmolBlock** (Transformer Block) |
|
|
|
|
|
Combines attention and MLP with pre-norm and residual connections: |
|
|
|
|
|
```python |
|
|
class SmolBlock(nn.Module): |
|
|
def forward(self, x, cos, sin, past_key_value=None, use_cache=False): |
|
|
# Pre-norm attention with residual |
|
|
attn_out, present_kv = self.attn( |
|
|
self.attn_norm(x), cos, sin, |
|
|
past_key_value=past_key_value, use_cache=use_cache |
|
|
) |
|
|
x = x + attn_out |
|
|
|
|
|
# Pre-norm MLP with residual |
|
|
x = x + self.mlp(self.mlp_norm(x)) |
|
|
|
|
|
return x, present_kv |
|
|
``` |
|
|
|
|
|
**Architecture:** |
|
|
- **Pre-norm**: Normalization before attention/MLP (not after) |
|
|
- **Residual connections**: Skip connections for gradient flow |
|
|
- **KV Cache passthrough**: Supports efficient inference |
|
|
|
|
|
#### 7. **SmolLM2** (Main Model) |
|
|
|
|
|
Top-level model that combines all components: |
|
|
|
|
|
```python |
|
|
class SmolLM2(nn.Module): |
|
|
def __init__(self, config): |
|
|
self.embed_tokens = nn.Embedding(vocab_size, hidden_size) |
|
|
self.layers = nn.ModuleList([SmolBlock(config) for _ in range(30)]) |
|
|
self.norm = RMSNorm(hidden_size) |
|
|
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
|
|
|
# Weight tying: share embeddings and output weights |
|
|
self.lm_head.weight = self.embed_tokens.weight |
|
|
|
|
|
def forward(self, input_ids, past_key_values=None, use_cache=False): |
|
|
# 1. Token embeddings |
|
|
x = self.embed_tokens(input_ids) |
|
|
|
|
|
# 2. Build RoPE cache |
|
|
cos, sin = build_rope_cache(...) |
|
|
|
|
|
# 3. Pass through transformer layers |
|
|
present_key_values = [] |
|
|
for layer in self.layers: |
|
|
x, present_kv = layer(x, cos, sin, past_key_value, use_cache) |
|
|
if use_cache: |
|
|
present_key_values.append(present_kv) |
|
|
|
|
|
# 4. Final norm and language modeling head |
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
return logits, present_key_values |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- **Weight Tying**: Embeddings and output weights are shared (reduces parameters) |
|
|
- **KV Cache Support**: Full support for efficient autoregressive generation |
|
|
- **30 Layers**: Deep transformer stack for capacity |
|
|
|
|
|
#### 8. **Generate Method** (Text Generation) |
|
|
|
|
|
Autoregressive text generation with KV cache: |
|
|
|
|
|
```python |
|
|
@torch.no_grad() |
|
|
def generate(self, input_ids, max_new_tokens=100, temperature=1.0, |
|
|
top_k=None, top_p=None, eos_token_id=None): |
|
|
generated = input_ids |
|
|
past_key_values = None |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
# Forward pass with KV cache |
|
|
logits, past_key_values = self.forward( |
|
|
generated[:, -1:] if past_key_values else generated, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
# Sample next token with temperature, top-k, top-p |
|
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
# Apply top-k and top-p filtering |
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
if eos_token_id and (next_token == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return generated |
|
|
``` |
|
|
|
|
|
**Key Features:** |
|
|
- **KV Cache**: Only processes new tokens (not entire sequence) |
|
|
- **Sampling**: Supports temperature, top-k, and top-p (nucleus) sampling |
|
|
- **Efficient**: O(1) per token after initial forward pass |
|
|
|
|
|
### Model Specifications |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| **Total Parameters** | ~135M | |
|
|
| **Hidden Size** | 576 | |
|
|
| **Layers** | 30 | |
|
|
| **Attention Heads** | 9 (Q), 3 (K/V) | |
|
|
| **Head Dimension** | 64 | |
|
|
| **Intermediate Size** | 1536 | |
|
|
| **Vocabulary Size** | 49,152 | |
|
|
| **Max Sequence Length** | 8,192 | |
|
|
| **RoPE Theta** | 100,000 | |
|
|
| **Activation** | SwiGLU (SiLU-gated) | |
|
|
| **Normalization** | RMSNorm | |
|
|
| **Weight Tying** | Yes (embeddings = output) | |
|
|
|
|
|
### Key Design Choices |
|
|
|
|
|
1. **GQA (Grouped Query Attention)**: 3:1 ratio reduces memory by 66% for K/V cache |
|
|
2. **Pre-norm Architecture**: More stable training than post-norm |
|
|
3. **RMSNorm**: Faster and simpler than LayerNorm |
|
|
4. **RoPE**: Relative positional encoding, no learned embeddings |
|
|
5. **SwiGLU**: Better activation than ReLU/GELU |
|
|
6. **Weight Tying**: Reduces parameters and improves generalization |
|
|
7. **No Biases**: Following LLaMA, reduces parameters slightly |
|
|
|
|
|
### Usage Example |
|
|
|
|
|
```python |
|
|
from model import SmolConfig, SmolLM2 |
|
|
from transformers import AutoConfig |
|
|
|
|
|
# Load config from HuggingFace |
|
|
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
config = SmolConfig.from_hf(hf_config) |
|
|
|
|
|
# Create model |
|
|
model = SmolLM2(config) |
|
|
|
|
|
# Forward pass (training) |
|
|
input_ids = torch.randint(0, config.vocab_size, (2, 512)) |
|
|
logits, _ = model(input_ids, use_cache=False) |
|
|
|
|
|
# Text generation (inference with KV cache) |
|
|
prompt_ids = tokenizer.encode("Hello, how are you?") |
|
|
generated = model.generate( |
|
|
prompt_ids, |
|
|
max_new_tokens=100, |
|
|
temperature=0.8, |
|
|
top_k=50 |
|
|
) |
|
|
``` |
|
|
|
|
|
## Training |
|
|
|
|
|
See `README_TRAINING.md` for detailed training instructions. |
|
|
|
|
|
## Inference |
|
|
|
|
|
See `app.py` for the Gradio web interface or use the `generate()` method directly. |
|
|
|
|
|
## References |
|
|
|
|
|
- [SmolLM2 Paper](https://arxiv.org/abs/2406.02528) |
|
|
- [LLaMA Architecture](https://arxiv.org/abs/2302.13971) |
|
|
- [RoPE: Rotary Position Embedding](https://arxiv.org/abs/2104.09864) |
|
|
- [SwiGLU Activation](https://arxiv.org/abs/2002.05202) |
|
|
|