secp256k1-parity-prediction / scripts /bit_transformer.py
hackinet's picture
Initial upload: negative-result study on secp256k1 parity prediction.
6b93c3b verified
#!/usr/bin/env python3
"""Train a small transformer over the 512 bits of (x, y) to predict k_state.
This is the model class most likely to catch XOR/bit-level interactions if any exist.
Output: results/bit_xformer.pt, plus appends metrics to results/metrics.json
"""
import os, time, json, sys, math
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, roc_auc_score
# ---- curve params ----
p_field = 2**256 - 2**32 - 977
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424
def inv(a): return pow(a, p_field-2, p_field)
def add(P, Q):
if P is None: return Q
if Q is None: return P
x1,y1=P; x2,y2=Q
if x1==x2 and (y1+y2)%p_field==0: return None
m=(3*x1*x1)*inv(2*y1)%p_field if P==Q else (y2-y1)*inv(x2-x1)%p_field
x3=(m*m-x1-x2)%p_field
return (x3,(m*(x1-x3)-y1)%p_field)
def gen_bits(N=1_000_000):
"""Walk k=1..N incrementally, return (bits[N,512] as uint8, labels[N] as int8)."""
print(f"generating {N} (x,y) -> 512-bit vectors...", flush=True)
G = (Gx, Gy)
P = None
bits = np.empty((N, 512), dtype=np.uint8)
labels = np.empty(N, dtype=np.int8)
t0 = time.time()
LOG = max(1, N // 20)
for i in range(N):
k = i + 1
P = add(P, G)
x, y = P
# pack 256 bits of x then 256 bits of y, MSB first
for j in range(256):
bits[i, j] = (x >> (255 - j)) & 1
bits[i, 256 + j] = (y >> (255 - j)) & 1
labels[i] = k & 1
if (i+1) % LOG == 0:
r = (i+1)/(time.time()-t0)
print(f" {i+1}/{N} ({(i+1)/N*100:5.1f}%) rate={r:.0f}/s ETA={(N-i-1)/r:.0f}s", flush=True)
print(f" done in {time.time()-t0:.1f}s")
return bits, labels
class BitTransformer(nn.Module):
def __init__(self, seq_len=512, d=128, nhead=4, nlayers=4):
super().__init__()
self.tok = nn.Embedding(2, d) # bit 0 or 1
self.pos = nn.Parameter(torch.randn(1, seq_len, d) * 0.02)
self.cls = nn.Parameter(torch.randn(1, 1, d) * 0.02)
enc = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=4*d,
batch_first=True, activation="gelu", norm_first=True)
self.enc = nn.TransformerEncoder(enc, num_layers=nlayers)
self.head = nn.Linear(d, 1)
def forward(self, x_bits): # (B, 512) int
h = self.tok(x_bits) + self.pos # (B, 512, d)
cls = self.cls.expand(h.size(0), -1, -1)
h = torch.cat([cls, h], dim=1) # (B, 513, d)
h = self.enc(h)
return self.head(h[:, 0, :]).squeeze(1) # logits (B,)
def split_seq(N, ftr=0.70, fva=0.15):
i1 = int(N*ftr); i2 = int(N*(ftr+fva))
return (slice(0, i1), slice(i1, i2), slice(i2, N))
def main(N=1_000_000, epochs=8, batch=512, d=128, nlayers=4):
bits, labels = gen_bits(N)
device = "cuda" if torch.cuda.is_available() else "cpu"
tr, va, ho = split_seq(N)
print(f"split sizes: train={tr.stop-tr.start} val={va.stop-va.start} ho={ho.stop-ho.start}")
bits_t = torch.tensor(bits, dtype=torch.long)
labels_t = torch.tensor(labels, dtype=torch.float32)
model = BitTransformer(d=d, nlayers=nlayers).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"model: BitTransformer d={d} layers={nlayers} params={n_params/1e6:.2f}M device={device}")
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
crit = nn.BCEWithLogitsLoss()
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_val_auc = 0.0
for ep in range(epochs):
model.train()
idx = torch.randperm(tr.stop)
t0 = time.time()
losses = []
for i in range(0, len(idx), batch):
b = idx[i:i+batch]
xb = bits_t[b].to(device, non_blocking=True)
yb = labels_t[b].to(device, non_blocking=True)
logit = model(xb)
loss = crit(logit, yb)
opt.zero_grad(); loss.backward(); opt.step()
losses.append(float(loss.item()))
sched.step()
# val
model.eval()
with torch.no_grad():
pv = []
for i in range(va.start, va.stop, 4096):
b = slice(i, min(i+4096, va.stop))
pv.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
pv = np.concatenate(pv)
yv = labels[va.start:va.stop]
acc = accuracy_score(yv, pv>0.5); auc = roc_auc_score(yv, pv)
print(f"epoch {ep+1}/{epochs} loss={np.mean(losses):.4f} val_acc={acc:.4f} val_auc={auc:.4f} ({time.time()-t0:.0f}s)", flush=True)
if auc > best_val_auc:
best_val_auc = auc
torch.save(model.state_dict(), "results/bit_xformer.pt")
# holdout
model.eval()
with torch.no_grad():
ph = []
for i in range(ho.start, ho.stop, 4096):
b = slice(i, min(i+4096, ho.stop))
ph.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
ph = np.concatenate(ph)
yh = labels[ho.start:ho.stop]
ho_acc = accuracy_score(yh, ph>0.5); ho_auc = roc_auc_score(yh, ph)
print(f"\nHOLDOUT acc={ho_acc:.4f} auc={ho_auc:.4f}")
# append to metrics.json
try:
with open("results/metrics.json") as f: m = json.load(f)
except FileNotFoundError:
m = {}
m["bit_transformer"] = {
"val_acc": float(acc), "val_auc": float(auc),
"ho_acc": float(ho_acc), "ho_auc": float(ho_auc),
"params": int(n_params), "epochs": epochs, "d": d, "layers": nlayers,
}
with open("results/metrics.json", "w") as f: json.dump(m, f, indent=2)
print("metrics saved.")
if __name__ == "__main__":
main()