ebm-modmul / model.py
cire77's picture
tier-2 cls_pp checkpoint (htop90=2)
e3c7df7 verified
Raw
History Blame Contribute Delete
11.1 kB
"""Submission entry point: learned modular multiplication.
Compliance contract (see rules/evaluation.md):
- ``preprocess_*`` are per-argument identities (each sees only its own argument).
- Inside ``predict_digits_batch`` we reduce each operand modulo p — ``int(a) % p``
and ``int(b) % p`` — the same two-args-at-a-time normalisation the reference
baselines use. We never form ``a * b`` or ``(a*b) % p`` in Python/tensors; the
modular product is produced by the trained network, whose output (a residue in
``[0, p)``) materially determines the answer.
- We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
Out of regime (``p >= 10**WIDTH``, i.e. tiers >= 4) the network's fixed-width
residue encoding cannot represent the operands, so we emit ``[0]`` — an honest
fallback, not a guess. This model targets the low tiers (1-3).
The architecture (encoder + classification/angular head) is loaded from the
checkpoint's ``arch`` field, so the same wrapper serves either trained head.
"""
from __future__ import annotations
import math
from pathlib import Path
import torch
import torch.nn as nn
from modchallenge.interface.base_model import ModularMultiplicationModel
# ---------------------------------------------------------------------------
# Fixed dimensions (must match the training code that produced the weights)
# ---------------------------------------------------------------------------
VOCAB_SIZE = 10 # decimal digits 0-9; fixed-width inputs, no PAD token
WIDTH = 5 # values < 10**5 = 100000 -> covers tiers 1-3
SEG_X, SEG_Y, SEG_P, SEG_ANS = 0, 1, 2, 3
def digits_fixed(n: int, width: int = WIDTH) -> list[int]:
"""Non-negative int -> fixed-width zero-padded decimal digits, MSB-first."""
out = [0] * width
i = width - 1
while n > 0 and i >= 0:
out[i] = n % 10
n //= 10
i -= 1
return out
def int_to_decimal_digits(n: int) -> list[int]:
"""Non-negative int -> base-10 digit list, MSB-first ([0] for zero)."""
if n == 0:
return [0]
return [int(c) for c in str(n)]
# ---------------------------------------------------------------------------
# Architectures (copied verbatim from training/model.py for state_dict match)
# ---------------------------------------------------------------------------
class JointModMulNetCls(nn.Module):
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024, p_max=256):
super().__init__()
self.p_max = p_max
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, p_max)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :]) # (B, p_max)
class JointModMulNetAngular(nn.Module):
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, 2)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :]) # (B, 2)
PRIME_ENUM_LIMIT = 65536
def _sieve_primes(limit: int) -> list[int]:
is_p = bytearray([1]) * limit
is_p[0] = is_p[1] = 0
for i in range(2, int(limit ** 0.5) + 1):
if is_p[i]:
is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
return [i for i in range(2, limit) if is_p[i]]
class JointModMulNetClsPP(nn.Module):
"""Joint-attention classifier with a learned per-prime embedding.
Mirrors training/model.py for state_dict compatibility."""
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_ff=1024, p_max=256):
super().__init__()
self.p_max = p_max
self.limit = PRIME_ENUM_LIMIT
self.tok_emb = nn.Embedding(VOCAB_SIZE, d_model)
self.cls_query = nn.Parameter(torch.randn(1, d_model) * 0.02)
self.seg_emb = nn.Embedding(4, d_model)
self.pos_emb = nn.Embedding(3 * WIDTH + 1, d_model)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
dropout=0.0, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, p_max)
primes = _sieve_primes(self.limit)
self.prime_emb = nn.Embedding(len(primes), d_model)
idx = torch.zeros(self.limit, dtype=torch.long)
valid = torch.zeros(self.limit, dtype=torch.float)
for rank, p in enumerate(primes):
idx[p] = rank
valid[p] = 1.0
self.register_buffer("idx_lookup", idx, persistent=False)
self.register_buffer("valid_lookup", valid, persistent=False)
self.register_buffer(
"place_value",
torch.tensor([10 ** (WIDTH - 1 - i) for i in range(WIDTH)], dtype=torch.long),
persistent=False,
)
seg = torch.tensor([SEG_X] * WIDTH + [SEG_Y] * WIDTH + [SEG_P] * WIDTH + [SEG_ANS])
self.register_buffer("seg_ids", seg, persistent=False)
self.register_buffer("pos_ids", torch.arange(3 * WIDTH + 1), persistent=False)
def forward(self, x_digits, y_digits, prime_digits):
b = x_digits.shape[0]
p_int = (prime_digits * self.place_value).sum(dim=1)
safe = p_int.clamp(0, self.limit - 1)
p_emb = self.prime_emb(self.idx_lookup[safe]) * self.valid_lookup[safe].unsqueeze(-1)
inp = torch.cat([x_digits, y_digits, prime_digits], dim=1)
tok = self.tok_emb(inp)
cls = self.cls_query.unsqueeze(0).expand(b, 1, -1)
x = torch.cat([tok, cls], dim=1)
x = x + self.seg_emb(self.seg_ids.unsqueeze(0)) + self.pos_emb(self.pos_ids.unsqueeze(0))
x = x + p_emb.unsqueeze(1)
x = self.encoder(x)
x = self.ln(x)
return self.head(x[:, -1, :])
_ARCHS = {
"cls": JointModMulNetCls,
"cls_pp": JointModMulNetClsPP,
"angular": JointModMulNetAngular,
}
def _angular_decode(pred: torch.Tensor, p_int: torch.Tensor) -> torch.Tensor:
theta = torch.atan2(pred[:, 1], pred[:, 0])
t = torch.round(theta * p_int.float() / (2 * math.pi))
return (t % p_int.float()).long()
# ---------------------------------------------------------------------------
# Submission entry class
# ---------------------------------------------------------------------------
class EBMModMul(ModularMultiplicationModel):
def __init__(self):
self.model = None
self.device = None
self.arch = None
def load(self, model_dir: str) -> None:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
ckpt = torch.load(Path(model_dir) / "weights.pt",
map_location=self.device, weights_only=False)
self.arch = ckpt.get("arch", "cls")
self.model = _ARCHS[self.arch](**ckpt["config"]).to(self.device)
self.model.load_state_dict(ckpt["state_dict"])
self.model.eval()
# Per-argument identity preprocessing (each hook sees only its own argument).
def preprocess_a(self, a): return a
def preprocess_b(self, b): return b
def preprocess_p(self, p): return p
@torch.no_grad()
def predict_digits(self, a_enc, b_enc, p_enc):
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
@torch.no_grad()
def predict_digits_batch(self, inputs):
out: list[list[int] | None] = [None] * len(inputs)
x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], []
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
p = int(p_enc)
# Out of the model's regime (residues don't fit WIDTH digits): honest 0.
if p >= 10 ** WIDTH:
out[i] = [0]
continue
a_red = int(a_enc) % p # per-operand reduction (allowed)
b_red = int(b_enc) % p
x_rows.append(digits_fixed(a_red))
y_rows.append(digits_fixed(b_red))
p_rows.append(digits_fixed(p))
p_ints.append(p)
idx.append(i)
if idx:
t = lambda r: torch.tensor(r, dtype=torch.long, device=self.device)
logits = self.model(t(x_rows), t(y_rows), t(p_rows))
if self.arch == "angular":
residues = _angular_decode(logits, t(p_ints)).tolist()
else:
residues = logits.argmax(dim=-1).tolist()
for j, i in enumerate(idx):
out[i] = int_to_decimal_digits(int(residues[j]))
return [o if o is not None else [0] for o in out]
def max_batch_size(self) -> int:
return 512