testoneshot / scripts /torch_predictive_attn.py
Asilarknes's picture
upload oneshot glm artifacts
a216fa7 verified
import argparse, math, os
import numpy as np
import torch
def log(*a):
import time
print(f"[{time.strftime('%H:%M:%S')}]", *a, flush=True)
def doc_index(x, eos):
starts = torch.zeros_like(x, dtype=torch.bool)
starts[0] = True
starts[1:] = x[:-1].eq(eos)
seg = torch.cumsum(starts.long(), 0) - 1
first = torch.nonzero(starts, as_tuple=False).flatten()
within = torch.arange(len(x), device=x.device) - first[seg]
return seg, within
def shift(M, within, s):
out = torch.zeros_like(M)
out[s:] = M[:-s]
out[within < s] = 0
return out
def ppmi_embed(tokens, V, d, window, max_tokens, device):
x = torch.tensor(tokens[:max_tokens].astype(np.int64), device=device)
C = torch.zeros((V, V), device=device)
for s in range(1, window + 1):
idx = x[:-s] * V + x[s:]
C += torch.bincount(idx, minlength=V * V).float().reshape(V, V) / s
tot = C.sum()
row = C.sum(1, keepdim=True) + 1e-6
col = C.sum(0, keepdim=True) + 1e-6
M = torch.clamp(torch.log(C * tot / row / col + 1e-12), min=0)
U, S, _ = torch.linalg.svd(M)
E = U[:, :d] * torch.sqrt(S[:d])[None, :]
E = torch.nn.functional.normalize(E, dim=1)
return E.float()
def pred_layer(H, x, E, P, B, within, args, layer_idx=0):
qk = H @ P
if args.value_mode in ("h", "dual_next", "dual_ridge_next", "dual_ridge_delta"):
Vv = H
elif args.value_mode == "next":
Vv = E[x]
Vv = torch.cat([Vv[1:], Vv[-1:]], 0)
elif args.value_mode == "delta":
Y = E[x]
Y = torch.cat([Y[1:], Y[-1:]], 0)
Vv = Y - H
elif args.value_mode == "ridge_next":
Vv = H @ B
elif args.value_mode == "ridge_delta":
Vv = H @ B - H
num = torch.zeros_like(H)
num_pred = torch.zeros_like(H)
den = torch.zeros((H.shape[0], 1), device=H.device)
if args.value_mode == "dual_next":
Yp = E[x]
Vpred = torch.cat([Yp[1:], Yp[-1:]], 0)
elif args.value_mode == "dual_ridge_next":
Vpred = H @ B
elif args.value_mode == "dual_ridge_delta":
Vpred = H @ B - H
else:
Vpred = None
for s in range(1, args.att_window + 1):
ks = shift(qk, within, s)
vs = shift(Vv, within, s)
w = torch.exp(((qk * ks).sum(1, keepdim=True) / args.temp).clamp(-30, 30))
w = torch.where((within >= s)[:, None], w, torch.zeros_like(w))
num += w * vs
if Vpred is not None:
num_pred += w * shift(Vpred, within, s)
den += w
ctx = num / (den + 1e-6)
pred_out = None
if Vpred is not None:
pred = num_pred / (den + 1e-6)
if args.orth_delta:
pred = pred - H * (pred * H).sum(1, keepdim=True)
if args.pred_norm:
pred = pred / (pred.norm(dim=1, keepdim=True) + 1e-6)
scale = args.pred_scale
if args.pred_schedule == "linear":
scale = scale * float(layer_idx + 1) / max(1, args.layers)
elif args.pred_schedule == "late":
scale = scale * max(0.0, float(layer_idx + 1 - args.layers // 3) / max(1, args.layers - args.layers // 3))
pred_out = pred
H = H + args.res_scale * ctx + scale * pred
elif "delta" in args.value_mode:
H = H + args.res_scale * ctx
else:
H = (1 - args.res_scale) * H + args.res_scale * ctx
return torch.nn.functional.normalize(H, dim=1), pred_out
def apply_stack(x, E, Ps, Bs, within, args, expose=True):
H = E[x]
phis = []
last_pred = None
for li, (P, B) in enumerate(zip(Ps, Bs)):
if expose:
q = H @ P
phis += [torch.relu(q), torch.abs(q), q * q]
m = min(64, q.shape[1] - 1)
if m > 0:
phis.append(q[:, :m] * q[:, 1:m+1])
H, last_pred = pred_layer(H, x, E, P, B, within, args, li)
if expose and args.pred_features and last_pred is not None:
phis += [last_pred, H * last_pred]
return H, phis
def features(H, within, phis, extra):
prev = shift(H, within, 1)
blocks = [H, prev, H * prev]
if extra:
prev2 = shift(H, within, 2)
blocks += [prev2, prev * prev2, H * prev2]
blocks += phis + [torch.ones((H.shape[0], 1), device=H.device)]
return torch.cat(blocks, 1)
def learn_map(tokens, E, Ps, Bs, eos, args, device):
xnp = tokens[:args.proj_tokens]
starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos])
d = E.shape[1]
C = torch.zeros((d, d), device=device, dtype=torch.float64)
A = torch.zeros_like(C); G = torch.zeros_like(C)
for i in range(0, len(starts), args.chunk_docs):
lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp)
x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device)
seg, within = doc_index(x, eos)
H, _ = apply_stack(x, E, Ps, Bs, within, args, expose=False)
Y = torch.cat([E[x][1:], E[x][-1:]], 0)
valid = torch.ones(len(x), device=device, dtype=torch.bool)
valid[-1] = False; valid[:-1] &= seg[1:].eq(seg[:-1]); valid &= x.ne(eos)
X = H[valid].double(); T = Y[valid].double()
C += X.T @ T; A += X.T @ X; G += X.T @ T
U, S, _ = torch.linalg.svd(C)
P = U[:, :args.r].float()
diag = torch.trace(A) / d
B = torch.linalg.solve(A + args.map_lam * diag * torch.eye(d, device=device, dtype=torch.float64), G).float()
log("sv", S[:4].detach().cpu().numpy().round(1))
return P, B
def fit_eval(train, valid, E, Ps, Bs, eos, V, args, device):
def stats(tokens, max_tokens):
xnp = tokens[:max_tokens]
starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos])
A = G = None
for i in range(0, len(starts), args.chunk_docs):
lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp)
x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device)
seg, within = doc_index(x, eos)
H, phis = apply_stack(x, E, Ps, Bs, within, args)
Phi = features(H, within, phis, args.extra_context)
y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos
valid_m = torch.ones(len(x), device=device, dtype=torch.bool)
valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos)
Phi = Phi[valid_m]; y = y[valid_m]
if A is None:
D = Phi.shape[1]
A = torch.zeros((D, D), device=device, dtype=torch.float64)
G = torch.zeros((D, V), device=device, dtype=torch.float64)
A += Phi.double().T @ Phi.double()
G.index_add_(1, y, Phi.double().T)
return A, G, A.shape[0]
A, G, D = stats(train, args.fit_tokens)
uni = torch.tensor((np.bincount(train.astype(np.int64), minlength=V)+1), device=device).float()
uni = uni / uni.sum()
best = None
for lam in [float(x) for x in args.lams.split(",")]:
diag = torch.trace(A) / D
W = torch.linalg.solve(A + lam * diag * torch.eye(D, device=device, dtype=torch.float64), G).float()
nll = 0.0; n = 0
xnp = valid[:args.eval_tokens]
starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos])
for i in range(0, len(starts), args.chunk_docs):
lo = starts[i]; hi = starts[i+args.chunk_docs] if i+args.chunk_docs < len(starts) else len(xnp)
x = torch.tensor(xnp[lo:hi].astype(np.int64), device=device)
seg, within = doc_index(x, eos)
H, phis = apply_stack(x, E, Ps, Bs, within, args)
Phi = features(H, within, phis, args.extra_context)
y = torch.empty(len(x), device=device, dtype=torch.long); y[:-1] = x[1:]; y[-1] = eos
valid_m = torch.ones(len(x), device=device, dtype=torch.bool)
valid_m[-1] = False; valid_m[:-1] &= seg[1:].eq(seg[:-1]); valid_m &= x.ne(eos)
Phi = Phi[valid_m]; y = y[valid_m]
S = Phi @ W
Pp = torch.relu(S) + args.floor * uni[None, :]
Pp = Pp / Pp.sum(1, keepdim=True)
nll += float(-torch.log(Pp[torch.arange(len(y), device=device), y] + 1e-12).sum())
n += len(y)
ppl = math.exp(nll / n)
log("lam", lam, "ppl", round(ppl, 2))
if best is None or ppl < best[1]: best = (lam, ppl)
log(f"TORCH_PRED mode={args.value_mode} ppl={best[1]:.2f} lam={best[0]} D={D}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data", default="/workspace/ts_mini")
ap.add_argument("--spm_model", default=None)
ap.add_argument("--train_bin", default=None)
ap.add_argument("--valid_bin", default=None)
ap.add_argument("--vocab", type=int, default=1024)
ap.add_argument("--d", type=int, default=192)
ap.add_argument("--r", type=int, default=64)
ap.add_argument("--layers", type=int, default=2)
ap.add_argument("--att_window", type=int, default=8)
ap.add_argument("--temp", type=float, default=0.3)
ap.add_argument("--window", type=int, default=8)
ap.add_argument("--extra_context", type=int, default=1)
ap.add_argument("--res_scale", type=float, default=0.12)
ap.add_argument("--pred_scale", type=float, default=0.04)
ap.add_argument("--pred_schedule", choices=["flat", "linear", "late"], default="flat")
ap.add_argument("--orth_delta", type=int, default=0)
ap.add_argument("--pred_norm", type=int, default=0)
ap.add_argument("--pred_features", type=int, default=0)
ap.add_argument("--map_lam", type=float, default=0.001)
ap.add_argument("--cooc_tokens", type=int, default=1_000_000)
ap.add_argument("--proj_tokens", type=int, default=500_000)
ap.add_argument("--fit_tokens", type=int, default=800_000)
ap.add_argument("--eval_tokens", type=int, default=100_000)
ap.add_argument("--chunk_docs", type=int, default=40)
ap.add_argument("--lams", default="0.003,0.01,0.03,0.1")
ap.add_argument("--floor", type=float, default=1e-4)
ap.add_argument("--value_mode", choices=["h","next","delta","ridge_next","ridge_delta",
"dual_next","dual_ridge_next","dual_ridge_delta"], default="h")
args = ap.parse_args()
import sentencepiece as spm
spm_model = args.spm_model or os.path.join(args.data, f"sp{args.vocab}.model")
train_bin = args.train_bin or os.path.join(args.data, "train.bin")
valid_bin = args.valid_bin or os.path.join(args.data, "valid.bin")
sp = spm.SentencePieceProcessor(model_file=spm_model)
eos = sp.eos_id(); V = sp.get_piece_size()
device = "cuda" if torch.cuda.is_available() else "cpu"
train = np.fromfile(train_bin, dtype=np.uint16)
valid = np.fromfile(valid_bin, dtype=np.uint16)
log("mode", args.value_mode, "device", device, "V", V, "train", len(train), "valid", len(valid))
E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device)
Ps, Bs = [], []
for _ in range(args.layers):
P, B = learn_map(train, E, Ps, Bs, eos, args, device)
Ps.append(P); Bs.append(B)
fit_eval(train, valid, E, Ps, Bs, eos, V, args, device)
if __name__ == "__main__":
main()