YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

LaughLM

A high-performance decoder-only transformer training system built with JAX + Flax and optimized for TPU training.

LaughLM is designed as a research-friendly yet production-capable framework for experimenting with modern transformer architectures while maintaining high training throughput.

The system emphasizes:

  • clean modular architecture
  • hardware-efficient training
  • reproducible experiments
  • flexible configuration
  • large-scale dataset streaming
  • high MFU optimization on TPUs

Features

  • Decoder-only GPT architecture
  • JAX + Flax implementation
  • TPU-optimized mixed precision training
  • Flexible architecture selection
  • Pre-tokenized memory-mapped datasets
  • Multiple attention variants
  • Multiple FFN architectures
  • Weight tying support
  • Orbax checkpointing
  • Optax optimizers
  • Config-driven experiments

Supported architecture features:

  • MHA / MQA / GQA attention
  • RoPE positional encoding
  • SwiGLU / GEGLU / GELU MLP
  • RMSNorm / LayerNorm
  • configurable residual scaling
  • multiple LR schedulers
  • masked weight decay

Project Structure:

.
β”œβ”€β”€ configs
β”‚   β”œβ”€β”€ gpu_test.yaml
β”‚   └── test.yaml
β”œβ”€β”€ LaughLM
β”‚   β”œβ”€β”€ config
β”‚   β”‚   β”œβ”€β”€ loader.py
β”‚   β”‚   β”œβ”€β”€ schema.py
β”‚   β”‚   └── validation.py
β”‚   β”œβ”€β”€ data
β”‚   β”‚   β”œβ”€β”€ domain_sampler.py
β”‚   β”‚   β”œβ”€β”€ memmap_loader.py
β”‚   β”‚   β”œβ”€β”€ shard_writer.py
β”‚   β”‚   β”œβ”€β”€ tokenizer.py
β”‚   β”‚   └── tokenizer_train.py
β”‚   β”œβ”€β”€ model
β”‚   β”‚   β”œβ”€β”€ gpt.py
β”‚   β”‚   β”œβ”€β”€ layers
β”‚   β”‚   β”‚   β”œβ”€β”€ attention.py
β”‚   β”‚   β”‚   β”œβ”€β”€ mlp.py
β”‚   β”‚   β”‚   β”œβ”€β”€ normalization.py
β”‚   β”‚   β”‚   β”œβ”€β”€ positional.py
β”‚   β”‚   β”‚   └── residual.py
β”‚   β”‚   β”œβ”€β”€ parameter_utils.py
β”‚   β”‚   └── transformer_block.py
β”‚   β”œβ”€β”€ training
β”‚   β”‚   β”œβ”€β”€ checkpoint.py
β”‚   β”‚   β”œβ”€β”€ logger.py
β”‚   β”‚   β”œβ”€β”€ loss.py
β”‚   β”‚   β”œβ”€β”€ optimizer.py
β”‚   β”‚   β”œβ”€β”€ scheduler.py
β”‚   β”‚   β”œβ”€β”€ trainer.py
β”‚   β”‚   β”œβ”€β”€ train_state.py
β”‚   β”‚   └── train_step.py
β”‚   └── utils
β”‚       └── rng.py
β”œβ”€β”€ LICENSE
β”œβ”€β”€ log.txt
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
└── scripts
    β”œβ”€β”€ build_shard.py
    └── train_gpu_test.py

Installation

Clone the repository:

git clone https://github.com/your-org/LaughLM.git
cd LaughLM

Create environment:

python -m venv venv
source venv/bin/activate

Install dependencies:

pip install -r requirements.txt

For TPU environments install JAX:

pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Configuration

Experiments are fully defined via YAML configs.

Example:

configs/test.yaml

Configuration sections include:

model architecture

optimizer

scheduler

runtime parameters

dataset sources

tokenizer settings

hardware configuration

Example snippet:

model:
  d_model: 768
  num_layers: 12
  num_heads: 12
  vocab_size: 32000
  max_seq_len: 2048

Dataset Pipeline

LaughLM uses a pre-tokenized dataset pipeline for maximum throughput.

Training datasets are converted into binary token shards.

Advantages:

high throughput

minimal CPU overhead

memory-mapped streaming

scalable to large datasets


Step 1 β€” Train Tokenizer

Train a tokenizer using streaming datasets.

python -m LaughLM.data.tokenizer_train

Output:

tokenizer.json


Step 2 β€” Build Token Shards

Convert raw text into token shards.

python scripts/build_shard.py

Output:

dataset_shard.bin

Shards contain:

uint16 token stream


Step 3 β€” Training

Run training:

python scripts/train_gpu_test.py

Training automatically handles:

optimizer

scheduler

logging

checkpointing

Example output:

STEP PROGRESS β”‚ LOSS PPL β”‚ LR β”‚ TOK/S β”‚ MFU


Checkpointing

Checkpoints are saved using Orbax.

Default directory:

checkpoints/

Resume training automatically if checkpoints exist.


Benchmarking Performance

Benchmark raw training throughput:

python scripts/benchmark_train_step.py

This measures:

compile time

step time

tokens/sec

MFU

Example output:

Compile time: 18.2s Step time: 0.048s Tokens/sec: 430000


Monitoring

Training logger displays:

loss

perplexity

gradient norm

tokens/sec

MFU

ETA

Example:

STEP PROGRESS β”‚ LOSS β”‚ LR β”‚ TOK/S β”‚ MFU β”‚ ETA


Optimization Roadmap

LaughLM is designed to progressively reach high TPU utilization.

Target MFU:

50–60% MFU on TPU v5e

Optimization phases:

Phase Goal

Baseline establish benchmark Data pipeline remove input bottlenecks Graph optimization eliminate Python overhead Kernel fusion maximize MXU utilization Flash attention reduce memory traffic


Development Workflow

Recommended workflow:

  1. Create branch
  2. Implement change
  3. Run benchmark
  4. Compare tokens/sec
  5. Merge if improvement

Example:

git checkout -b optimize_attention

Contributing

Pull requests should include:

clear description

performance impact

benchmark results


License

MIT License


Acknowledgements

LaughLM builds on ideas from:

GPT

LLaMA

PaLM

DeepSeek

MiniCPM

and the JAX / Flax ecosystem.


Future Work

Planned improvements:

Flash Attention

Activation checkpointing

MoE layers

PJIT sharding

distributed training

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support