File size: 2,768 Bytes
7a2fc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6302659
 
7a2fc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
---
license: mit
language:
- en
tags:
- 1-bit
- bitnet
- tiny
- language-model
- tinystories
datasets:
- roneneldan/TinyStories
pipeline_tag: text-generation
---

# tiny-tiny-stories

Oh, you think tinystories is small? Welcome to tiny-stories on steroids. A 1m param model at 1.5 bit quantization!

A **1-bit (ternary {-1, 0, +1}) transformer language model** trained on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories).

## Specs

| | |
|---|---|
| **Parameters** | 998,784 (< 1M) |
| **Weight precision** | 1.58-bit ternary (BitNet b1.58) |
| **Tokenizer** | SentencePiece unigram, 192 vocab |
| **Context length** | 512 tokens |
| **Best val loss** | 1.2087 (perplexity 3.35) |
| **Training** | 100K steps on 2.1M TinyStories |
| **Checkpoint size** | 3.9 MB (FP32 latent), ~350 KB quantized |

## Architecture

- **d_model**: 128
- **Heads**: 4 (head_dim=32)
- **Layers**: 5
- **FFN**: SwiGLU (d_ff=336)
- **Position encoding**: RoPE (no learned positional embeddings)
- **Normalization**: RMSNorm
- **Embeddings**: Tied input/output, full precision
- **All linear layers**: BitLinear with ternary quantization + straight-through estimator

## How it works

All Q/K/V/O attention projections and SwiGLU FFN matrices use **BitLinear**: weights are quantized to {-1, 0, +1} during the forward pass via `round(W / mean(|W|))`, with gradients flowing through a straight-through estimator to full-precision latent weights during training.

## Usage

```python
import torch
import sentencepiece as spm

# Load tokenizer and model
sp = spm.SentencePieceProcessor(model_file='tokenizer.model')

# Load model (see train.py for BitLM class definition)
from train import BitLM, Config
cfg = Config()
cfg.vocab_size = 192
model = BitLM(cfg)
ckpt = torch.load('best.pt', map_location='cpu', weights_only=True)
state = ckpt['model']
if any(k.startswith('_orig_mod.') for k in state):
    state = {k.replace('_orig_mod.', ''): v for k, v in state.items()}
model.load_state_dict(state)
model.eval()

# Generate
ids = [sp.bos_id()] + sp.encode("Once upon a time")
idx = torch.tensor([ids])
out = model.generate(idx, max_new=200, temp=0.8, top_k=40, eos_id=sp.eos_id())
print(sp.decode(out[0].tolist()))
```

## Sample output

> Once upon a time, there was a squirrel. He was very curious and loved to play in the park. One day, he noticed a big tree in the sky. He was already laughing, but he was stronger under his houses. The squirrel was glue of all the trees, exploring the walls...

## Training

Trained on 2x RTX 2080 Ti using mixed-precision (FP16) with AdamW optimizer, cosine LR schedule (1.5e-3 peak, 1000 step warmup), and gradient accumulation (effective batch size 384).

```bash
python train.py --exp-dir ./output --device cuda:0 --compile
```