YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
LaughLM
A high-performance decoder-only transformer training system built with JAX + Flax and optimized for TPU training.
LaughLM is designed as a research-friendly yet production-capable framework for experimenting with modern transformer architectures while maintaining high training throughput.
The system emphasizes:
- clean modular architecture
- hardware-efficient training
- reproducible experiments
- flexible configuration
- large-scale dataset streaming
- high MFU optimization on TPUs
Features
- Decoder-only GPT architecture
- JAX + Flax implementation
- TPU-optimized mixed precision training
- Flexible architecture selection
- Pre-tokenized memory-mapped datasets
- Multiple attention variants
- Multiple FFN architectures
- Weight tying support
- Orbax checkpointing
- Optax optimizers
- Config-driven experiments
Supported architecture features:
- MHA / MQA / GQA attention
- RoPE positional encoding
- SwiGLU / GEGLU / GELU MLP
- RMSNorm / LayerNorm
- configurable residual scaling
- multiple LR schedulers
- masked weight decay
Project Structure:
.
βββ configs
β βββ gpu_test.yaml
β βββ test.yaml
βββ LaughLM
β βββ config
β β βββ loader.py
β β βββ schema.py
β β βββ validation.py
β βββ data
β β βββ domain_sampler.py
β β βββ memmap_loader.py
β β βββ shard_writer.py
β β βββ tokenizer.py
β β βββ tokenizer_train.py
β βββ model
β β βββ gpt.py
β β βββ layers
β β β βββ attention.py
β β β βββ mlp.py
β β β βββ normalization.py
β β β βββ positional.py
β β β βββ residual.py
β β βββ parameter_utils.py
β β βββ transformer_block.py
β βββ training
β β βββ checkpoint.py
β β βββ logger.py
β β βββ loss.py
β β βββ optimizer.py
β β βββ scheduler.py
β β βββ trainer.py
β β βββ train_state.py
β β βββ train_step.py
β βββ utils
β βββ rng.py
βββ LICENSE
βββ log.txt
βββ pyproject.toml
βββ README.md
βββ requirements.txt
βββ scripts
βββ build_shard.py
βββ train_gpu_test.py
Installation
Clone the repository:
git clone https://github.com/your-org/LaughLM.git
cd LaughLM
Create environment:
python -m venv venv
source venv/bin/activate
Install dependencies:
pip install -r requirements.txt
For TPU environments install JAX:
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Configuration
Experiments are fully defined via YAML configs.
Example:
configs/test.yaml
Configuration sections include:
model architecture
optimizer
scheduler
runtime parameters
dataset sources
tokenizer settings
hardware configuration
Example snippet:
model:
d_model: 768
num_layers: 12
num_heads: 12
vocab_size: 32000
max_seq_len: 2048
Dataset Pipeline
LaughLM uses a pre-tokenized dataset pipeline for maximum throughput.
Training datasets are converted into binary token shards.
Advantages:
high throughput
minimal CPU overhead
memory-mapped streaming
scalable to large datasets
Step 1 β Train Tokenizer
Train a tokenizer using streaming datasets.
python -m LaughLM.data.tokenizer_train
Output:
tokenizer.json
Step 2 β Build Token Shards
Convert raw text into token shards.
python scripts/build_shard.py
Output:
dataset_shard.bin
Shards contain:
uint16 token stream
Step 3 β Training
Run training:
python scripts/train_gpu_test.py
Training automatically handles:
optimizer
scheduler
logging
checkpointing
Example output:
STEP PROGRESS β LOSS PPL β LR β TOK/S β MFU
Checkpointing
Checkpoints are saved using Orbax.
Default directory:
checkpoints/
Resume training automatically if checkpoints exist.
Benchmarking Performance
Benchmark raw training throughput:
python scripts/benchmark_train_step.py
This measures:
compile time
step time
tokens/sec
MFU
Example output:
Compile time: 18.2s Step time: 0.048s Tokens/sec: 430000
Monitoring
Training logger displays:
loss
perplexity
gradient norm
tokens/sec
MFU
ETA
Example:
STEP PROGRESS β LOSS β LR β TOK/S β MFU β ETA
Optimization Roadmap
LaughLM is designed to progressively reach high TPU utilization.
Target MFU:
50β60% MFU on TPU v5e
Optimization phases:
Phase Goal
Baseline establish benchmark Data pipeline remove input bottlenecks Graph optimization eliminate Python overhead Kernel fusion maximize MXU utilization Flash attention reduce memory traffic
Development Workflow
Recommended workflow:
- Create branch
- Implement change
- Run benchmark
- Compare tokens/sec
- Merge if improvement
Example:
git checkout -b optimize_attention
Contributing
Pull requests should include:
clear description
performance impact
benchmark results
License
MIT License
Acknowledgements
LaughLM builds on ideas from:
GPT
LLaMA
PaLM
DeepSeek
MiniCPM
and the JAX / Flax ecosystem.
Future Work
Planned improvements:
Flash Attention
Activation checkpointing
MoE layers
PJIT sharding
distributed training