etwk commited on
Commit
ea9333f
·
1 Parent(s): 55de2ed

Docs: correct README to the shipped 5 weight-sets (shared 64-512 cell), fix param count ~21M and tier-5/6 numbers; remove the legacy MLP train.py (superseded by the carry-aware TCN recipe in README/manifest)

Browse files
Files changed (2) hide show
  1. README.md +31 -24
  2. train.py +0 -221
README.md CHANGED
@@ -19,9 +19,10 @@ metrics:
19
  A compliant bit-sequential RNN that **clears every reduction tier, 1 through 10** (primes up to
20
  2^2048) on the public benchmark — tiers 1-5 = 100%, tier 6 = 98%, tiers 7-8 = 100%,
21
  tier 9 = 99%, **tier 10 = 100%** — so `highest_tier_above_90 = 10` (the maximum),
22
- overall_accuracy **0.997**. Every cell is the same **carry-aware TCN** (~30M params total, 0.13 GB),
23
- so its capability comes from *learning one algorithmic step* rather than memorising finite
24
- multiplication tables, and it verifiably generalises to primes never seen in training.
 
25
 
26
  ## The idea
27
 
@@ -47,25 +48,26 @@ The single-step function is **piecewise linear** (`2t + bit*b`, then subtract 0,
47
  `2p`), which is why it generalises across primes where the full bilinear map does not:
48
  held-out-prime validation accuracy tracks training accuracy throughout (no memorisation gap).
49
 
50
- ## Eight cells, routed by prime size
51
 
52
- The recurrence is exact only if the state is wide enough to hold the residue, so the cell is
53
- trained per bit-width. The model ships eight and routes each problem to the narrowest cell
54
- whose state holds its prime:
 
 
55
 
56
- | Cell | Primes | Tiers | Architecture | Params | Public benchmark |
57
  |---|---|---|---|---|---|
58
- | 16-bit | `< 2^16` | 1-3 | carry-aware TCN, 6 blocks, dil 1..8 | ~2.4M | tiers 1-3 = 1.00 |
59
- | 32-bit | `< 2^32` | 4 | carry-aware TCN, 8 blocks, dil 1..16 | ~3.2M | tier 4 = 1.00 |
60
- | 64-bit | `< 2^64` | 5 | carry-aware TCN, 8 blocks, dil 1..32 | ~3.2M | tier 5 = 0.99 |
61
- | 128-bit | `< 2^128` | 6 | carry-aware TCN, 10 blocks, dil 1..64 | ~3.9M | tier 6 = 0.97 |
62
- | 256-bit | `< 2^256` | 7 | carry-aware TCN, 12 blocks, dil 1..128 | ~4.7M | tier 7 = 0.98 |
63
- | 512-bit | `< 2^512` | 8 | carry-aware TCN, 14 blocks, dil 1..256 | ~5.5M | tier 8 = 0.98 |
64
- | 1024-bit | `< 2^1024` | 9 | carry-aware TCN, 12 blocks, dil 1..512 | ~4.7M | tier 9 = 0.99 |
65
- | 2048-bit | `< 2^2048` | 10 | carry-aware TCN, 13 blocks, dil 1..1024 | ~5.1M | tier 10 = 1.00 |
66
-
67
- For `p >= 2^2048` (outside all regimes) the model emits the honest `[0]` fallback without
68
- invoking the network.
69
 
70
  ## The carry-aware TCN (every tier)
71
 
@@ -134,12 +136,17 @@ position. Combined with gradient accumulation (effective batch ~26k) and the wor
134
  loss, this took tier 9 from **0.73 -> 0.99**, even across prime widths (held-out value-uniform
135
  validation 0.99; per-width 1015-1024 all ~0.99).
136
 
 
 
 
137
  ```bash
138
- python horner_rnn/train.py --stage1-minutes 50 # 16-bit cell -> weights16.pt
139
- python exploration/train_horner32.py --minutes 120 # 32-bit cell -> weights32.pt
140
- python exploration/train_horner_tcn.py --bits 64 --blocks 8 --max-dil 32 --lo-bits 62 # tier 5
141
- python exploration/train_horner_tcn.py --bits 256 --blocks 12 --max-dil 128 --lo-bits 251 # tier 7
142
- python exploration/train_horner_tcn.py --bits 512 --blocks 14 --max-dil 256 --accum 2 # tier 8
 
 
143
  ```
