Gemma 3 270M β€” Pre-trained from Scratch on TinyStories

A custom implementation of the Gemma 3 architecture (scaled to 164.6M parameters), pre-trained from scratch on the TinyStories dataset.

πŸ“Š Results

Metric Value
Best Val Loss 1.7845
Perplexity 5.96
Best Iteration 13,000
Parameters 164.6M

Training Loss Curves

πŸ—οΈ Architecture

This model implements the complete Gemma 3 architecture with all modern innovations:

Component Specification
Layers 18 (15 sliding + 3 full attention)
Embedding Dim 640
Attention Heads 4 (Multi-Query, 1 KV group)
Head Dimension 256
FFN Hidden 2,048 (GeGLU activation)
Context Length 32,768 tokens
Vocabulary 50,257 (GPT-2 BPE)

Key Features

  • Sliding Window Attention (w=512): O(nΓ—w) instead of O(nΒ²), 64Γ— cheaper
  • Multi-Query Attention: All query heads share 1 K,V head β€” 4Γ— less KV cache
  • RoPE with Dual Bases: 10K (local patterns) + 1M (long-range dependencies)
  • QK Normalization: RMSNorm on Q,K vectors before attention
  • Gemma-style RMSNorm: (1 + weight) scaling for stable initialization
  • GeGLU Feed-Forward: Gated GELU activation with 3.2Γ— expansion

Layer Type Pattern

Layers 1-5:   Sliding Attention (local, base=10K)
Layer 6:      Full Attention (global, base=1M)
Layers 7-11:  Sliding Attention (local, base=10K)
Layer 12:     Full Attention (global, base=1M)
Layers 13-17: Sliding Attention (local, base=10K)
Layer 18:     Full Attention (global, base=1M)

πŸ“– Training

  • Dataset: TinyStories (2.1M stories, 471M tokens)
  • Tokenizer: GPT-2 BPE via tiktoken (50,257 vocab)
  • Optimizer: AdamW (Ξ²1=0.9, Ξ²2=0.95, Ξ΅=1e-9, weight_decay=0.1)
  • Learning Rate: 1e-4 β†’ 5e-5 (cosine decay with 1K step warmup)
  • Precision: bfloat16 mixed precision
  • Hardware: NVIDIA A100 40GB (Google Colab Pro)
  • Gradient Clipping: max_norm=0.5

πŸ’» Usage

import torch
import tiktoken

# Load model (you'll need the model class definition)
model = Gemma3Model(config)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

# Tokenize
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode_ordinary(prompt)])

# Generate
with torch.no_grad():
    output = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
print(enc.decode(output[0].tolist()))

πŸ“ Sample Outputs

Prompt: "Once upon a time, there was a little cat named Mittens"

Temperature 0.7: Mittens was very hungry and wanted to eat some food. She went outside to find some grass to eat. Mittens saw a big tree and decided to climb it. She climbed up and up until she reached the top. As she was in the tree, she saw a small bird with a broken wing. Mittens knew just what to do. She took the bird to her mom and asked for help.

πŸ™ Credits

  • Architecture Reference: Vizuara Team - Raj (Tutorial)
  • Dataset: TinyStories by Ronen Eldan & Yuanzhi Li
  • Tokenizer: OpenAI tiktoken (GPT-2 BPE)

πŸ“„ License

Apache 2.0

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train G3nadh/gemma3-270m-tinystories