| # LaughLM | |
| A high-performance **decoder-only transformer training system** built with **JAX + Flax** and optimized for **TPU training**. | |
| LaughLM is designed as a **research-friendly yet production-capable framework** for experimenting with modern transformer architectures while maintaining high training throughput. | |
| The system emphasizes: | |
| - clean modular architecture | |
| - hardware-efficient training | |
| - reproducible experiments | |
| - flexible configuration | |
| - large-scale dataset streaming | |
| - high MFU optimization on TPUs | |
| --- | |
| # Features | |
| - **Decoder-only GPT architecture** | |
| - **JAX + Flax implementation** | |
| - **TPU-optimized mixed precision training** | |
| - **Flexible architecture selection** | |
| - **Pre-tokenized memory-mapped datasets** | |
| - **Multiple attention variants** | |
| - **Multiple FFN architectures** | |
| - **Weight tying support** | |
| - **Orbax checkpointing** | |
| - **Optax optimizers** | |
| - **Config-driven experiments** | |
| Supported architecture features: | |
| - MHA / MQA / GQA attention | |
| - RoPE positional encoding | |
| - SwiGLU / GEGLU / GELU MLP | |
| - RMSNorm / LayerNorm | |
| - configurable residual scaling | |
| - multiple LR schedulers | |
| - masked weight decay | |
| --- | |
| # Project Structure: | |
| ```text | |
| . | |
| βββ configs | |
| βΒ Β βββ gpu_test.yaml | |
| βΒ Β βββ test.yaml | |
| βββ LaughLM | |
| βΒ Β βββ config | |
| βΒ Β βΒ Β βββ loader.py | |
| βΒ Β βΒ Β βββ schema.py | |
| βΒ Β βΒ Β βββ validation.py | |
| βΒ Β βββ data | |
| βΒ Β βΒ Β βββ domain_sampler.py | |
| βΒ Β βΒ Β βββ memmap_loader.py | |
| βΒ Β βΒ Β βββ shard_writer.py | |
| βΒ Β βΒ Β βββ tokenizer.py | |
| βΒ Β βΒ Β βββ tokenizer_train.py | |
| βΒ Β βββ model | |
| βΒ Β βΒ Β βββ gpt.py | |
| βΒ Β βΒ Β βββ layers | |
| βΒ Β βΒ Β βΒ Β βββ attention.py | |
| βΒ Β βΒ Β βΒ Β βββ mlp.py | |
| βΒ Β βΒ Β βΒ Β βββ normalization.py | |
| βΒ Β βΒ Β βΒ Β βββ positional.py | |
| βΒ Β βΒ Β βΒ Β βββ residual.py | |
| βΒ Β βΒ Β βββ parameter_utils.py | |
| βΒ Β βΒ Β βββ transformer_block.py | |
| βΒ Β βββ training | |
| βΒ Β βΒ Β βββ checkpoint.py | |
| βΒ Β βΒ Β βββ logger.py | |
| βΒ Β βΒ Β βββ loss.py | |
| βΒ Β βΒ Β βββ optimizer.py | |
| βΒ Β βΒ Β βββ scheduler.py | |
| βΒ Β βΒ Β βββ trainer.py | |
| βΒ Β βΒ Β βββ train_state.py | |
| βΒ Β βΒ Β βββ train_step.py | |
| βΒ Β βββ utils | |
| βΒ Β βββ rng.py | |
| βββ LICENSE | |
| βββ log.txt | |
| βββ pyproject.toml | |
| βββ README.md | |
| βββ requirements.txt | |
| βββ scripts | |
| βββ build_shard.py | |
| βββ train_gpu_test.py | |
| ``` | |
| --- | |
| # Installation | |
| Clone the repository: | |
| ```bash | |
| git clone https://github.com/your-org/LaughLM.git | |
| cd LaughLM | |
| ``` | |
| Create environment: | |
| ```bash | |
| python -m venv venv | |
| source venv/bin/activate | |
| ``` | |
| Install dependencies: | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| For TPU environments install JAX: | |
| ```bash | |
| pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| ``` | |
| --- | |
| Configuration | |
| Experiments are fully defined via YAML configs. | |
| Example: | |
| configs/test.yaml | |
| Configuration sections include: | |
| model architecture | |
| optimizer | |
| scheduler | |
| runtime parameters | |
| dataset sources | |
| tokenizer settings | |
| hardware configuration | |
| Example snippet: | |
| ```yaml | |
| model: | |
| d_model: 768 | |
| num_layers: 12 | |
| num_heads: 12 | |
| vocab_size: 32000 | |
| max_seq_len: 2048 | |
| ``` | |
| --- | |
| Dataset Pipeline | |
| LaughLM uses a pre-tokenized dataset pipeline for maximum throughput. | |
| Training datasets are converted into binary token shards. | |
| Advantages: | |
| high throughput | |
| minimal CPU overhead | |
| memory-mapped streaming | |
| scalable to large datasets | |
| --- | |
| Step 1 β Train Tokenizer | |
| Train a tokenizer using streaming datasets. | |
| ```bash | |
| python -m LaughLM.data.tokenizer_train | |
| ``` | |
| Output: | |
| tokenizer.json | |
| --- | |
| Step 2 β Build Token Shards | |
| Convert raw text into token shards. | |
| ```bash | |
| python scripts/build_shard.py | |
| ``` | |
| Output: | |
| dataset_shard.bin | |
| Shards contain: | |
| uint16 token stream | |
| --- | |
| Step 3 β Training | |
| Run training: | |
| ```bash | |
| python scripts/train_gpu_test.py | |
| ``` | |
| Training automatically handles: | |
| optimizer | |
| scheduler | |
| logging | |
| checkpointing | |
| Example output: | |
| STEP PROGRESS β LOSS PPL β LR β TOK/S β MFU | |
| --- | |
| Checkpointing | |
| Checkpoints are saved using Orbax. | |
| Default directory: | |
| checkpoints/ | |
| Resume training automatically if checkpoints exist. | |
| --- | |
| Benchmarking Performance | |
| Benchmark raw training throughput: | |
| python scripts/benchmark_train_step.py | |
| This measures: | |
| compile time | |
| step time | |
| tokens/sec | |
| MFU | |
| Example output: | |
| Compile time: 18.2s | |
| Step time: 0.048s | |
| Tokens/sec: 430000 | |
| --- | |
| Monitoring | |
| Training logger displays: | |
| loss | |
| perplexity | |
| gradient norm | |
| tokens/sec | |
| MFU | |
| ETA | |
| Example: | |
| STEP PROGRESS β LOSS β LR β TOK/S β MFU β ETA | |
| --- | |
| Optimization Roadmap | |
| LaughLM is designed to progressively reach high TPU utilization. | |
| Target MFU: | |
| 50β60% MFU on TPU v5e | |
| Optimization phases: | |
| Phase Goal | |
| Baseline establish benchmark | |
| Data pipeline remove input bottlenecks | |
| Graph optimization eliminate Python overhead | |
| Kernel fusion maximize MXU utilization | |
| Flash attention reduce memory traffic | |
| --- | |
| Development Workflow | |
| Recommended workflow: | |
| 1. Create branch | |
| 2. Implement change | |
| 3. Run benchmark | |
| 4. Compare tokens/sec | |
| 5. Merge if improvement | |
| Example: | |
| ```bash | |
| git checkout -b optimize_attention | |
| ``` | |
| --- | |
| Contributing | |
| Pull requests should include: | |
| clear description | |
| performance impact | |
| benchmark results | |
| --- | |
| License | |
| MIT License | |
| --- | |
| Acknowledgements | |
| LaughLM builds on ideas from: | |
| GPT | |
| LLaMA | |
| PaLM | |
| DeepSeek | |
| MiniCPM | |
| and the JAX / Flax ecosystem. | |
| --- | |
| Future Work | |
| Planned improvements: | |
| Flash Attention | |
| Activation checkpointing | |
| MoE layers | |
| PJIT sharding | |
| distributed training | |