| --- |
| license: mit |
| language: |
| - en |
| tags: |
| - 1-bit |
| - bitnet |
| - tiny |
| - language-model |
| - tinystories |
| datasets: |
| - roneneldan/TinyStories |
| pipeline_tag: text-generation |
| --- |
| |
| # tiny-tiny-stories |
|
|
| Oh, you think tinystories is small? Welcome to tiny-stories on steroids. A 1m param model at 1.5 bit quantization! |
|
|
| A **1-bit (ternary {-1, 0, +1}) transformer language model** trained on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). |
|
|
| ## Specs |
|
|
| | | | |
| |---|---| |
| | **Parameters** | 998,784 (< 1M) | |
| | **Weight precision** | 1.58-bit ternary (BitNet b1.58) | |
| | **Tokenizer** | SentencePiece unigram, 192 vocab | |
| | **Context length** | 512 tokens | |
| | **Best val loss** | 1.2087 (perplexity 3.35) | |
| | **Training** | 100K steps on 2.1M TinyStories | |
| | **Checkpoint size** | 3.9 MB (FP32 latent), ~350 KB quantized | |
|
|
| ## Architecture |
|
|
| - **d_model**: 128 |
| - **Heads**: 4 (head_dim=32) |
| - **Layers**: 5 |
| - **FFN**: SwiGLU (d_ff=336) |
| - **Position encoding**: RoPE (no learned positional embeddings) |
| - **Normalization**: RMSNorm |
| - **Embeddings**: Tied input/output, full precision |
| - **All linear layers**: BitLinear with ternary quantization + straight-through estimator |
| |
| ## How it works |
| |
| All Q/K/V/O attention projections and SwiGLU FFN matrices use **BitLinear**: weights are quantized to {-1, 0, +1} during the forward pass via `round(W / mean(|W|))`, with gradients flowing through a straight-through estimator to full-precision latent weights during training. |
| |
| ## Usage |
| |
| ```python |
| import torch |
| import sentencepiece as spm |
| |
| # Load tokenizer and model |
| sp = spm.SentencePieceProcessor(model_file='tokenizer.model') |
| |
| # Load model (see train.py for BitLM class definition) |
| from train import BitLM, Config |
| cfg = Config() |
| cfg.vocab_size = 192 |
| model = BitLM(cfg) |
| ckpt = torch.load('best.pt', map_location='cpu', weights_only=True) |
| state = ckpt['model'] |
| if any(k.startswith('_orig_mod.') for k in state): |
| state = {k.replace('_orig_mod.', ''): v for k, v in state.items()} |
| model.load_state_dict(state) |
| model.eval() |
| |
| # Generate |
| ids = [sp.bos_id()] + sp.encode("Once upon a time") |
| idx = torch.tensor([ids]) |
| out = model.generate(idx, max_new=200, temp=0.8, top_k=40, eos_id=sp.eos_id()) |
| print(sp.decode(out[0].tolist())) |
| ``` |
| |
| ## Sample output |
| |
| > Once upon a time, there was a squirrel. He was very curious and loved to play in the park. One day, he noticed a big tree in the sky. He was already laughing, but he was stronger under his houses. The squirrel was glue of all the trees, exploring the walls... |
| |
| ## Training |
| |
| Trained on 2x RTX 2080 Ti using mixed-precision (FP16) with AdamW optimizer, cosine LR schedule (1.5e-3 peak, 1000 step warmup), and gradient accumulation (effective batch size 384). |
| |
| ```bash |
| python train.py --exp-dir ./output --device cuda:0 --compile |
| ``` |
| |