import argparse, math, os import numpy as np import torch def log(*a): import time print(f"[{time.strftime('%H:%M:%S')}]", *a, flush=True) def doc_index(x, eos): starts = torch.zeros_like(x, dtype=torch.bool) starts[0] = True starts[1:] = x[:-1].eq(eos) seg = torch.cumsum(starts.long(), 0) - 1 first = torch.nonzero(starts, as_tuple=False).flatten() within = torch.arange(len(x), device=x.device) - first[seg] return seg, within def shift(M, within, s): out = torch.zeros_like(M) out[s:] = M[:-s] out[within < s] = 0 return out def ppmi_embed(tokens, V, d, window, max_tokens, device): x = torch.tensor(tokens[:max_tokens].astype(np.int64), device=device) C = torch.zeros((V, V), device=device) for s in range(1, window + 1): idx = x[:-s] * V + x[s:] C += torch.bincount(idx, minlength=V * V).float().reshape(V, V) / s tot = C.sum() row = C.sum(1, keepdim=True) + 1e-6 col = C.sum(0, keepdim=True) + 1e-6 M = torch.clamp(torch.log(C * tot / row / col + 1e-12), min=0) U, S, _ = torch.linalg.svd(M) E = U[:, :d] * torch.sqrt(S[:d])[None, :] E = torch.nn.functional.normalize(E, dim=1) return E.float() def pred_layer(H, x, E, P, B, within, args, layer_idx=0): qk = H @ P if args.value_mode in ("h", "dual_next", "dual_ridge_next", "dual_ridge_delta"): Vv = H elif args.value_mode == "next": Vv = E[x] Vv = torch.cat([Vv[1:], Vv[-1:]], 0) elif args.value_mode == "delta": Y = E[x] Y = torch.cat([Y[1:], Y[-1:]], 0) Vv = Y - H elif args.value_mode == "ridge_next": Vv = H @ B elif args.value_mode == "ridge_delta": Vv = H @ B - H num = torch.zeros_like(H) num_pred = torch.zeros_like(H) den = torch.zeros((H.shape[0], 1), device=H.device) if args.value_mode == "dual_next": Yp = E[x] Vpred = torch.cat([Yp[1:], Yp[-1:]], 0) elif args.value_mode == "dual_ridge_next": Vpred = H @ B elif args.value_mode == "dual_ridge_delta": Vpred = H @ B - H else: Vpred = None for s in range(1, args.att_window + 1): ks = shift(qk, within, s) vs = shift(Vv, within, s) w = torch.exp(((qk * ks).sum(1, keepdim=True) / args.temp).clamp(-30, 30)) w = torch.where((within >= s)[:, None], w, torch.zeros_like(w)) num += w * vs if Vpred is not None: num_pred += w * shift(Vpred, within, s) den += w ctx = num / (den + 1e-6) pred_out = None if Vpred is not None: pred = num_pred / (den + 1e-6) if args.orth_delta: pred = pred - H * (pred * H).sum(1, keepdim=True) if args.pred_norm: pred = pred / (pred.norm(dim=1, keepdim=True) + 1e-6) scale = args.pred_scale if args.pred_schedule == "linear": scale = scale * float(layer_idx + 1) / max(1, args.layers) elif args.pred_schedule == "late": scale = scale * max(0.0, float(layer_idx + 1 - args.layers // 3) / max(1, args.layers - args.layers // 3)) pred_out = pred H = H + args.res_scale * ctx + scale * pred elif "delta" in args.value_mode: H = H + args.res_scale * ctx else: H = (1 - args.res_scale) * H + args.res_scale * ctx return torch.nn.functional.normalize(H, dim=1), pred_out def apply_stack(x, E, Ps, Bs, within, args, expose=True): H = E[x] phis = [] last_pred = None for li, (P, B) in enumerate(zip(Ps, Bs)): if expose: q = H @ P phis += [torch.relu(q), torch.abs(q), q * q] m = min(64, q.shape[1] - 1) if m > 0: phis.append(q[:, :m] * q[:, 1:m+1]) H, last_pred = pred_layer(H, x, E, P, B, within, args, li) if expose and args.pred_features and last_pred is not None: phis += [last_pred, H * last_pred] return H, phis def features(H, within, phis, extra): prev = shift(H, within, 1) blocks = [H, prev, H * prev] if extra: prev2 = shift(H, within, 2) blocks += [prev2, prev * prev2, H * prev2] blocks += phis + [torch.ones((H.shape[0], 1), device=H.device)] return torch.cat(blocks, 1) def learn_map(tokens, E, Ps, Bs, eos, args, device): xnp = tokens[:args.proj_tokens] starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) d = E.shape[1] C = torch.zeros((d, d), device=device, dtype=torch.float64) A = torch.zeros_like(C); G = torch.zeros_like(C) for i in range(0, len(starts), args.chunk_docs): lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) seg, within = doc_index(x, eos) H, _ = apply_stack(x, E, Ps, Bs, within, args, expose=False) Y = torch.cat([E[x][1:], E[x][-1:]], 0) valid = torch.ones(len(x), device=device, dtype=torch.bool) valid[-1] = False; valid[:-1] &= seg[1:].eq(seg[:-1]); valid &= x.ne(eos) X = H[valid].double(); T = Y[valid].double() C += X.T @ T; A += X.T @ X; G += X.T @ T U, S, _ = torch.linalg.svd(C) P = U[:, :args.r].float() diag = torch.trace(A) / d B = torch.linalg.solve(A + args.map_lam * diag * torch.eye(d, device=device, dtype=torch.float64), G).float() log("sv", S[:4].detach().cpu().numpy().round(1)) return P, B def fit_eval(train, valid, E, Ps, Bs, eos, V, args, device): def stats(tokens, max_tokens): xnp = tokens[:max_tokens] starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) A = G = None for i in range(0, len(starts), args.chunk_docs): lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) seg, within = doc_index(x, eos) H, phis = apply_stack(x, E, Ps, Bs, within, args) Phi = features(H, within, phis, args.extra_context) y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos valid_m = torch.ones(len(x), device=device, dtype=torch.bool) valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos) Phi = Phi[valid_m]; y = y[valid_m] if A is None: D = Phi.shape[1] A = torch.zeros((D, D), device=device, dtype=torch.float64) G = torch.zeros((D, V), device=device, dtype=torch.float64) A += Phi.double().T @ Phi.double() G.index_add_(1, y, Phi.double().T) return A, G, A.shape[0] A, G, D = stats(train, args.fit_tokens) uni = torch.tensor((np.bincount(train.astype(np.int64), minlength=V)+1), device=device).float() uni = uni / uni.sum() best = None for lam in [float(x) for x in args.lams.split(",")]: diag = torch.trace(A) / D W = torch.linalg.solve(A + lam * diag * torch.eye(D, device=device, dtype=torch.float64), G).float() nll = 0.0; n = 0 xnp = valid[:args.eval_tokens] starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) for i in range(0, len(starts), args.chunk_docs): lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) seg, within = doc_index(x, eos) H, phis = apply_stack(x, E, Ps, Bs, within, args) Phi = features(H, within, phis, args.extra_context) y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos valid_m = torch.ones(len(x), device=device, dtype=torch.bool) valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos) Phi = Phi[valid_m]; y = y[valid_m] S = Phi @ W Pp = torch.relu(S) + args.floor * uni[None, :] Pp = Pp / Pp.sum(1, keepdim=True) nll += float(-torch.log(Pp[torch.arange(len(y), device=device), y] + 1e-12).sum()) n += len(y) ppl = math.exp(nll / n) log("lam", lam, "ppl", round(ppl, 2)) if best is None or ppl < best[1]: best = (lam, ppl) log(f"TORCH_PRED mode={args.value_mode} ppl={best[1]:.2f} lam={best[0]} D={D}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--data", default="/workspace/ts_mini") ap.add_argument("--spm_model", default=None) ap.add_argument("--train_bin", default=None) ap.add_argument("--valid_bin", default=None) ap.add_argument("--vocab", type=int, default=1024) ap.add_argument("--d", type=int, default=192) ap.add_argument("--r", type=int, default=64) ap.add_argument("--layers", type=int, default=2) ap.add_argument("--att_window", type=int, default=8) ap.add_argument("--temp", type=float, default=0.3) ap.add_argument("--window", type=int, default=8) ap.add_argument("--extra_context", type=int, default=1) ap.add_argument("--res_scale", type=float, default=0.12) ap.add_argument("--pred_scale", type=float, default=0.04) ap.add_argument("--pred_schedule", choices=["flat", "linear", "late"], default="flat") ap.add_argument("--orth_delta", type=int, default=0) ap.add_argument("--pred_norm", type=int, default=0) ap.add_argument("--pred_features", type=int, default=0) ap.add_argument("--map_lam", type=float, default=0.001) ap.add_argument("--cooc_tokens", type=int, default=1_000_000) ap.add_argument("--proj_tokens", type=int, default=500_000) ap.add_argument("--fit_tokens", type=int, default=800_000) ap.add_argument("--eval_tokens", type=int, default=100_000) ap.add_argument("--chunk_docs", type=int, default=40) ap.add_argument("--lams", default="0.003,0.01,0.03,0.1") ap.add_argument("--floor", type=float, default=1e-4) ap.add_argument("--value_mode", choices=["h","next","delta","ridge_next","ridge_delta", "dual_next","dual_ridge_next","dual_ridge_delta"], default="h") args = ap.parse_args() import sentencepiece as spm spm_model = args.spm_model or os.path.join(args.data, f"sp{args.vocab}.model") train_bin = args.train_bin or os.path.join(args.data, "train.bin") valid_bin = args.valid_bin or os.path.join(args.data, "valid.bin") sp = spm.SentencePieceProcessor(model_file=spm_model) eos = sp.eos_id(); V = sp.get_piece_size() device = "cuda" if torch.cuda.is_available() else "cpu" train = np.fromfile(train_bin, dtype=np.uint16) valid = np.fromfile(valid_bin, dtype=np.uint16) log("mode", args.value_mode, "device", device, "V", V, "train", len(train), "valid", len(valid)) E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) Ps, Bs = [], [] for _ in range(args.layers): P, B = learn_map(train, E, Ps, Bs, eos, args, device) Ps.append(P); Bs.append(B) fit_eval(train, valid, E, Ps, Bs, eos, V, args, device) if __name__ == "__main__": main()