144
 
145
  The **1024-bit (tier-9) cell is a multi-stage curriculum**, not a single run — the carry
 
19
  A compliant bit-sequential RNN that **clears every reduction tier, 1 through 10** (primes up to
20
  2^2048) on the public benchmark — tiers 1-5 = 100%, tier 6 = 98%, tiers 7-8 = 100%,
21
  tier 9 = 99%, **tier 10 = 100%** — so `highest_tier_above_90 = 10` (the maximum),
22
+ overall_accuracy **0.997**. Every cell is the same **carry-aware TCN** (~21M params total across
23
+ five weight-sets, 0.13 GB), so its capability comes from *learning one algorithmic step* rather
24
+ than memorising finite multiplication tables, and it verifiably generalises to primes never seen
25
+ in training.
26
 
27
  ## The idea
28
 
 
48
  `2p`), which is why it generalises across primes where the full bilinear map does not:
49
  held-out-prime validation accuracy tracks training accuracy throughout (no memorisation gap).
50
 
51
+ ## Five weight-sets, routed by prime size
52
 
53
+ The recurrence is exact only if the state is wide enough to hold the residue, so each cell is
54
+ trained per bit-width but because the dilated convolution is weight-shared across bit-positions
55
+ and the carry/borrow rule is position-invariant, **one shared weight-set serves the four mid
56
+ widths 64/128/256/512** (run at each prime's native width). The model therefore ships **five
57
+ weight-sets** and routes each problem to the narrowest cell whose state holds its prime:
58
 
59
+ | Weight file | Primes | Tiers | Architecture | Params | Public benchmark |
60
  |---|---|---|---|---|---|
61
+ | `weights16.pt` | `< 2^16` | 1-3 | carry-aware TCN, 6 blocks, dil 1..8 | ~2.4M | tiers 1-3 = 1.00 |
62
+ | `weights32.pt` | `< 2^32` | 4 | carry-aware TCN, 8 blocks, dil 1..16 | ~3.2M | tier 4 = 1.00 |
63
+ | `weights_shared_64_512.pt` | `< 2^512` | 5-8 | carry-aware TCN, 14 blocks, dil 1..256 — **one shared set**, run at native width | ~5.5M | tier 5 = 1.00, tier 6 = 0.98, tier 7 = 1.00, tier 8 = 1.00 |
64
+ | `weights1024.pt` | `< 2^1024` | 9 | carry-aware TCN, 12 blocks, dil 1..512 | ~4.7M | tier 9 = 0.99 |
65
+ | `weights2048.pt` | `< 2^2048` | 10 | carry-aware TCN, 13 blocks, dil 1..1024 | ~5.1M | tier 10 = 1.00 |
66
+
67
+ The four separate mid-width cells it replaced (0.99 / 0.97 / 0.98 / 0.98, ~17M params combined)
68
+ were collapsed into the single shared set, which matches or beats them at ~5.5M total **~21M
69
+ params, 0.13 GB**. For `p >= 2^2048` (outside all regimes) the model emits the honest `[0]`
70
+ fallback without invoking the network.
 
71
 
72
  ## The carry-aware TCN (every tier)
73
 
 
136
  loss, this took tier 9 from **0.73 -> 0.99**, even across prime widths (held-out value-uniform
137
  validation 0.99; per-width 1015-1024 all ~0.99).
138
 
139
+ The training scripts live in the companion research repo (not shipped in this model repo); the
140
+ commands below document *how the weights were obtained* (the provenance the rules ask for):
141
+
142
  ```bash
143
+ # small-prime cells, width-matched (bit-length-uniform over the cell's whole range + value-uniform)
144
+ python exploration/train_horner_tcn.py --bits 16 --lo-bits 2 --bitlen-frac 0.65 --bitlen-lo 2 # -> weights16.pt (tiers 1-3)
145
+ python exploration/train_horner_tcn.py --bits 32 --lo-bits 17 --bitlen-frac 0.6 --bitlen-lo 17 # -> weights32.pt (tier 4)
146
+
147
+ # ONE shared mid-width set for tiers 5-8: warm-start from the dedicated 512-bit cell, then
148
+ # fine-tune on a {64,128,256,512}-bit mix (the carry rule is width-portable) -> weights_shared_64_512.pt
149
+ python exploration/train_unified.py --warm --init-from weights512.pt --widths 64,128,256,512
150
  ```
151
 
152
  The **1024-bit (tier-9) cell is a multi-stage curriculum**, not a single run — the carry
train.py DELETED
@@ -1,221 +0,0 @@
1
- """Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning.
2
-
3
- Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2},
4
- easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining.
5
-
6
- Stage 2 (optional, default off): fine-tune end-to-end through the 16-step
7
- chain with a straight-through estimator on the quantized state, loss on every
8
- step's ground-truth intermediate. In practice this was destructive at lr2=5e-5
9
- (chain val collapsed); the shipped weights come from stage 1 alone, which
10
- reaches chain val ~0.998 on held-out primes. Kept for further experimentation
11
- at lower learning rates.
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import argparse
17
- import time
18
-
19
- import sys
20
- from pathlib import Path
21
-
22
- import torch
23
- import torch.nn as nn
24
-
25
- # Import the shared architecture from the sibling model.py.
26
- HERE = Path(__file__).resolve().parent
27
- sys.path.insert(0, str(HERE))
28
- from model import HornerCell, BITS, _to_bits as to_bits # noqa: E402
29
-
30
-
31
- def sieve_primes(limit: int) -> list[int]:
32
- is_p = bytearray([1]) * limit
33
- is_p[0] = is_p[1] = 0
34
- for i in range(2, int(limit ** 0.5) + 1):
35
- if is_p[i]:
36
- is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
37
- return [i for i in range(2, limit) if is_p[i]]
38
-
39
-
40
- def sample_batch(primes_t, n, device, hard_frac=0.5):
41
- p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
42
- b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
43
- bit = torch.randint(0, 2, (n,), device=device)
44
- n_hard = int(n * hard_frac)
45
- t = torch.empty(n, dtype=torch.long, device=device)
46
- t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long()
47
- if n_hard:
48
- ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard]
49
- q = torch.randint(0, 3, (n_hard,), device=device)
50
- delta = torch.randint(-2, 3, (n_hard,), device=device)
51
- th = (q * ph + delta - bith * bh) >> 1
52
- t[:n_hard] = th.clamp(min=0) % ph
53
- z = (2 * t + bit * b) % p
54
- return t, bit, b, p, z
55
-
56
-
57
- @torch.no_grad()
58
- def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float:
59
- ok = 0
60
- for i in range(0, n, bs):
61
- m = min(bs, n - i)
62
- t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0)
63
- logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
64
- ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item()
65
- return ok / n
66
-
67
-
68
- @torch.no_grad()
69
- def chain_exact_rate(model, primes_t, device, n=20_000) -> float:
70
- p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
71
- a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
72
- b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
73
- truth = (a * b) % p
74
- bb, pb = to_bits(b), to_bits(p)
75
- tb = torch.zeros(n, BITS, device=device)
76
- for i in range(BITS - 1, -1, -1):
77
- bit = ((a >> i) & 1).float().unsqueeze(1)
78
- tb = (model(tb, bit, bb, pb) > 0).float()
79
- pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1)
80
- return (pred == truth).float().mean().item()
81
-
82
-
83
- def chain_finetune_batch(model, primes_t, n, device, loss_fn):
84
- """One end-to-end pass: STE state, per-step CE against true intermediates."""
85
- p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
86
- a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
87
- b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
88
- bb, pb = to_bits(b), to_bits(p)
89
- tb = torch.zeros(n, BITS, device=device)
90
- t_true = torch.zeros_like(a)
91
- loss = torch.zeros((), device=device)
92
- for i in range(BITS - 1, -1, -1):
93
- bit_i = (a >> i) & 1
94
- t_true = (2 * t_true + bit_i * b) % p
95
- logits = model(tb, bit_i.float().unsqueeze(1), bb, pb)
96
- loss = loss + loss_fn(logits, to_bits(t_true))
97
- hard = (logits > 0).float()
98
- soft = torch.sigmoid(logits)
99
- tb = hard + (soft - soft.detach()) # straight-through
100
- return loss / BITS
101
-
102
-
103
- def main() -> int:
104
- ap = argparse.ArgumentParser()
105
- ap.add_argument("--stage1-minutes", type=float, default=50.0)
106
- ap.add_argument("--stage2-minutes", type=float, default=0.0)
107
- ap.add_argument("--batch", type=int, default=32768)
108
- ap.add_argument("--chain-batch", type=int, default=4096)
109
- ap.add_argument("--lr", type=float, default=3e-4)
110
- ap.add_argument("--lr2", type=float, default=5e-5)
111
- ap.add_argument("--width", type=int, default=4096)
112
- ap.add_argument("--depth", type=int, default=4)
113
- ap.add_argument("--init", type=str, default="")
114
- ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt"))
115
- args = ap.parse_args()
116
-
117
- device = torch.device("cuda")
118
- torch.manual_seed(0)
119
-
120
- small = sieve_primes(256)
121
- primes = [p for p in sieve_primes(1 << 16) if p >= 256]
122
- g = torch.Generator().manual_seed(1)
123
- perm = torch.randperm(len(primes), generator=g).tolist()
124
- val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device)
125
- train_primes = torch.tensor(
126
- small + [primes[i] for i in perm[len(primes) // 10 :]], device=device
127
- )
128
- print(f"train primes {len(train_primes)}, val primes {len(val_primes)}")
129
-
130
- model = HornerCell(args.width, args.depth).to(device)
131
- if args.init:
132
- ckpt = torch.load(args.init, map_location=device, weights_only=True)
133
- model.load_state_dict(ckpt["state_dict"])
134
- print(f"initialised from {args.init}")
135
- ema = HornerCell(args.width, args.depth).to(device)
136
- ema.load_state_dict(model.state_dict())
137
- for q in ema.parameters():
138
- q.requires_grad_(False)
139
- print(f"params: {sum(t.numel() for t in model.parameters()):,}")
140
- loss_fn = nn.BCEWithLogitsLoss()
141
- EMA_DECAY = 0.999
142
-
143
- def update_ema():
144
- with torch.no_grad():
145
- for q, w in zip(ema.parameters(), model.parameters()):
146
- q.lerp_(w, 1 - EMA_DECAY)
147
-
148
- best_chain = -1.0
149
-
150
- def save_if_best(tag):
151
- nonlocal best_chain
152
- ch = chain_exact_rate(ema, val_primes, device)
153
- if ch > best_chain:
154
- best_chain = ch
155
- torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out)
156
- return ch
157
-
158
- # ----- Stage 1: cell training -----
159
- if args.stage1_minutes > 0:
160
- opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
161
- total_steps = int(args.stage1_minutes * 60 * 16)
162
- sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02)
163
- deadline = time.monotonic() + args.stage1_minutes * 60
164
- start = time.monotonic()
165
- step = 0
166
- while time.monotonic() < deadline:
167
- t, bit, b, p, z = sample_batch(train_primes, args.batch, device)
168
- logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
169
- loss = loss_fn(logits, to_bits(z))
170
- opt.zero_grad()
171
- loss.backward()
172
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
173
- opt.step()
174
- if step < total_steps:
175
- sched.step()
176
- update_ema()
177
- step += 1
178
- if step % 1000 == 0:
179
- va = exact_rate(ema, val_primes, device, n=100_000)
180
- ch = save_if_best("s1")
181
- print(
182
- f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
183
- f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
184
- flush=True,
185
- )
186
-
187
- # ----- Stage 2: end-to-end chain fine-tuning (STE) -----
188
- if args.stage2_minutes > 0:
189
- opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5)
190
- total_steps = int(args.stage2_minutes * 60 * 3)
191
- sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1)
192
- deadline = time.monotonic() + args.stage2_minutes * 60
193
- start = time.monotonic()
194
- step = 0
195
- while time.monotonic() < deadline:
196
- loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn)
197
- opt.zero_grad()
198
- loss.backward()
199
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
200
- opt.step()
201
- if step < total_steps:
202
- sched.step()
203
- update_ema()
204
- step += 1
205
- if step % 200 == 0:
206
- va = exact_rate(ema, val_primes, device, n=100_000)
207
- ch = save_if_best("s2")
208
- print(
209
- f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
210
- f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
211
- flush=True,
212
- )
213
-
214
- va = exact_rate(ema, val_primes, device, n=500_000)
215
- ch = chain_exact_rate(ema, val_primes, device, n=50_000)
216
- print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}")
217
- return 0
218
-
219
-
220
- if __name__ == "__main__":
221
- raise SystemExit(main())