modmul / model.py
cire77's picture
tier5 route [2^32,2^64) to base-16 modmul decoder (htop90=5)
9e19660 verified
Raw
History Blame Contribute Delete
27.4 kB
"""Submission entry point: learned modular multiplication.
Compliance contract (see rules/evaluation.md):
- ``preprocess_*`` are per-argument identities (each sees only its own argument).
- Inside ``predict_digits_batch`` we reduce each operand modulo p — ``int(a) % p``
and ``int(b) % p`` — the same two-args-at-a-time normalisation the reference
baselines use. We never form ``a * b`` or ``(a*b) % p`` in Python/tensors; the
modular product is produced by the trained network, whose output (a residue in
``[0, p)``) materially determines the answer.
- We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
Routing by prime size: tiers 1-2 (p < 512) use the classification head; tiers 3-5
use the interleaved modular-multiply scratchpad decoder (same architecture,
separately trained weights) — tier 3 (512 <= p < 65536) and tier 4
(65536 <= p < 2**32) in numeric base 10, tier 5 (2**32 <= p < 2**64) in numeric
base 16 (shorter Horner chain at large prime sizes). p >= 2**64 (tiers 6+) is out
of regime, so we emit ``[0]`` — an honest fallback, not a guess.
The architecture (encoder + classification/angular head) is loaded from the
checkpoint's ``arch`` field, so the same wrapper serves either trained head.
"""
from __future__ import annotations
import math
from collections import defaultdict
from pathlib import Path
import torch
import torch.nn as nn
from modchallenge.interface.base_model import ModularMultiplicationModel
# ---------------------------------------------------------------------------
# Fixed dimensions (must match the training code that produced the weights)
# ---------------------------------------------------------------------------
VOCAB_SIZE = 10 # decimal digits 0-9; fixed-width inputs, no PAD token
WIDTH = 5 # values < 10**5 = 100000 -> covers tiers 1-3
SEG_X, SEG_Y, SEG_P, SEG_ANS = 0, 1, 2, 3
def digits_fixed(n: int, width: int = WIDTH) -> list[int]:
"""Non-negative int -> fixed-width zero-padded decimal digits, MSB-first."""
out = [0] * width
i = width - 1
while n > 0 and i >= 0:
out[i] = n % 10
n //= 10
i -= 1
return out
def int_to_decimal_digits(n: int) -> list[int]:
"""Non-negative int -> base-10 digit list, MSB-first ([0] for zero)."""
if n == 0:
return [0]
return [int(c) for c in str(n)]
# ---------------------------------------------------------------------------
# Architectures (copied verbatim from training/model.py for state_dict match)
# ---------------------------------------------------------------------------
class JointModMulNetCls(nn.Module):
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024, p_max=256):
super().__init__()
self.p_max = p_max
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, p_max)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :]) # (B, p_max)
class JointModMulNetAngular(nn.Module):
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, 2)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :]) # (B, 2)
PRIME_ENUM_LIMIT = 65536
def _sieve_primes(limit: int) -> list[int]:
is_p = bytearray([1]) * limit
is_p[0] = is_p[1] = 0
for i in range(2, int(limit ** 0.5) + 1):
if is_p[i]:
is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
return [i for i in range(2, limit) if is_p[i]]
class JointModMulNetClsPP(nn.Module):
"""Joint-attention classifier with a learned per-prime embedding.
Mirrors training/model.py for state_dict compatibility."""
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024, p_max=256):
super().__init__()
self.p_max = p_max
self.limit = PRIME_ENUM_LIMIT
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, p_max)
primes = _sieve_primes(self.limit)
self.prime_emb = nn.Embedding(len(primes), d_model)
idx = torch.zeros(self.limit, dtype=torch.long)
valid = torch.zeros(self.limit, dtype=torch.float)
for rank, p in enumerate(primes):
idx[p] = rank
valid[p] = 1.0
self.register_buffer("idx_lookup", idx, persistent=False)
self.register_buffer("valid_lookup", valid, persistent=False)
self.register_buffer(
"place_value",
torch.tensor([10 ** (WIDTH - 1 - i) for i in range(WIDTH)], dtype=torch.long),
persistent=False,
)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
p_int = (prime_digits * self.place_value).sum(dim=1)
safe = p_int.clamp(0, self.limit - 1)
p_emb = self.prime_emb(self.idx_lookup[safe]) * self.valid_lookup[safe].unsqueeze(-1)
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = x + p_emb.unsqueeze(1)
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :])
_ARCHS = {
"cls": JointModMulNetCls,
"cls_pp": JointModMulNetClsPP,
"angular": JointModMulNetAngular,
}
def _angular_decode(pred: torch.Tensor, p_int: torch.Tensor) -> torch.Tensor:
theta = torch.atan2(pred[:, 1], pred[:, 0])
t = torch.round(theta * p_int.float() / (2 * math.pi))
return (t % p_int.float()).long()
# ---------------------------------------------------------------------------
# Tier-3 interleaved modular-multiply scratchpad (autoregressive).
#
# Self-contained copy of the trained training/modmul_probe.py decoder + greedy
# decode. The network emits the schoolbook computation digit by digit:
# BOS x MUL y MOD p EQ d:q1:r1:pp:t:q2:r2 STEP ... EOS
# folding multiply and reduction into one Horner pass so no intermediate exceeds
# ~6 digits. Compliance: the only modular reduction in shipped code is the
# per-operand int(a)%p / int(b)%p done BEFORE the network runs; the product's
# reduction is produced entirely by trained parameters (greedy argmax over digit
# tokens). There is no %, //, Barrett, Montgomery or CRT applied to a*b anywhere.
# ---------------------------------------------------------------------------
MM_PAD, MM_BOS, MM_MUL, MM_MOD, MM_EQ, MM_COLON, MM_STEP, MM_EOS = 10, 11, 12, 13, 14, 15, 16, 17
MM_VOCAB = 18
MM_SPECIALS = {MM_PAD, MM_BOS, MM_MUL, MM_MOD, MM_EQ, MM_COLON, MM_STEP, MM_EOS}
def _digits_msb(n: int) -> list[int]:
if n == 0:
return [0]
s = []
while n > 0:
s.append(n % 10)
n //= 10
return s[::-1]
class AbacusDecoder(nn.Module):
"""Decoder-only transformer with abacus (place-within-number) embeddings.
Architecture identical to training/modmul_probe.py / tier5_modmul.py for
state_dict match. ``vocab`` defaults to the base-10 scratchpad vocab (18) used
by tiers 3-4; the tier-5 base-16 scratchpad passes vocab=24 (16 digits + 8
specials)."""
def __init__(self, max_len, abacus_max, d_model=384, nhead=8, num_layers=8,
dim_ff=1536, vocab=MM_VOCAB):
super().__init__()
self.tok_emb = nn.Embedding(vocab, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.abacus_emb = nn.Embedding(abacus_max, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab, bias=False)
self.max_len = max_len
self.register_buffer("pos_ids", torch.arange(max_len), persistent=False)
def forward(self, toks, abacus):
b, t = toks.shape
x = self.tok_emb(toks) + self.pos_emb(self.pos_ids[:t]) + self.abacus_emb(abacus)
mask = torch.triu(torch.full((t, t), float("-inf"), device=toks.device), diagonal=1)
x = self.transformer(x, mask=mask, is_causal=True)
return self.head(self.ln(x))
@torch.no_grad()
def _modmul_decode(model, cfg, xyp, device, chunk=128):
"""Greedy-decode (x*y) mod p for each (x, y, p) with x,y already in [0, p).
Returns a list of residue digit-lists (MSB-first), or [0] if unparseable.
Decodes in length-grouped chunks to bound memory."""
max_len, abmax = cfg["max_len"], cfg["abacus_max"]
specials = torch.tensor(sorted(MM_SPECIALS), device=device)
out: list[list[int] | None] = [None] * len(xyp)
groups = defaultdict(list)
prompts = []
for i, (x, y, p) in enumerate(xyp):
xd, yd, pd = _digits_msb(x), _digits_msb(y), _digits_msb(p)
toks = [MM_BOS] + xd + [MM_MUL] + yd + [MM_MOD] + pd + [MM_EQ]
abac = ([0] + list(range(len(xd))) + [0] + list(range(len(yd)))
+ [0] + list(range(len(pd))) + [0])
groups[len(toks)].append(i)
prompts.append((toks, abac))
for L, idxs in groups.items():
for s in range(0, len(idxs), chunk):
sub = idxs[s:s + chunk]
g = len(sub)
toks = torch.tensor([prompts[i][0] for i in sub], dtype=torch.long, device=device)
abac = torch.tensor([prompts[i][1] for i in sub], dtype=torch.long, device=device)
seg = torch.zeros(g, dtype=torch.long, device=device)
done = torch.zeros(g, dtype=torch.bool, device=device)
gen = [[] for _ in range(g)]
steps = 0
while toks.shape[1] < max_len and not bool(done.all()):
nxt = model(toks, abac)[:, -1].argmax(-1)
nxt = torch.where(done, torch.full_like(nxt, MM_PAD), nxt)
is_sp = (nxt.unsqueeze(1) == specials).any(1)
new_abac = torch.where(is_sp, torch.zeros_like(seg),
torch.clamp(seg, max=abmax - 1))
seg = torch.where(is_sp, torch.zeros_like(seg), seg + 1)
nc, dc = nxt.tolist(), done.tolist()
for j in range(g):
if not dc[j] and nc[j] != MM_EOS and nc[j] != MM_PAD:
gen[j].append(nc[j])
toks = torch.cat([toks, nxt.unsqueeze(1)], dim=1)
abac = torch.cat([abac, new_abac.unsqueeze(1)], dim=1)
done = done | (nxt == MM_EOS)
# The sequence grows one token per step, so the caching allocator
# holds a distinct buffer for every length (~800 on tier-4 chains)
# and OOMs mid-decode. Periodically release them.
steps += 1
if steps % 32 == 0:
if device.type == "mps":
torch.mps.empty_cache()
elif device.type == "cuda":
torch.cuda.empty_cache()
for j, i in enumerate(sub):
gj = gen[j]
if MM_COLON in gj:
k = len(gj) - 1 - gj[::-1].index(MM_COLON)
ans = [d for d in gj[k + 1:] if d < 10]
out[i] = ans if ans else [0]
else:
out[i] = [0]
# Release the chunk's activations: the caching allocator otherwise
# accumulates across length-groups/chunks (MPS in particular never
# frees mid-run) and OOMs on long tier-4 chains.
del toks, abac, seg, done, gen
if device.type == "mps":
torch.mps.empty_cache()
elif device.type == "cuda":
torch.cuda.empty_cache()
return [o if o is not None else [0] for o in out]
# ---------------------------------------------------------------------------
# Tier-5 base-16 modular-multiply scratchpad (autoregressive).
#
# Same AbacusDecoder architecture and 9-field Horner scratchpad as tiers 3-4, but
# trained in numeric BASE 16 (so the per-step partial products / quotient digits
# stay easy while the chain length is bounded). tier-5 primes are 33-64 bit, so the
# chain is ~16 base-16 Horner blocks (~1853 tokens). Vocab: digits 0..15 then
# PAD,BOS,MUL,MOD,EQ,COLON,STEP,EOS = base..base+7 (see tier5_modmul.make_vocab).
# The decoded answer is a BASE-16 residue; we convert it to base-10 digits with
# multiply-add only (no %, //, Barrett/Montgomery/CRT on the product) so it matches
# the global output_base=10. Compliance is unchanged: the only modular reduction in
# shipped code is the per-operand int(a)%p / int(b)%p done before the network runs.
# ---------------------------------------------------------------------------
def _make_vocab_base(base: int) -> dict:
"""Base-B scratchpad vocab, matching training/tier5_modmul.make_vocab."""
PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS = (
base, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6, base + 7)
return dict(PAD=PAD, BOS=BOS, MUL=MUL, MOD=MOD, EQ=EQ, COLON=COLON, STEP=STEP,
EOS=EOS, VOCAB=base + 8,
SPECIALS={PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS})
def _digits_base_msb(n: int, base: int) -> list[int]:
if n == 0:
return [0]
s = []
while n > 0:
s.append(n % base)
n //= base
return s[::-1]
def _base_to_int(ds: list[int], base: int) -> int:
v = 0
for d in ds:
v = v * base + d # multiply-add only; no %/// on the product
return v
@torch.no_grad()
def _modmul_decode_base(model, cfg, xyp, device, base, chunk=64):
"""Greedy-decode (x*y) mod p in numeric base ``base`` for each (x, y, p) with
x, y already in [0, p). Returns base-10 digit-lists (MSB-first), or [0] if
unparseable. Mirrors _modmul_decode but base-parametrized; the final base-``base``
residue is re-expressed in base 10 via multiply-add (compliant)."""
V = _make_vocab_base(base)
PAD, EOS, COLON = V["PAD"], V["EOS"], V["COLON"]
max_len, abmax = cfg["max_len"], cfg["abacus_max"]
specials = torch.tensor(sorted(V["SPECIALS"]), device=device)
out: list[list[int] | None] = [None] * len(xyp)
groups = defaultdict(list)
prompts = []
for i, (x, y, p) in enumerate(xyp):
xd, yd, pd = (_digits_base_msb(x, base), _digits_base_msb(y, base),
_digits_base_msb(p, base))
toks = [V["BOS"]] + xd + [V["MUL"]] + yd + [V["MOD"]] + pd + [V["EQ"]]
abac = ([0] + list(range(len(xd))) + [0] + list(range(len(yd)))
+ [0] + list(range(len(pd))) + [0])
groups[len(toks)].append(i)
prompts.append((toks, abac))
for L, idxs in groups.items():
for s in range(0, len(idxs), chunk):
sub = idxs[s:s + chunk]
g = len(sub)
toks = torch.tensor([prompts[i][0] for i in sub], dtype=torch.long, device=device)
abac = torch.tensor([prompts[i][1] for i in sub], dtype=torch.long, device=device)
seg = torch.zeros(g, dtype=torch.long, device=device)
done = torch.zeros(g, dtype=torch.bool, device=device)
gen = [[] for _ in range(g)]
while toks.shape[1] < max_len and not bool(done.all()):
nxt = model(toks, abac)[:, -1].argmax(-1)
nxt = torch.where(done, torch.full_like(nxt, PAD), nxt)
is_sp = (nxt.unsqueeze(1) == specials).any(1)
new_abac = torch.where(is_sp, torch.zeros_like(seg),
torch.clamp(seg, max=abmax - 1))
seg = torch.where(is_sp, torch.zeros_like(seg), seg + 1)
nc, dc = nxt.tolist(), done.tolist()
for j in range(g):
if not dc[j] and nc[j] != EOS and nc[j] != PAD:
gen[j].append(nc[j])
toks = torch.cat([toks, nxt.unsqueeze(1)], dim=1)
abac = torch.cat([abac, new_abac.unsqueeze(1)], dim=1)
done = done | (nxt == EOS)
for j, i in enumerate(sub):
gj = gen[j]
if COLON in gj:
k = len(gj) - 1 - gj[::-1].index(COLON)
ans = [d for d in gj[k + 1:] if d < base]
out[i] = int_to_decimal_digits(_base_to_int(ans, base)) if ans else [0]
else:
out[i] = [0]
del toks, abac, seg, done, gen
# Only MPS needs the cache drop (it never frees mid-run and OOMs on
# long chains). On CUDA, empty_cache() synchronizes + forces the
# allocator to re-acquire buffers every chunk -- a ~5x slowdown over
# the ~1853-step base-16 decode -- and tier-5 peak is <3 GB anyway.
if device.type == "mps":
torch.mps.empty_cache()
return [o if o is not None else [0] for o in out]
# ---------------------------------------------------------------------------
# Submission entry class
# ---------------------------------------------------------------------------
class EBMModMul(ModularMultiplicationModel):
def __init__(self):
self.model = None
self.device = None
self.arch = None
self.mm = None # tier-3 modmul scratchpad
self.mm_cfg = None
self.mm4 = None # tier-4 modmul scratchpad
self.mm4_cfg = None
self.mm5 = None # tier-5 base-16 modmul scratchpad
self.mm5_cfg = None
def load(self, model_dir: str) -> None:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
ckpt = torch.load(Path(model_dir) / "weights.pt",
map_location=self.device, weights_only=False)
# Tiers 1-2: the classification/angular head (banked).
self.arch = ckpt.get("arch", "cls")
self.model = _ARCHS[self.arch](**ckpt["config"]).to(self.device)
self.model.load_state_dict(ckpt["state_dict"])
self.model.eval()
# Tier 3: the interleaved modular-multiply scratchpad (optional bundle).
if "tier3" in ckpt:
c = ckpt["tier3"]["config"]
self.mm_cfg = c
self.mm = AbacusDecoder(
max_len=c["max_len"], abacus_max=c["abacus_max"], d_model=c["d_model"],
nhead=c["nhead"], num_layers=c["layers"], dim_ff=c["dim_ff"],
).to(self.device)
self.mm.load_state_dict(ckpt["tier3"]["state_dict"])
self.mm.eval()
# Tier 4: same scratchpad architecture, trained on [2**17, 2**32).
if "tier4" in ckpt:
c4 = ckpt["tier4"]["config"]
self.mm4_cfg = c4
self.mm4 = AbacusDecoder(
max_len=c4["max_len"], abacus_max=c4["abacus_max"], d_model=c4["d_model"],
nhead=c4["nhead"], num_layers=c4["layers"], dim_ff=c4["dim_ff"],
).to(self.device)
self.mm4.load_state_dict(ckpt["tier4"]["state_dict"])
self.mm4.eval()
# Tier 5: base-16 scratchpad, trained on primes in [2**33, 2**64).
if "tier5" in ckpt:
c5 = ckpt["tier5"]["config"]
self.mm5_cfg = c5
self.mm5 = AbacusDecoder(
max_len=c5["max_len"], abacus_max=c5["abacus_max"], d_model=c5["d_model"],
nhead=c5["nhead"], num_layers=c5["layers"], dim_ff=c5["dim_ff"],
vocab=c5["base"] + 8,
).to(self.device)
self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
self.mm5.eval()
# Per-argument identity preprocessing (each hook sees only its own argument).
def preprocess_a(self, a): return a
def preprocess_b(self, b): return b
def preprocess_p(self, p): return p
@torch.no_grad()
def predict_digits(self, a_enc, b_enc, p_enc):
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
# Prime routing: tiers 1-2 (p < 512) use the classification head; tier 3
# (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) use the modmul scratchpad
# (separate trained weights); p >= 2**32 (tiers 5+) is out of regime.
# 512 = 2**9 is the tier-3 floor, 65536 = 2**16 the tier-3/4 boundary (TIERS).
TIER3_LO = 512
TIER3_HI = 65536
TIER4_HI = 2 ** 32
TIER5_HI = 2 ** 64
@torch.no_grad()
def predict_digits_batch(self, inputs):
out: list[list[int] | None] = [None] * len(inputs)
x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], [] # tiers 1-2
mm_items, mm_idx = [], [] # tier 3
mm4_items, mm4_idx = [], [] # tier 4
mm5_items, mm5_idx = [], [] # tier 5
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
p = int(p_enc)
# Out of regime (residues don't fit the trained range): honest 0.
if p >= self.TIER5_HI:
out[i] = [0]
continue
a_red = int(a_enc) % p # per-operand reduction (allowed)
b_red = int(b_enc) % p
if p >= self.TIER4_HI:
if self.mm5 is not None:
mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
else:
out[i] = [0]
elif p >= self.TIER3_HI:
if self.mm4 is not None:
mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
else:
out[i] = [0]
elif p >= self.TIER3_LO and self.mm is not None:
mm_items.append((a_red, b_red, p)); mm_idx.append(i)
else:
x_rows.append(digits_fixed(a_red))
y_rows.append(digits_fixed(b_red))
p_rows.append(digits_fixed(p))
p_ints.append(p)
idx.append(i)
if idx:
t = lambda r: torch.tensor(r, dtype=torch.long, device=self.device)
logits = self.model(t(x_rows), t(y_rows), t(p_rows))
if self.arch == "angular":
residues = _angular_decode(logits, t(p_ints)).tolist()
else:
residues = logits.argmax(dim=-1).tolist()
for j, i in enumerate(idx):
out[i] = int_to_decimal_digits(int(residues[j]))
if mm_items:
res = _modmul_decode(self.mm, self.mm_cfg, mm_items, self.device)
for j, i in enumerate(mm_idx):
out[i] = res[j]
if mm4_items:
# Tier-4 chains are ~800 tokens; without a KV-cache the per-step
# forward is O(L^2), so decode in small sub-batches to bound peak
# memory (a single batch of 100 OOMs on a 20 GB device).
res = _modmul_decode(self.mm4, self.mm4_cfg, mm4_items, self.device, chunk=16)
for j, i in enumerate(mm4_idx):
out[i] = res[j]
if mm5_items:
# Tier-5 base-16 chains are ~1853 tokens. The model was trained under
# bf16 autocast, and decoding in bf16 (not fp32) is what keeps the long
# attention both fast (~125s/100 vs ~470-660s in fp32) and within memory
# (<3 GB vs a fp32 OOM at 1853-length attention). Match training precision.
if self.device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
self.device, base=self.mm5_cfg["base"],
chunk=64)
else:
res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
self.device, base=self.mm5_cfg["base"], chunk=64)
for j, i in enumerate(mm5_idx):
out[i] = res[j]
return [o if o is not None else [0] for o in out]
def max_batch_size(self) -> int:
return 512