Gemma 3 270M β Pre-trained from Scratch on TinyStories
A custom implementation of the Gemma 3 architecture (scaled to 164.6M parameters), pre-trained from scratch on the TinyStories dataset.
π Results
| Metric | Value |
|---|---|
| Best Val Loss | 1.7845 |
| Perplexity | 5.96 |
| Best Iteration | 13,000 |
| Parameters | 164.6M |
ποΈ Architecture
This model implements the complete Gemma 3 architecture with all modern innovations:
| Component | Specification |
|---|---|
| Layers | 18 (15 sliding + 3 full attention) |
| Embedding Dim | 640 |
| Attention Heads | 4 (Multi-Query, 1 KV group) |
| Head Dimension | 256 |
| FFN Hidden | 2,048 (GeGLU activation) |
| Context Length | 32,768 tokens |
| Vocabulary | 50,257 (GPT-2 BPE) |
Key Features
- Sliding Window Attention (w=512): O(nΓw) instead of O(nΒ²), 64Γ cheaper
- Multi-Query Attention: All query heads share 1 K,V head β 4Γ less KV cache
- RoPE with Dual Bases: 10K (local patterns) + 1M (long-range dependencies)
- QK Normalization: RMSNorm on Q,K vectors before attention
- Gemma-style RMSNorm: (1 + weight) scaling for stable initialization
- GeGLU Feed-Forward: Gated GELU activation with 3.2Γ expansion
Layer Type Pattern
Layers 1-5: Sliding Attention (local, base=10K)
Layer 6: Full Attention (global, base=1M)
Layers 7-11: Sliding Attention (local, base=10K)
Layer 12: Full Attention (global, base=1M)
Layers 13-17: Sliding Attention (local, base=10K)
Layer 18: Full Attention (global, base=1M)
π Training
- Dataset: TinyStories (2.1M stories, 471M tokens)
- Tokenizer: GPT-2 BPE via tiktoken (50,257 vocab)
- Optimizer: AdamW (Ξ²1=0.9, Ξ²2=0.95, Ξ΅=1e-9, weight_decay=0.1)
- Learning Rate: 1e-4 β 5e-5 (cosine decay with 1K step warmup)
- Precision: bfloat16 mixed precision
- Hardware: NVIDIA A100 40GB (Google Colab Pro)
- Gradient Clipping: max_norm=0.5
π» Usage
import torch
import tiktoken
# Load model (you'll need the model class definition)
model = Gemma3Model(config)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Tokenize
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode_ordinary(prompt)])
# Generate
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
print(enc.decode(output[0].tolist()))
π Sample Outputs
Prompt: "Once upon a time, there was a little cat named Mittens"
Temperature 0.7: Mittens was very hungry and wanted to eat some food. She went outside to find some grass to eat. Mittens saw a big tree and decided to climb it. She climbed up and up until she reached the top. As she was in the tree, she saw a small bird with a broken wing. Mittens knew just what to do. She took the bird to her mom and asked for help.
π Credits
- Architecture Reference: Vizuara Team - Raj (Tutorial)
- Dataset: TinyStories by Ronen Eldan & Yuanzhi Li
- Tokenizer: OpenAI tiktoken (GPT-2 BPE)
π License
Apache 2.0
- Downloads last month
- 1
