modmul / README.md
cire77's picture
Add model card
cd4d949 verified
|
Raw
History Blame Contribute Delete
2.08 kB
---
license: mit
library_name: pytorch
tags:
- modular-arithmetic
- scratchpad
- grokking
- transformer
- sair-challenge
---
# modmul-tier3 β€” learned modular multiplication
A submission to the **SAIR Modular Arithmetic Challenge**: compute `(a Β· b) mod p`
with the arithmetic produced entirely by **trained parameters** β€” no `%`, `//`,
Barrett, Montgomery, or CRT on the product.
## What it does
Two trained networks behind one interface, routed by the size of `p`:
| `p` range | Model |
|---|---|
| `p < 512` (tiers 1–2) | joint-attention Transformer with a classification head over the residue |
| `512 ≀ p < 65536` (tier 3) | **autoregressive "abacus" decoder** that emits an interleaved modular-multiply scratchpad |
| `p β‰₯ 65536` | out of trained range β†’ returns `[0]` |
Operands are reduced per-argument (`a mod p`, `b mod p`) before the network runs;
the **product's** reduction is learned.
## The tier-3 idea
A one-shot 5-digit Γ— 5-digit multiply is a hard learnability wall. Instead the model
runs **Horner's method**, reducing mod `p` at every step so no intermediate ever
exceeds ~6 digits β€” the regime the decoder handles reliably. Every intermediate is
written to the scratchpad and supervised:
```
BOS x MUL y MOD p EQ d:q1:r1:pp:t:q2:r2 STEP … EOS
```
per `y`-digit: shift-reduce (`q1,r1`), single-digit partial product (`pp`),
the explicit add `t = r1 + pp`, then add-reduce (`q2,r2`). Emitting `t` explicitly
was the decisive fix β€” leaving that addition implicit capped end-to-end accuracy.
## Results
Verified through the official `modchallenge evaluate-hf` pipeline:
| Tier | Accuracy |
|---|---|
| 1 | 1.00 |
| 2 | 0.99 |
| **3** | **1.00** |
**`highest_tier_above_90 = 3`**, deterministic, static-compliance check passed.
A 30-seed local sweep gave htop90 = 3 on all 30 (tier-3 min 0.95 / mean 0.986).
## Files
- `model.py` β€” the `EBMModMul` entry class (loader + abacus decoder)
- `manifest.json` β€” entry class, `output_base=10`, descriptions
- `weights.pt` β€” trained weights for both routed models