File size: 5,969 Bytes
6b93c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python3
"""Train a small transformer over the 512 bits of (x, y) to predict k_state.
This is the model class most likely to catch XOR/bit-level interactions if any exist.
Output: results/bit_xformer.pt, plus appends metrics to results/metrics.json
"""
import os, time, json, sys, math
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, roc_auc_score

# ---- curve params ----
p_field = 2**256 - 2**32 - 977
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424

def inv(a): return pow(a, p_field-2, p_field)
def add(P, Q):
    if P is None: return Q
    if Q is None: return P
    x1,y1=P; x2,y2=Q
    if x1==x2 and (y1+y2)%p_field==0: return None
    m=(3*x1*x1)*inv(2*y1)%p_field if P==Q else (y2-y1)*inv(x2-x1)%p_field
    x3=(m*m-x1-x2)%p_field
    return (x3,(m*(x1-x3)-y1)%p_field)

def gen_bits(N=1_000_000):
    """Walk k=1..N incrementally, return (bits[N,512] as uint8, labels[N] as int8)."""
    print(f"generating {N} (x,y) -> 512-bit vectors...", flush=True)
    G = (Gx, Gy)
    P = None
    bits = np.empty((N, 512), dtype=np.uint8)
    labels = np.empty(N, dtype=np.int8)
    t0 = time.time()
    LOG = max(1, N // 20)
    for i in range(N):
        k = i + 1
        P = add(P, G)
        x, y = P
        # pack 256 bits of x then 256 bits of y, MSB first
        for j in range(256):
            bits[i, j]       = (x >> (255 - j)) & 1
            bits[i, 256 + j] = (y >> (255 - j)) & 1
        labels[i] = k & 1
        if (i+1) % LOG == 0:
            r = (i+1)/(time.time()-t0)
            print(f"  {i+1}/{N} ({(i+1)/N*100:5.1f}%)  rate={r:.0f}/s  ETA={(N-i-1)/r:.0f}s", flush=True)
    print(f"  done in {time.time()-t0:.1f}s")
    return bits, labels

class BitTransformer(nn.Module):
    def __init__(self, seq_len=512, d=128, nhead=4, nlayers=4):
        super().__init__()
        self.tok = nn.Embedding(2, d)                        # bit 0 or 1
        self.pos = nn.Parameter(torch.randn(1, seq_len, d) * 0.02)
        self.cls = nn.Parameter(torch.randn(1, 1, d) * 0.02)
        enc = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=4*d,
                                         batch_first=True, activation="gelu", norm_first=True)
        self.enc = nn.TransformerEncoder(enc, num_layers=nlayers)
        self.head = nn.Linear(d, 1)
    def forward(self, x_bits):                                # (B, 512) int
        h = self.tok(x_bits) + self.pos                       # (B, 512, d)
        cls = self.cls.expand(h.size(0), -1, -1)
        h = torch.cat([cls, h], dim=1)                        # (B, 513, d)
        h = self.enc(h)
        return self.head(h[:, 0, :]).squeeze(1)               # logits (B,)

def split_seq(N, ftr=0.70, fva=0.15):
    i1 = int(N*ftr); i2 = int(N*(ftr+fva))
    return (slice(0, i1), slice(i1, i2), slice(i2, N))

def main(N=1_000_000, epochs=8, batch=512, d=128, nlayers=4):
    bits, labels = gen_bits(N)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tr, va, ho = split_seq(N)
    print(f"split sizes: train={tr.stop-tr.start}  val={va.stop-va.start}  ho={ho.stop-ho.start}")

    bits_t   = torch.tensor(bits, dtype=torch.long)
    labels_t = torch.tensor(labels, dtype=torch.float32)

    model = BitTransformer(d=d, nlayers=nlayers).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"model: BitTransformer d={d} layers={nlayers}  params={n_params/1e6:.2f}M  device={device}")
    opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    crit = nn.BCEWithLogitsLoss()
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    best_val_auc = 0.0
    for ep in range(epochs):
        model.train()
        idx = torch.randperm(tr.stop)
        t0 = time.time()
        losses = []
        for i in range(0, len(idx), batch):
            b = idx[i:i+batch]
            xb = bits_t[b].to(device, non_blocking=True)
            yb = labels_t[b].to(device, non_blocking=True)
            logit = model(xb)
            loss = crit(logit, yb)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(float(loss.item()))
        sched.step()
        # val
        model.eval()
        with torch.no_grad():
            pv = []
            for i in range(va.start, va.stop, 4096):
                b = slice(i, min(i+4096, va.stop))
                pv.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
            pv = np.concatenate(pv)
        yv = labels[va.start:va.stop]
        acc = accuracy_score(yv, pv>0.5); auc = roc_auc_score(yv, pv)
        print(f"epoch {ep+1}/{epochs}  loss={np.mean(losses):.4f}  val_acc={acc:.4f}  val_auc={auc:.4f}  ({time.time()-t0:.0f}s)", flush=True)
        if auc > best_val_auc:
            best_val_auc = auc
            torch.save(model.state_dict(), "results/bit_xformer.pt")

    # holdout
    model.eval()
    with torch.no_grad():
        ph = []
        for i in range(ho.start, ho.stop, 4096):
            b = slice(i, min(i+4096, ho.stop))
            ph.append(torch.sigmoid(model(bits_t[b].to(device))).cpu().numpy())
        ph = np.concatenate(ph)
    yh = labels[ho.start:ho.stop]
    ho_acc = accuracy_score(yh, ph>0.5); ho_auc = roc_auc_score(yh, ph)
    print(f"\nHOLDOUT  acc={ho_acc:.4f}  auc={ho_auc:.4f}")

    # append to metrics.json
    try:
        with open("results/metrics.json") as f: m = json.load(f)
    except FileNotFoundError:
        m = {}
    m["bit_transformer"] = {
        "val_acc": float(acc), "val_auc": float(auc),
        "ho_acc": float(ho_acc), "ho_auc": float(ho_auc),
        "params": int(n_params), "epochs": epochs, "d": d, "layers": nlayers,
    }
    with open("results/metrics.json", "w") as f: json.dump(m, f, indent=2)
    print("metrics saved.")

if __name__ == "__main__":
    main()