LaughLM / README.md
dignity045's picture
Duplicate from Dhiraj45/LaughLM
9639af0
# LaughLM
A high-performance **decoder-only transformer training system** built with **JAX + Flax** and optimized for **TPU training**.
LaughLM is designed as a **research-friendly yet production-capable framework** for experimenting with modern transformer architectures while maintaining high training throughput.
The system emphasizes:
- clean modular architecture
- hardware-efficient training
- reproducible experiments
- flexible configuration
- large-scale dataset streaming
- high MFU optimization on TPUs
---
# Features
- **Decoder-only GPT architecture**
- **JAX + Flax implementation**
- **TPU-optimized mixed precision training**
- **Flexible architecture selection**
- **Pre-tokenized memory-mapped datasets**
- **Multiple attention variants**
- **Multiple FFN architectures**
- **Weight tying support**
- **Orbax checkpointing**
- **Optax optimizers**
- **Config-driven experiments**
Supported architecture features:
- MHA / MQA / GQA attention
- RoPE positional encoding
- SwiGLU / GEGLU / GELU MLP
- RMSNorm / LayerNorm
- configurable residual scaling
- multiple LR schedulers
- masked weight decay
---
# Project Structure:
```text
.
β”œβ”€β”€ configs
β”‚Β Β  β”œβ”€β”€ gpu_test.yaml
β”‚Β Β  └── test.yaml
β”œβ”€β”€ LaughLM
β”‚Β Β  β”œβ”€β”€ config
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ loader.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ schema.py
β”‚Β Β  β”‚Β Β  └── validation.py
β”‚Β Β  β”œβ”€β”€ data
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ domain_sampler.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ memmap_loader.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ shard_writer.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ tokenizer.py
β”‚Β Β  β”‚Β Β  └── tokenizer_train.py
β”‚Β Β  β”œβ”€β”€ model
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ gpt.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ layers
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ attention.py
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ mlp.py
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ normalization.py
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ positional.py
β”‚Β Β  β”‚Β Β  β”‚Β Β  └── residual.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ parameter_utils.py
β”‚Β Β  β”‚Β Β  └── transformer_block.py
β”‚Β Β  β”œβ”€β”€ training
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ checkpoint.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ logger.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ loss.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ optimizer.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ scheduler.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ trainer.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train_state.py
β”‚Β Β  β”‚Β Β  └── train_step.py
β”‚Β Β  └── utils
β”‚Β Β  └── rng.py
β”œβ”€β”€ LICENSE
β”œβ”€β”€ log.txt
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
└── scripts
β”œβ”€β”€ build_shard.py
└── train_gpu_test.py
```
---
# Installation
Clone the repository:
```bash
git clone https://github.com/your-org/LaughLM.git
cd LaughLM
```
Create environment:
```bash
python -m venv venv
source venv/bin/activate
```
Install dependencies:
```bash
pip install -r requirements.txt
```
For TPU environments install JAX:
```bash
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
---
Configuration
Experiments are fully defined via YAML configs.
Example:
configs/test.yaml
Configuration sections include:
model architecture
optimizer
scheduler
runtime parameters
dataset sources
tokenizer settings
hardware configuration
Example snippet:
```yaml
model:
d_model: 768
num_layers: 12
num_heads: 12
vocab_size: 32000
max_seq_len: 2048
```
---
Dataset Pipeline
LaughLM uses a pre-tokenized dataset pipeline for maximum throughput.
Training datasets are converted into binary token shards.
Advantages:
high throughput
minimal CPU overhead
memory-mapped streaming
scalable to large datasets
---
Step 1 β€” Train Tokenizer
Train a tokenizer using streaming datasets.
```bash
python -m LaughLM.data.tokenizer_train
```
Output:
tokenizer.json
---
Step 2 β€” Build Token Shards
Convert raw text into token shards.
```bash
python scripts/build_shard.py
```
Output:
dataset_shard.bin
Shards contain:
uint16 token stream
---
Step 3 β€” Training
Run training:
```bash
python scripts/train_gpu_test.py
```
Training automatically handles:
optimizer
scheduler
logging
checkpointing
Example output:
STEP PROGRESS β”‚ LOSS PPL β”‚ LR β”‚ TOK/S β”‚ MFU
---
Checkpointing
Checkpoints are saved using Orbax.
Default directory:
checkpoints/
Resume training automatically if checkpoints exist.
---
Benchmarking Performance
Benchmark raw training throughput:
python scripts/benchmark_train_step.py
This measures:
compile time
step time
tokens/sec
MFU
Example output:
Compile time: 18.2s
Step time: 0.048s
Tokens/sec: 430000
---
Monitoring
Training logger displays:
loss
perplexity
gradient norm
tokens/sec
MFU
ETA
Example:
STEP PROGRESS β”‚ LOSS β”‚ LR β”‚ TOK/S β”‚ MFU β”‚ ETA
---
Optimization Roadmap
LaughLM is designed to progressively reach high TPU utilization.
Target MFU:
50–60% MFU on TPU v5e
Optimization phases:
Phase Goal
Baseline establish benchmark
Data pipeline remove input bottlenecks
Graph optimization eliminate Python overhead
Kernel fusion maximize MXU utilization
Flash attention reduce memory traffic
---
Development Workflow
Recommended workflow:
1. Create branch
2. Implement change
3. Run benchmark
4. Compare tokens/sec
5. Merge if improvement
Example:
```bash
git checkout -b optimize_attention
```
---
Contributing
Pull requests should include:
clear description
performance impact
benchmark results
---
License
MIT License
---
Acknowledgements
LaughLM builds on ideas from:
GPT
LLaMA
PaLM
DeepSeek
MiniCPM
and the JAX / Flax ecosystem.
---
Future Work
Planned improvements:
Flash Attention
Activation checkpointing
MoE layers
PJIT sharding
distributed training