--- license: mit language: - en tags: - 1-bit - bitnet - tiny - language-model - tinystories datasets: - roneneldan/TinyStories pipeline_tag: text-generation --- # tiny-tiny-stories Oh, you think tinystories is small? Welcome to tiny-stories on steroids. A 1m param model at 1.5 bit quantization! A **1-bit (ternary {-1, 0, +1}) transformer language model** trained on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). ## Specs | | | |---|---| | **Parameters** | 998,784 (< 1M) | | **Weight precision** | 1.58-bit ternary (BitNet b1.58) | | **Tokenizer** | SentencePiece unigram, 192 vocab | | **Context length** | 512 tokens | | **Best val loss** | 1.2087 (perplexity 3.35) | | **Training** | 100K steps on 2.1M TinyStories | | **Checkpoint size** | 3.9 MB (FP32 latent), ~350 KB quantized | ## Architecture - **d_model**: 128 - **Heads**: 4 (head_dim=32) - **Layers**: 5 - **FFN**: SwiGLU (d_ff=336) - **Position encoding**: RoPE (no learned positional embeddings) - **Normalization**: RMSNorm - **Embeddings**: Tied input/output, full precision - **All linear layers**: BitLinear with ternary quantization + straight-through estimator ## How it works All Q/K/V/O attention projections and SwiGLU FFN matrices use **BitLinear**: weights are quantized to {-1, 0, +1} during the forward pass via `round(W / mean(|W|))`, with gradients flowing through a straight-through estimator to full-precision latent weights during training. ## Usage ```python import torch import sentencepiece as spm # Load tokenizer and model sp = spm.SentencePieceProcessor(model_file='tokenizer.model') # Load model (see train.py for BitLM class definition) from train import BitLM, Config cfg = Config() cfg.vocab_size = 192 model = BitLM(cfg) ckpt = torch.load('best.pt', map_location='cpu', weights_only=True) state = ckpt['model'] if any(k.startswith('_orig_mod.') for k in state): state = {k.replace('_orig_mod.', ''): v for k, v in state.items()} model.load_state_dict(state) model.eval() # Generate ids = [sp.bos_id()] + sp.encode("Once upon a time") idx = torch.tensor([ids]) out = model.generate(idx, max_new=200, temp=0.8, top_k=40, eos_id=sp.eos_id()) print(sp.decode(out[0].tolist())) ``` ## Sample output > Once upon a time, there was a squirrel. He was very curious and loved to play in the park. One day, he noticed a big tree in the sky. He was already laughing, but he was stronger under his houses. The squirrel was glue of all the trees, exploring the walls... ## Training Trained on 2x RTX 2080 Ti using mixed-precision (FP16) with AdamW optimizer, cosine LR schedule (1.5e-3 peak, 1000 step warmup), and gradient accumulation (effective batch size 384). ```bash python train.py --exp-dir ./output --device cuda:0 --compile ```