File size: 5,969 Bytes
6b93c3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | #!/usr/bin/env python3
"""Train a small transformer over the 512 bits of (x, y) to predict k_state.
This is the model class most likely to catch XOR/bit-level interactions if any exist.
Output: results/bit_xformer.pt, plus appends metrics to results/metrics.json
"""
import os, time, json, sys, math
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, roc_auc_score
# ---- curve params ----
p_field = 2**256 - 2**32 - 977
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424
def inv(a): return pow(a, p_field-2, p_field)
def add(P, Q):
if P is None: return Q
if Q is None: return P
x1,y1=P; x2,y2=Q
if x1==x2 and (y1+y2)%p_field==0: return None
m=(3*x1*x1)*inv(2*y1)%p_field if P==Q else (y2-y1)*inv(x2-x1)%p_field
x3=(m*m-x1-x2)%p_field
return (x3,(m*(x1-x3)-y1)%p_field)
def gen_bits(N=1_000_000):
"""Walk k=1..N incrementally, return (bits[N,512] as uint8, labels[N] as int8)."""
print(f"generating {N} (x,y) -> 512-bit vectors...", flush=True)
G = (Gx, Gy)
P = None
bits = np.empty((N, 512), dtype=np.uint8)
labels = np.empty(N, dtype=np.int8)
t0 = time.time()
LOG = max(1, N // 20)
for i in range(N):
k = i + 1
P = add(P, G)
x, y = P
# pack 256 bits of x then 256 bits of y, MSB first
for j in range(256):
bits[i, j] = (x >> (255 - j)) & 1
bits[i, 256 + j] = (y >> (255 - j)) & 1
labels[i] = k & 1
if (i+1) % LOG == 0:
r = (i+1)/(time.time()-t0)
print(f" {i+1}/{N} ({(i+1)/N*100:5.1f}%) rate={r:.0f}/s ETA={(N-i-1)/r:.0f}s", flush=True)
print(f" done in {time.time()-t0:.1f}s")
return bits, labels
class BitTransformer(nn.Module):
def __init__(self, seq_len=512, d=128, nhead=4, nlayers=4):
super().__init__()
self.tok = nn.Embedding(2, d) # bit 0 or 1
self.pos = nn.Parameter(torch.randn(1, seq_len, d) * 0.02)
self.cls = nn.Parameter(torch.randn(1, 1, d) * 0.02)
enc = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=4*d,
batch_first=True, activation="gelu", norm_first=True)
self.enc = nn.TransformerEncoder(enc, num_layers=nlayers)
self.head = nn.Linear(d, 1)
def forward(self, x_bits): # (B, 512) int
h = self.tok(x_bits) + self.pos # (B, 512, d)
cls = self.cls.expand(h.size(0), -1, -1)
h = torch.cat([cls, h], dim=1) # (B, 513, d)
h = self.enc(h)
return self.head(h[:, 0, :]).squeeze(1) # logits (B,)
def split_seq(N, ftr=0.70, fva=0.15):
i1 = int(N*ftr); i2 = int(N*(ftr+fva))
return (slice(0, i1), slice(i1, i2), slice(i2, N))
def main(N=1_000_000, epochs=8, batch=512, d=128, nlayers=4):
bits, labels = gen_bits(N)
device = "cuda" if torch.cuda.is_available() else "cpu"
tr, va, ho = split_seq(N)
print(f"split sizes: train={tr.stop-tr.start} val={va.stop-va.start} ho={ho.stop-ho.start}")
bits_t = torch.tensor(bits, dtype=torch.long)
labels_t = torch.tensor(labels, dtype=torch.float32)
model = BitTransformer(d=d, nlayers=nlayers).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"model: BitTransformer d={d} layers={nlayers} params={n_params/1e6:.2f}M device={device}")
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
crit = nn.BCEWithLogitsLoss()
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_val_auc = 0.0
for ep in range(epochs):
model.train()
idx = torch.randperm(tr.stop)
t0 = time.time()
losses = []
for i in range(0, len(idx), batch):
b = idx[i:i+batch]
xb = bits_t[b].to(device, non_blocking=True)
yb = labels_t[b].to(device, non_blocking=True)
logit = model(xb)
loss = crit(logit, yb)
opt.zero_grad(); loss.backward(); opt.step()
losses.append(float(loss.item()))
sched.step()
# val
model.eval()
with torch.no_grad():
pv = []
for i in range(va.start, va.stop, 4096):
b = slice(i, min(i+4096, va.stop))
pv.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
pv = np.concatenate(pv)
yv = labels[va.start:va.stop]
acc = accuracy_score(yv, pv>0.5); auc = roc_auc_score(yv, pv)
print(f"epoch {ep+1}/{epochs} loss={np.mean(losses):.4f} val_acc={acc:.4f} val_auc={auc:.4f} ({time.time()-t0:.0f}s)", flush=True)
if auc > best_val_auc:
best_val_auc = auc
torch.save(model.state_dict(), "results/bit_xformer.pt")
# holdout
model.eval()
with torch.no_grad():
ph = []
for i in range(ho.start, ho.stop, 4096):
b = slice(i, min(i+4096, ho.stop))
ph.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
ph = np.concatenate(ph)
yh = labels[ho.start:ho.stop]
ho_acc = accuracy_score(yh, ph>0.5); ho_auc = roc_auc_score(yh, ph)
print(f"\nHOLDOUT acc={ho_acc:.4f} auc={ho_auc:.4f}")
# append to metrics.json
try:
with open("results/metrics.json") as f: m = json.load(f)
except FileNotFoundError:
m = {}
m["bit_transformer"] = {
"val_acc": float(acc), "val_auc": float(auc),
"ho_acc": float(ho_acc), "ho_auc": float(ho_auc),
"params": int(n_params), "epochs": epochs, "d": d, "layers": nlayers,
}
with open("results/metrics.json", "w") as f: json.dump(m, f, indent=2)
print("metrics saved.")
if __name__ == "__main__":
main()
|