| import argparse, math, os |
| import numpy as np |
| import torch |
|
|
|
|
| def log(*a): |
| import time |
| print(f"[{time.strftime('%H:%M:%S')}]", *a, flush=True) |
|
|
|
|
| def doc_index(x, eos): |
| starts = torch.zeros_like(x, dtype=torch.bool) |
| starts[0] = True |
| starts[1:] = x[:-1].eq(eos) |
| seg = torch.cumsum(starts.long(), 0) - 1 |
| first = torch.nonzero(starts, as_tuple=False).flatten() |
| within = torch.arange(len(x), device=x.device) - first[seg] |
| return seg, within |
|
|
|
|
| def shift(M, within, s): |
| out = torch.zeros_like(M) |
| out[s:] = M[:-s] |
| out[within < s] = 0 |
| return out |
|
|
|
|
| def ppmi_embed(tokens, V, d, window, max_tokens, device): |
| x = torch.tensor(tokens[:max_tokens].astype(np.int64), device=device) |
| C = torch.zeros((V, V), device=device) |
| for s in range(1, window + 1): |
| idx = x[:-s] * V + x[s:] |
| C += torch.bincount(idx, minlength=V * V).float().reshape(V, V) / s |
| tot = C.sum() |
| row = C.sum(1, keepdim=True) + 1e-6 |
| col = C.sum(0, keepdim=True) + 1e-6 |
| M = torch.clamp(torch.log(C * tot / row / col + 1e-12), min=0) |
| U, S, _ = torch.linalg.svd(M) |
| E = U[:, :d] * torch.sqrt(S[:d])[None, :] |
| E = torch.nn.functional.normalize(E, dim=1) |
| return E.float() |
|
|
|
|
| def pred_layer(H, x, E, P, B, within, args, layer_idx=0): |
| qk = H @ P |
| if args.value_mode in ("h", "dual_next", "dual_ridge_next", "dual_ridge_delta"): |
| Vv = H |
| elif args.value_mode == "next": |
| Vv = E[x] |
| Vv = torch.cat([Vv[1:], Vv[-1:]], 0) |
| elif args.value_mode == "delta": |
| Y = E[x] |
| Y = torch.cat([Y[1:], Y[-1:]], 0) |
| Vv = Y - H |
| elif args.value_mode == "ridge_next": |
| Vv = H @ B |
| elif args.value_mode == "ridge_delta": |
| Vv = H @ B - H |
| num = torch.zeros_like(H) |
| num_pred = torch.zeros_like(H) |
| den = torch.zeros((H.shape[0], 1), device=H.device) |
| if args.value_mode == "dual_next": |
| Yp = E[x] |
| Vpred = torch.cat([Yp[1:], Yp[-1:]], 0) |
| elif args.value_mode == "dual_ridge_next": |
| Vpred = H @ B |
| elif args.value_mode == "dual_ridge_delta": |
| Vpred = H @ B - H |
| else: |
| Vpred = None |
| for s in range(1, args.att_window + 1): |
| ks = shift(qk, within, s) |
| vs = shift(Vv, within, s) |
| w = torch.exp(((qk * ks).sum(1, keepdim=True) / args.temp).clamp(-30, 30)) |
| w = torch.where((within >= s)[:, None], w, torch.zeros_like(w)) |
| num += w * vs |
| if Vpred is not None: |
| num_pred += w * shift(Vpred, within, s) |
| den += w |
| ctx = num / (den + 1e-6) |
| pred_out = None |
| if Vpred is not None: |
| pred = num_pred / (den + 1e-6) |
| if args.orth_delta: |
| pred = pred - H * (pred * H).sum(1, keepdim=True) |
| if args.pred_norm: |
| pred = pred / (pred.norm(dim=1, keepdim=True) + 1e-6) |
| scale = args.pred_scale |
| if args.pred_schedule == "linear": |
| scale = scale * float(layer_idx + 1) / max(1, args.layers) |
| elif args.pred_schedule == "late": |
| scale = scale * max(0.0, float(layer_idx + 1 - args.layers // 3) / max(1, args.layers - args.layers // 3)) |
| pred_out = pred |
| H = H + args.res_scale * ctx + scale * pred |
| elif "delta" in args.value_mode: |
| H = H + args.res_scale * ctx |
| else: |
| H = (1 - args.res_scale) * H + args.res_scale * ctx |
| return torch.nn.functional.normalize(H, dim=1), pred_out |
|
|
|
|
| def apply_stack(x, E, Ps, Bs, within, args, expose=True): |
| H = E[x] |
| phis = [] |
| last_pred = None |
| for li, (P, B) in enumerate(zip(Ps, Bs)): |
| if expose: |
| q = H @ P |
| phis += [torch.relu(q), torch.abs(q), q * q] |
| m = min(64, q.shape[1] - 1) |
| if m > 0: |
| phis.append(q[:, :m] * q[:, 1:m+1]) |
| H, last_pred = pred_layer(H, x, E, P, B, within, args, li) |
| if expose and args.pred_features and last_pred is not None: |
| phis += [last_pred, H * last_pred] |
| return H, phis |
|
|
|
|
| def features(H, within, phis, extra): |
| prev = shift(H, within, 1) |
| blocks = [H, prev, H * prev] |
| if extra: |
| prev2 = shift(H, within, 2) |
| blocks += [prev2, prev * prev2, H * prev2] |
| blocks += phis + [torch.ones((H.shape[0], 1), device=H.device)] |
| return torch.cat(blocks, 1) |
|
|
|
|
| def learn_map(tokens, E, Ps, Bs, eos, args, device): |
| xnp = tokens[:args.proj_tokens] |
| starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) |
| d = E.shape[1] |
| C = torch.zeros((d, d), device=device, dtype=torch.float64) |
| A = torch.zeros_like(C); G = torch.zeros_like(C) |
| for i in range(0, len(starts), args.chunk_docs): |
| lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) |
| x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) |
| seg, within = doc_index(x, eos) |
| H, _ = apply_stack(x, E, Ps, Bs, within, args, expose=False) |
| Y = torch.cat([E[x][1:], E[x][-1:]], 0) |
| valid = torch.ones(len(x), device=device, dtype=torch.bool) |
| valid[-1] = False; valid[:-1] &= seg[1:].eq(seg[:-1]); valid &= x.ne(eos) |
| X = H[valid].double(); T = Y[valid].double() |
| C += X.T @ T; A += X.T @ X; G += X.T @ T |
| U, S, _ = torch.linalg.svd(C) |
| P = U[:, :args.r].float() |
| diag = torch.trace(A) / d |
| B = torch.linalg.solve(A + args.map_lam * diag * torch.eye(d, device=device, dtype=torch.float64), G).float() |
| log("sv", S[:4].detach().cpu().numpy().round(1)) |
| return P, B |
|
|
|
|
| def fit_eval(train, valid, E, Ps, Bs, eos, V, args, device): |
| def stats(tokens, max_tokens): |
| xnp = tokens[:max_tokens] |
| starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) |
| A = G = None |
| for i in range(0, len(starts), args.chunk_docs): |
| lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) |
| x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) |
| seg, within = doc_index(x, eos) |
| H, phis = apply_stack(x, E, Ps, Bs, within, args) |
| Phi = features(H, within, phis, args.extra_context) |
| y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos |
| valid_m = torch.ones(len(x), device=device, dtype=torch.bool) |
| valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos) |
| Phi = Phi[valid_m]; y = y[valid_m] |
| if A is None: |
| D = Phi.shape[1] |
| A = torch.zeros((D, D), device=device, dtype=torch.float64) |
| G = torch.zeros((D, V), device=device, dtype=torch.float64) |
| A += Phi.double().T @ Phi.double() |
| G.index_add_(1, y, Phi.double().T) |
| return A, G, A.shape[0] |
| A, G, D = stats(train, args.fit_tokens) |
| uni = torch.tensor((np.bincount(train.astype(np.int64), minlength=V)+1), device=device).float() |
| uni = uni / uni.sum() |
| best = None |
| for lam in [float(x) for x in args.lams.split(",")]: |
| diag = torch.trace(A) / D |
| W = torch.linalg.solve(A + lam * diag * torch.eye(D, device=device, dtype=torch.float64), G).float() |
| nll = 0.0; n = 0 |
| xnp = valid[:args.eval_tokens] |
| starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) |
| for i in range(0, len(starts), args.chunk_docs): |
| lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp) |
| x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device) |
| seg, within = doc_index(x, eos) |
| H, phis = apply_stack(x, E, Ps, Bs, within, args) |
| Phi = features(H, within, phis, args.extra_context) |
| y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos |
| valid_m = torch.ones(len(x), device=device, dtype=torch.bool) |
| valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos) |
| Phi = Phi[valid_m]; y = y[valid_m] |
| S = Phi @ W |
| Pp = torch.relu(S) + args.floor * uni[None, :] |
| Pp = Pp / Pp.sum(1, keepdim=True) |
| nll += float(-torch.log(Pp[torch.arange(len(y), device=device), y] + 1e-12).sum()) |
| n += len(y) |
| ppl = math.exp(nll / n) |
| log("lam", lam, "ppl", round(ppl, 2)) |
| if best is None or ppl < best[1]: best = (lam, ppl) |
| log(f"TORCH_PRED mode={args.value_mode} ppl={best[1]:.2f} lam={best[0]} D={D}") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data", default="/workspace/ts_mini") |
| ap.add_argument("--spm_model", default=None) |
| ap.add_argument("--train_bin", default=None) |
| ap.add_argument("--valid_bin", default=None) |
| ap.add_argument("--vocab", type=int, default=1024) |
| ap.add_argument("--d", type=int, default=192) |
| ap.add_argument("--r", type=int, default=64) |
| ap.add_argument("--layers", type=int, default=2) |
| ap.add_argument("--att_window", type=int, default=8) |
| ap.add_argument("--temp", type=float, default=0.3) |
| ap.add_argument("--window", type=int, default=8) |
| ap.add_argument("--extra_context", type=int, default=1) |
| ap.add_argument("--res_scale", type=float, default=0.12) |
| ap.add_argument("--pred_scale", type=float, default=0.04) |
| ap.add_argument("--pred_schedule", choices=["flat", "linear", "late"], default="flat") |
| ap.add_argument("--orth_delta", type=int, default=0) |
| ap.add_argument("--pred_norm", type=int, default=0) |
| ap.add_argument("--pred_features", type=int, default=0) |
| ap.add_argument("--map_lam", type=float, default=0.001) |
| ap.add_argument("--cooc_tokens", type=int, default=1_000_000) |
| ap.add_argument("--proj_tokens", type=int, default=500_000) |
| ap.add_argument("--fit_tokens", type=int, default=800_000) |
| ap.add_argument("--eval_tokens", type=int, default=100_000) |
| ap.add_argument("--chunk_docs", type=int, default=40) |
| ap.add_argument("--lams", default="0.003,0.01,0.03,0.1") |
| ap.add_argument("--floor", type=float, default=1e-4) |
| ap.add_argument("--value_mode", choices=["h","next","delta","ridge_next","ridge_delta", |
| "dual_next","dual_ridge_next","dual_ridge_delta"], default="h") |
| args = ap.parse_args() |
| import sentencepiece as spm |
| spm_model = args.spm_model or os.path.join(args.data, f"sp{args.vocab}.model") |
| train_bin = args.train_bin or os.path.join(args.data, "train.bin") |
| valid_bin = args.valid_bin or os.path.join(args.data, "valid.bin") |
| sp = spm.SentencePieceProcessor(model_file=spm_model) |
| eos = sp.eos_id(); V = sp.get_piece_size() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| train = np.fromfile(train_bin, dtype=np.uint16) |
| valid = np.fromfile(valid_bin, dtype=np.uint16) |
| log("mode", args.value_mode, "device", device, "V", V, "train", len(train), "valid", len(valid)) |
| E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) |
| Ps, Bs = [], [] |
| for _ in range(args.layers): |
| P, B = learn_map(train, E, Ps, Bs, eos, args, device) |
| Ps.append(P); Bs.append(B) |
| fit_eval(train, valid, E, Ps, Bs, eos, V, args, device) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|