| |
| """Build an HTML page with model predictions for a range of k. |
| Columns: k | MLP | XGBoost | LightGBM |
| Each cell: 'O ✓' or 'E ✗' (O=odd, E=even ; check=correct, cross=wrong) |
| """ |
| import os, sys, time |
| import numpy as np |
| import pandas as pd |
| from sklearn.preprocessing import StandardScaler |
| import xgboost as xgb, lightgbm as lgb |
| import torch, torch.nn as nn |
|
|
| class BitTransformer(nn.Module): |
| def __init__(self, seq_len=512, d=128, nhead=4, nlayers=4): |
| super().__init__() |
| self.tok = nn.Embedding(2, d) |
| self.pos = nn.Parameter(torch.randn(1, seq_len, d) * 0.02) |
| self.cls = nn.Parameter(torch.randn(1, 1, d) * 0.02) |
| enc = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=4*d, |
| batch_first=True, activation="gelu", norm_first=True) |
| self.enc = nn.TransformerEncoder(enc, num_layers=nlayers) |
| self.head = nn.Linear(d, 1) |
| def forward(self, x_bits): |
| h = self.tok(x_bits) + self.pos |
| cls = self.cls.expand(h.size(0), -1, -1) |
| h = torch.cat([cls, h], dim=1) |
| h = self.enc(h) |
| return self.head(h[:, 0, :]).squeeze(1) |
|
|
| def bits_of(x, y): |
| arr = np.empty(512, dtype=np.int64) |
| for j in range(256): |
| arr[j] = (x >> (255 - j)) & 1 |
| arr[256 + j] = (y >> (255 - j)) & 1 |
| return arr |
|
|
| p = 2**256 - 2**32 - 977 |
| Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240 |
| Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424 |
|
|
| def inv(a): return pow(a, p-2, p) |
| def add(P, Q): |
| if P is None: return Q |
| if Q is None: return P |
| x1,y1=P; x2,y2=Q |
| if x1==x2 and (y1+y2)%p==0: return None |
| m=(3*x1*x1)*inv(2*y1)%p if P==Q else (y2-y1)*inv(x2-x1)%p |
| x3=(m*m-x1-x2)%p |
| return (x3,(m*(x1-x3)-y1)%p) |
| def mul(k, P): |
| R=None |
| while k: |
| if k&1: R=add(R,P) |
| P=add(P,P); k>>=1 |
| return R |
|
|
| def num_features(v, prefix): |
| s = str(v); digs = [int(c) for c in s] |
| return { |
| f"{prefix}_num_digits": len(s), f"{prefix}_first_digit": digs[0], |
| f"{prefix}_last_digit": digs[-1], f"{prefix}_last2": v % 100, |
| f"{prefix}_last3": v % 1000, f"{prefix}_digit_sum": sum(digs), |
| f"{prefix}_digit_sum_mod_9": sum(digs) % 9, |
| f"{prefix}_even_digit_count": sum(1 for d in digs if d%2==0), |
| f"{prefix}_odd_digit_count": sum(1 for d in digs if d%2==1), |
| f"{prefix}_zero_count": s.count("0"), |
| f"{prefix}_unique_digit_count": len(set(s)), |
| f"{prefix}_bit_length": v.bit_length(), |
| f"{prefix}_popcount": bin(v).count("1"), |
| f"{prefix}_state": v % 2, |
| f"{prefix}_mod_3": v % 3, f"{prefix}_mod_5": v % 5, |
| f"{prefix}_mod_7": v % 7, f"{prefix}_mod_11": v % 11, |
| f"{prefix}_mod_13": v % 13, f"{prefix}_mod_17": v % 17, |
| f"{prefix}_mod_19": v % 19, |
| } |
|
|
| def featurize(x, y): |
| sxd = sum(int(c) for c in str(x)); syd = sum(int(c) for c in str(y)) |
| row = {} |
| row.update(num_features(x, "x")); row.update(num_features(y, "y")) |
| row["x_gt_y"] = int(x > y) |
| row["digit_sum_diff_xy"] = sxd - syd |
| return row |
|
|
| def main(N=2000, k_start=2_000_000, out="/tmp/predictions.html"): |
| G = (Gx, Gy) |
| print(f"computing kG for k = {k_start} .. {k_start+N-1}") |
| t0 = time.time() |
| P = mul(k_start - 1, G) |
| feats, ks, truths, bits = [], [], [], [] |
| for i in range(N): |
| k = k_start + i |
| P = add(P, G) |
| feats.append(featurize(*P)) |
| bits.append(bits_of(*P)) |
| ks.append(k); truths.append(k & 1) |
| print(f" {N} points done in {time.time()-t0:.1f}s") |
| bits_arr = np.stack(bits).astype(np.int64) |
|
|
| df_tr = pd.read_parquet("features.parquet") |
| drop = {"k","k_state","abs_x_minus_y"} |
| feat_cols = [c for c in df_tr.columns if c not in drop] |
| X = np.array([[r[c] for c in feat_cols] for r in feats], dtype=np.float32) |
| y = np.array(truths, dtype=np.int8) |
|
|
| bst = xgb.XGBClassifier(); bst.load_model("results/xgb.json") |
| p_xgb = bst.predict_proba(X)[:,1] |
| lgbm = lgb.Booster(model_file="results/lgbm.txt") |
| p_lgb = lgbm.predict(X) |
| Xtr = df_tr[feat_cols].astype(np.float32).values |
| sc = StandardScaler().fit(Xtr[:int(0.7*len(Xtr))]) |
| Xs = sc.transform(X) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| D = X.shape[1] |
| mlp = nn.Sequential(nn.Linear(D,512),nn.ReLU(),nn.Linear(512,512),nn.ReLU(), |
| nn.Linear(512,256),nn.ReLU(),nn.Linear(256,1)).to(device) |
| mlp.load_state_dict(torch.load("results/mlp.pt", map_location=device)) |
| mlp.eval() |
| with torch.no_grad(): |
| logits = mlp(torch.tensor(Xs, dtype=torch.float32, device=device)).squeeze(1).cpu().numpy() |
| p_mlp = 1/(1+np.exp(-logits)) |
|
|
| |
| bx = BitTransformer().to(device) |
| bx.load_state_dict(torch.load("results/bit_xformer.pt", map_location=device)) |
| bx.eval() |
| p_bx = [] |
| with torch.no_grad(): |
| for i in range(0, N, 4096): |
| chunk = torch.tensor(bits_arr[i:i+4096], dtype=torch.long, device=device) |
| p_bx.append(torch.sigmoid(bx(chunk)).cpu().numpy()) |
| p_bx = np.concatenate(p_bx) |
|
|
| pred_mlp = (p_mlp > 0.5).astype(int) |
| pred_xgb = (p_xgb > 0.5).astype(int) |
| pred_lgb = (p_lgb > 0.5).astype(int) |
| pred_bx = (p_bx > 0.5).astype(int) |
|
|
| def cell(pred, truth): |
| letter = "O" if pred == 1 else "E" |
| ok = (pred == truth) |
| cls = "ok" if ok else "bad" |
| mark = "✓" if ok else "✗" |
| return f'<td class="{cls}">{letter} {mark}</td>' |
|
|
| rows = [] |
| for i in range(N): |
| k = ks[i]; t = truths[i] |
| truth_letter = "O" if t == 1 else "E" |
| rows.append("<tr>" |
| f"<td class='k'>{k}</td>" |
| f"<td class='truth'>{truth_letter}</td>" |
| + cell(pred_mlp[i], t) |
| + cell(pred_xgb[i], t) |
| + cell(pred_lgb[i], t) |
| + cell(pred_bx[i], t) |
| + "</tr>") |
|
|
| acc_mlp = (pred_mlp == y).mean() |
| acc_xgb = (pred_xgb == y).mean() |
| acc_lgb = (pred_lgb == y).mean() |
| acc_bx = (pred_bx == y).mean() |
|
|
| html = f"""<!doctype html> |
| <html><head><meta charset="utf-8"><title>secp256k1 parity predictions</title> |
| <style> |
| :root {{ color-scheme: dark; }} |
| body {{ background:#0e1117; color:#e6edf3; font-family:-apple-system,BlinkMacSystemFont,system-ui,sans-serif; margin:24px;}} |
| h1 {{ margin:0 0 6px; font-size:22px; }} |
| .sub {{ color:#8b949e; margin:0 0 18px; font-size:13px; }} |
| .stats {{ display:flex; gap:18px; margin-bottom:16px; }} |
| .stats div {{ background:#161b22; border:1px solid #30363d; border-radius:8px; padding:10px 14px; }} |
| .stats span {{ color:#8b949e; font-size:11px; text-transform:uppercase; display:block; }} |
| .stats b {{ font-size:18px; color:#f7931a; }} |
| table {{ border-collapse:collapse; width:100%; font-family:ui-monospace,monospace; font-size:13px; }} |
| th,td {{ padding:6px 10px; border-bottom:1px solid #21262d; text-align:left; }} |
| th {{ background:#161b22; color:#8b949e; text-transform:uppercase; font-size:11px; letter-spacing:.05em; position:sticky; top:0; }} |
| td.k {{ color:#8b949e; }} |
| td.truth {{ color:#f7931a; font-weight:600; }} |
| td.ok {{ color:#3fb950; }} |
| td.bad {{ color:#f85149; }} |
| tr:hover td {{ background:#161b22; }} |
| </style></head><body> |
| <h1>secp256k1 parity predictions — k = {ks[0]} … {ks[-1]}</h1> |
| <p class="sub">truth column = actual parity of k (O=odd, E=even). Model columns show prediction + ✓ (correct) or ✗ (wrong).</p> |
| <div class="stats"> |
| <div><span>MLP accuracy</span><b>{acc_mlp*100:.2f}%</b></div> |
| <div><span>XGBoost accuracy</span><b>{acc_xgb*100:.2f}%</b></div> |
| <div><span>LightGBM accuracy</span><b>{acc_lgb*100:.2f}%</b></div> |
| <div><span>Bit-Transformer accuracy</span><b>{acc_bx*100:.2f}%</b></div> |
| <div><span>rows</span><b>{N}</b></div> |
| </div> |
| <table> |
| <thead><tr><th>k</th><th>truth</th><th>MLP</th><th>XGBoost</th><th>LightGBM</th><th>BitXformer</th></tr></thead> |
| <tbody> |
| {''.join(rows)} |
| </tbody></table> |
| </body></html> |
| """ |
| with open(out, "w") as f: f.write(html) |
| print(f"wrote {out} ({os.path.getsize(out)/1024:.0f} KB)") |
| print(f"acc: MLP={acc_mlp:.4f} XGB={acc_xgb:.4f} LGB={acc_lgb:.4f} BX={acc_bx:.4f}") |
|
|
| if __name__ == "__main__": |
| N = int(sys.argv[1]) if len(sys.argv) > 1 else 2000 |
| main(N=N) |
|
|