LLM_D3: A Sparse 350M Architecture Trained on 50B Tokens
This repository contains the implementation of LLM_D3, a decoder-only Large Language Model trained from scratch on 50 billion tokens of the C4 English-only dataset. It features a modern, high-performance architecture optimized for efficiency, combining Mixture of Experts (MoE), Multi-head Latent Attention (MLA), and Rotary Positional Embeddings (RoPE).
Designed for genuine generalization over rote memorization, the model was trained using a single-epoch pass, achieving a 33% zero-shot HellaSwag score. Following instruction fine-tuning, it serves as a capable assistant with strong general reasoning and factual recall.
π Model Statistics
| Metric | Value |
|---|---|
| Total Parameters | 358.74M |
| Active Parameters | 171.96M |
| Sparsity Ratio | 52.06% |
| Training Data | 50B Tokens (C4 English) |
| Architecture | MLA + Sparse MoE + RoPE |
for the scipt
github: firdavsus/LLM_D3
ποΈ Architecture Details
The model utilizes a custom GPT implementation (LLM_2.py) with several key architectural innovations focused on compute efficiency and memory optimization.
Multi-head Latent Attention (MLA)
To solve the memory bottleneck of the KV cache, LLM_D3 implements Multi-head Latent Attention.
- Latent Compression: Query and KV states are compressed into a lower-dimensional latent space before being up-projected for attention calculations.
- Throughput: This reduces the memory footprint of the KV cache during inference while maintaining the performance of standard Multi-Head Attention.
Sparse Mixture of Experts (MoE)
LLM_D3 uses a sparse MoE architecture for 19 out of its 24 layers.
- Expert Configuration: Each MoE layer contains 6 experts, with a Top-2 routing mechanism active for every token.
- Hybrid Stability Sandwich: For improved training stability, the first 3 layers and last 2 layers are initialized as standard dense MLP blocks rather than MoE layers.
- Routing: Uses a noisy Top-K router with auxiliary load-balancing and router z-loss to prevent expert collapse and ensure balanced utilization across the 19 MoE blocks.
Positional Encoding
- RoPE: Rotary Positional Embeddings are applied to ensure better handling of long-range dependencies and superior sequence positioning compared to traditional learned embeddings.
π Training & Evaluation
Pre-training Setup
- Policy: Single-epoch pass on 50B tokens (no repetition) to prioritize feature extraction and generalization.
- Batch Size: 1M tokens effective batch size for high gradient stability.
- Schedule: Warmup-Stable-Decay (WSD) / Stepped Cosine Decay with a 1,000-step warmup.
- Optimizer: AdamW with hardware-optimized settings.
Benchmarks
| Benchmark | Setting | Score |
|---|---|---|
| HellaSwag | Zero-shot | 33% |
Fine-tuning
Fine-tuned on the alpaca-cleaned dataset using an Instruction-Input-Response format.
- Strengths: Strong general reasoning, factual consistency, and instruction adherence.
- Known Limitations: The model currently struggles with complex arithmetic. Additionally, an initialization anomaly in the final 2 layers resulted in a signal spike at the end of the network; while the model remains functional and capable, this is a known area for future refinement.
πΌοΈ Visualizations
Pre-Training Curves
50k steps on a 50B token corpus with 1M token effective batch size.
Diagnostics & Utilization
Visualizing weight distribution and expert utilization. Current routing shows healthy balance with utilization under 33%.
π οΈ Usage
Inference
Interact with the model using the test.py script, which includes Top-K, Top-P, and repetition penalty sampling.
python test.py
Fine-tuning
To replicate the instruction tuning on your own dataset:
- Format your data following the Alpaca template in
fine_tune.py. - Execute:
python fine_tune.py
π Repository Structure
LLM_2.py: Core architecture (MLA, MoE, RoPE).train.py: Pre-training logic and WSD scheduler.fine_tune.py: Instruction tuning implementation.manager.py: MoE auxiliary loss tracking.check_params.py: Active vs. total parameter counter.eval.py: HellaSwag evaluation suite.analysis.py/show.py: Diagnostic and visualization tools.
Note: This model was developed as a research exploration into efficient sparse architectures. Verify all mathematical outputs manually.
