tiny-tiny-stories / README.md
hidude562's picture
Update README.md
6302659 verified
---
license: mit
language:
- en
tags:
- 1-bit
- bitnet
- tiny
- language-model
- tinystories
datasets:
- roneneldan/TinyStories
pipeline_tag: text-generation
---
# tiny-tiny-stories
Oh, you think tinystories is small? Welcome to tiny-stories on steroids. A 1m param model at 1.5 bit quantization!
A **1-bit (ternary {-1, 0, +1}) transformer language model** trained on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories).
## Specs
| | |
|---|---|
| **Parameters** | 998,784 (< 1M) |
| **Weight precision** | 1.58-bit ternary (BitNet b1.58) |
| **Tokenizer** | SentencePiece unigram, 192 vocab |
| **Context length** | 512 tokens |
| **Best val loss** | 1.2087 (perplexity 3.35) |
| **Training** | 100K steps on 2.1M TinyStories |
| **Checkpoint size** | 3.9 MB (FP32 latent), ~350 KB quantized |
## Architecture
- **d_model**: 128
- **Heads**: 4 (head_dim=32)
- **Layers**: 5
- **FFN**: SwiGLU (d_ff=336)
- **Position encoding**: RoPE (no learned positional embeddings)
- **Normalization**: RMSNorm
- **Embeddings**: Tied input/output, full precision
- **All linear layers**: BitLinear with ternary quantization + straight-through estimator
## How it works
All Q/K/V/O attention projections and SwiGLU FFN matrices use **BitLinear**: weights are quantized to {-1, 0, +1} during the forward pass via `round(W / mean(|W|))`, with gradients flowing through a straight-through estimator to full-precision latent weights during training.
## Usage
```python
import torch
import sentencepiece as spm
# Load tokenizer and model
sp = spm.SentencePieceProcessor(model_file='tokenizer.model')
# Load model (see train.py for BitLM class definition)
from train import BitLM, Config
cfg = Config()
cfg.vocab_size = 192
model = BitLM(cfg)
ckpt = torch.load('best.pt', map_location='cpu', weights_only=True)
state = ckpt['model']
if any(k.startswith('_orig_mod.') for k in state):
state = {k.replace('_orig_mod.', ''): v for k, v in state.items()}
model.load_state_dict(state)
model.eval()
# Generate
ids = [sp.bos_id()] + sp.encode("Once upon a time")
idx = torch.tensor([ids])
out = model.generate(idx, max_new=200, temp=0.8, top_k=40, eos_id=sp.eos_id())
print(sp.decode(out[0].tolist()))
```
## Sample output
> Once upon a time, there was a squirrel. He was very curious and loved to play in the park. One day, he noticed a big tree in the sky. He was already laughing, but he was stronger under his houses. The squirrel was glue of all the trees, exploring the walls...
## Training
Trained on 2x RTX 2080 Ti using mixed-precision (FP16) with AdamW optimizer, cosine LR schedule (1.5e-3 peak, 1000 step warmup), and gradient accumulation (effective batch size 384).
```bash
python train.py --exp-dir ./output --device cuda:0 --compile
```