roneneldan/TinyStories
Viewer β’ Updated β’ 2.14M β’ 89.5k β’ 983
A custom implementation of the Gemma 3 architecture (scaled to 164.6M parameters), pre-trained from scratch on the TinyStories dataset.
| Metric | Value |
|---|---|
| Best Val Loss | 1.7845 |
| Perplexity | 5.96 |
| Best Iteration | 13,000 |
| Parameters | 164.6M |
This model implements the complete Gemma 3 architecture with all modern innovations:
| Component | Specification |
|---|---|
| Layers | 18 (15 sliding + 3 full attention) |
| Embedding Dim | 640 |
| Attention Heads | 4 (Multi-Query, 1 KV group) |
| Head Dimension | 256 |
| FFN Hidden | 2,048 (GeGLU activation) |
| Context Length | 32,768 tokens |
| Vocabulary | 50,257 (GPT-2 BPE) |
Layers 1-5: Sliding Attention (local, base=10K)
Layer 6: Full Attention (global, base=1M)
Layers 7-11: Sliding Attention (local, base=10K)
Layer 12: Full Attention (global, base=1M)
Layers 13-17: Sliding Attention (local, base=10K)
Layer 18: Full Attention (global, base=1M)
import torch
import tiktoken
# Load model (you'll need the model class definition)
model = Gemma3Model(config)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Tokenize
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode_ordinary(prompt)])
# Generate
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
print(enc.decode(output[0].tolist()))
Prompt: "Once upon a time, there was a little cat named Mittens"
Temperature 0.7: Mittens was very hungry and wanted to eat some food. She went outside to find some grass to eat. Mittens saw a big tree and decided to climb it. She climbed up and up until she reached the top. As she was in the tree, she saw a small bird with a broken wing. Mittens knew just what to do. She took the bird to her mom and asked for help.
Apache 2.0