Kaito Model
A GPT-2 style decoder-only transformer with modern architectural enhancements, built from scratch in PyTorch. The model features two variants: a 114M parameter dense model and an optional 510M parameter Mixture of Experts (MoE) variant, both capable of efficient training and generation.
Model Details
- Model Type: Causal Language Model (Decoder-only Transformer)
- Parameters:
- Dense: 114M parameters
- MoE (optional): 510M parameters (matches 114M FLOPs via top-2 routing)
- Architecture Highlights: RoPE, GQA, SwiGLU, RMSNorm, Sliding Window Attention, Parallel Attention+FFN, Mixed Precision, Weight Tying, and Z-Loss.
- Language(s): English (
en) - License: MIT (or see LICENSE file)
Architecture Features
The Kaito model incorporates several advancements over the original GPT-2:
- RoPE (Rotary Position Embedding): Enables length extrapolation without learned positional embeddings.
- GQA (Grouped Query Attention): 12 query heads share 4 KV heads, reducing KV cache size by 3× with negligible quality loss.
- SwiGLU FFN: Replaces GELU with a gated Swish activation, yielding better performance at the same parameter count.
- RMSNorm: Used instead of LayerNorm for faster normalization without mean-centering.
- Parallel Attention + FFN: PaLM-style parallel streams that share one normalized input, speeding up training by ~15%.
- Sliding Window Attention: Configurable banded causal mask limiting attention to N previous tokens.
- Mixture of Experts (MoE) [Optional]: Replaces standard FFN with 8 expert SwiGLU networks using top-2 routing and a load-balancing auxiliary loss.
Intended Use
This model is intended for educational purposes, architecture experimentation, and text generation tasks. It demonstrates how modern LLM optimizations can be integrated into a foundational scale model.
Training Details
- Optimizer: AdamW with linear warmup and cosine decay.
- Precision: Mixed Precision (
bfloat16autocast on CUDA/TPU). - Regularization: Weight Decay, Z-Loss (logit magnitude penalty), and Dropout.
- Batching: Supports gradient accumulation for larger effective batch sizes.
- Hardware: Optimized to run on CPU, GPU (CUDA), and TPU environments.
How to Get Started with the Model
You can load and use the model with the source code provided in the repository.
from main import kaitomodel
import torch
from text_generation import generate_text
import tiktoken
# Initialize tokenizer and model
tokenizer = tiktoken.get_encoding("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = kaitomodel().to(device)
# Load weights if available
# model.load_state_dict(torch.load("model.pt", map_location=device))
model.eval()
# Generate text
result = generate_text(
model=model,
idx="In the beginning",
tokenizer=tokenizer,
new_max_length=100,
temperature=0.8,
top_k=50,
top_p=0.9,
)
print(result)
Citation
If you use this model or its architectural implementations in your work, please refer to the corresponding techniques:
- RoPE (RoFormer, 2021)
- GQA (Ainslie et al., 2023)
- SwiGLU (Shazeer, 2020)
- RMSNorm (Zhang & Sennrich, 2019)
- Parallel Attention+FFN & Z-Loss (PaLM, 2022)
- MoE Techniques (Switch Transformers, Mixtral)