MiniLM / ARCHITECTURE.md
0sparsh2's picture
Upload ARCHITECTURE.md with huggingface_hub
f4f6f1c verified
|
Raw
History Blame Contribute Delete
4.48 kB
# MiniLM: The 1.58-bit Architecture Deep Dive
MiniLM is not just a quantized model—it is a completely custom neural network architecture built from the ground up to natively operate in **1.58-bit (Ternary)** precision.
By heavily compressing the internal mathematics of the Transformer, we achieved a deep 12-layer model that fits entirely into **6.00 MB** of RAM, making it small enough to run on microcontrollers, smartwatches, and embedded IoT devices.
This document serves as a masterclass on exactly how MiniLM was engineered.
---
## 1. The Core Innovation: 1.58-bit Ternary Weights
In standard Large Language Models (like Llama 3 or GPT-4), the neural network's memory (its "weights") are stored as 16-bit floating-point numbers (`FP16`). A single layer can easily exceed gigabytes of RAM.
MiniLM uses the **BitNet 1.58b** architecture paradigm. We discard floating-point precision entirely. Every single internal weight in MiniLM's Linear layers is constrained to exactly three possible values:
* `-1`
* `0`
* `1`
Because $\log_2(3) \approx 1.58$, we call this a 1.58-bit model.
### Why is this revolutionary?
When you multiply a number by `-1`, `0`, or `1`, you aren't actually doing complex matrix multiplication. You are simply doing **Addition and Subtraction**.
If a weight is `1`, you add the input. If it is `-1`, you subtract the input. If it is `0`, you ignore it.
This means MiniLM replaces the most computationally expensive operation in AI (Floating Point Matrix Multiplication) with ultra-fast, hardware-efficient Integer Addition.
---
## 2. How We Trained It: The Straight-Through Estimator (STE)
You cannot train a ternary neural network using standard backpropagation, because the rounding function (clamping a value to -1, 0, or 1) has a derivative of zero almost everywhere. The gradient would instantly "die" and the model would never learn.
To solve this, we implemented a custom **Straight-Through Estimator (STE)**:
1. **Forward Pass:** We take the high-precision latent weights, calculate their mean, divide by a scaling factor (`beta`), and aggressively round them to `[-1, 0, 1]`. The forward calculations are performed using these ternary weights.
2. **Backward Pass:** When the loss calculates the error gradient, we *pretend* the rounding step never happened. We pass the gradient straight through to the high-precision latent weights.
This allows the high-precision weights to slowly adjust over time, until their rounded ternary counterparts snap into the optimal configuration.
---
## 3. Breaking the Depth Barrier: Weight Tying
Our initial 4-layer model fit into 3.93 MB and showed promising results, but 4 layers is incredibly shallow for an LLM to form coherent, long-form thoughts.
To solve this, we implemented **Weight Tying**.
In a standard LLM, the `Embedding Layer` (which turns words into vectors) and the `Output Head` (which turns vectors back into words) are two separate, massive matrices.
Because we used a 32,000 token vocabulary, these two matrices were consuming **over 85%** of our total parameter budget!
By mathematically tying the weights together (`model.head.weight = model.embedding.weight`), we instantly freed up 8 Million parameters. We re-invested this exact parameter budget to triple the depth of the neural network from 4 layers to **12 layers**, drastically improving output coherence without increasing the file size by a single byte.
---
## 4. Knowledge Distillation
Training a 1.58-bit model from absolute scratch using Next-Token Prediction is notoriously difficult and requires massive amounts of data and compute (100k+ steps).
Instead, we used **Knowledge Distillation**.
1. We loaded `HuggingFaceTB/SmolLM-135M-Instruct` as a "Teacher" model.
2. We forced MiniLM to use the exact same tokenizer as SmolLM.
3. For every prompt, the Teacher model output a rich probability distribution (logits) of what the next word should be.
4. We used `KLDivLoss` (KL Divergence) to force MiniLM to perfectly mimic the Teacher's probability distribution.
By learning from the Teacher's rich understanding of language rather than just a sparse one-hot encoded dataset, MiniLM converged in just **3,000 steps** on the TinyStories dataset!
---
## Conclusion
MiniLM is a testament to the future of Edge AI. By combining Ternary Quantization, Weight Tying, and Knowledge Distillation, we have packed the structural depth of a 12-layer Transformer into a file size smaller than an MP3 song.