| """ |
| PHASE B eval: full sweep on N ∈ {25, 60} × all mapping methods × 5 held-out tasks. |
| Reuses adapters trained by phaseB_train.py. |
| """ |
| import os, sys, json, gc, shutil, re, collections |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed |
| from peft import PeftModel |
| from safetensors.torch import load_file, save_file |
|
|
| sys.path.insert(0, "/app") |
| import scaled_pipeline as sp |
| import phaseA_new_methods as pa |
| import phaseB_train as pb |
|
|
| set_seed(42) |
| OUT = sp.OUT |
| MODEL_Y = sp.MODEL_Y |
| HELDOUT_NAMES = sp.HELDOUT_NAMES |
|
|
| def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()} |
|
|
| def cos_sd(s1, s2): |
| keys = sorted(s1.keys()) |
| a = torch.cat([s1[k].float().flatten() for k in keys]) |
| b = torch.cat([s2[k].float().flatten() for k in keys]) |
| return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() |
|
|
| def save_adapter(sd, dirname): |
| d = OUT/"Y_pred_phaseB"/dirname |
| d.mkdir(parents=True, exist_ok=True) |
| src = OUT/"Y"/sp.ANCHOR_NAMES[0] |
| shutil.copy(src/"adapter_config.json", d/"adapter_config.json") |
| save_file({k:v.to(torch.bfloat16) for k,v in sd.items()}, str(d/"adapter_model.safetensors")) |
| return d |
|
|
| @torch.no_grad() |
| def eval_adapter(adir, eval_ds, labels, tok, max_n=300): |
| base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda() |
| m = base if adir is None else PeftModel.from_pretrained(base, str(adir)) |
| acc = sp.eval_classification(m, tok, eval_ds, labels, max_n=max_n) |
| del m, base; gc.collect(); torch.cuda.empty_cache() |
| return acc |
|
|
| def main(): |
| |
| new_anchors_avail = json.loads((OUT/"new_anchors.json").read_text()) |
| full_anchors = sp.ANCHOR_NAMES + new_anchors_avail |
| print(f"Total anchors available: {len(full_anchors)}") |
|
|
| X_full = [load_sd(OUT/"X"/n) for n in full_anchors] |
| Y_full = [load_sd(OUT/"Y"/n) for n in full_anchors] |
|
|
| |
| splits = {25: (sp.ANCHOR_NAMES, X_full[:25], Y_full[:25]), |
| len(full_anchors): (full_anchors, X_full, Y_full)} |
|
|
| tokY = AutoTokenizer.from_pretrained(MODEL_Y) |
| if tokY.pad_token is None: tokY.pad_token = tokY.eos_token |
| tokY.padding_side = "left" |
|
|
| results = {"per_task": {}, "config": {"anchor_pool_25": sp.ANCHOR_NAMES, "anchor_pool_full": full_anchors, |
| "heldout": HELDOUT_NAMES}} |
| for t_name in HELDOUT_NAMES: |
| print(f"\n=== {t_name} ===") |
| X_target = load_sd(OUT/"X"/t_name) |
| Y_oracle = load_sd(OUT/"Y"/t_name) |
| _, eval_ds, labels = sp.build_task(t_name, n_train=10, n_eval=sp.EVAL_PER_TASK) |
|
|
| task_res = {} |
| |
| task_res["base_Y"] = eval_adapter(None, eval_ds, labels, tokY) |
| task_res["oracle_Y"] = eval_adapter(OUT/"Y"/t_name, eval_ds, labels, tokY) |
| print(f" base={task_res['base_Y']:.4f} oracle={task_res['oracle_Y']:.4f}") |
|
|
| for N, (names, Xa, Ya) in splits.items(): |
| print(f"--- N={N} anchors ---") |
| preds = {} |
| preds[f"N{N}_mean"] = sp.mean_pred(Ya) |
| preds[f"N{N}_global_ridge"], _ = sp.global_ridge_pred(Xa, Ya, X_target, ridge=1e-3) |
| preds[f"N{N}_pertensor_ridge"] = sp.pertensor_ridge_pred(Xa, Ya, X_target, ridge=1e-3) |
| preds[f"N{N}_pertensor_pca"] = sp.pertensor_pca_linear_pred(Xa, Ya, X_target, k_lat=8, ridge=1e-2) |
| preds[f"N{N}_procrustes"] = pa.procrustes_pred(Xa, Ya, X_target) |
| for k in [5, 8, 12]: |
| if k > N: continue |
| p, idx, _ = pa.topk_global_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_global_ridge"] = p |
| if N==max(splits): print(f" topk{k} selected: {[names[i] for i in idx]}") |
| p, _, _ = pa.topk_pertensor_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_pertensor_ridge"] = p |
| |
| preds[f"N{N}_pertensor_mlp"], _ = sp.pertensor_pca_mlp_pred(Xa, Ya, X_target, k_lat=8, epochs=300, lr=1e-3) |
| if N >= 12: |
| p, _, _ = pa.topk_pertensor_mlp(Xa, Ya, X_target, k=12, k_lat=6, epochs=300) |
| preds[f"N{N}_topk12_pertensor_mlp"] = p |
|
|
| for n, sd in preds.items(): |
| d = save_adapter(sd, f"{t_name}_{n}") |
| acc = eval_adapter(d, eval_ds, labels, tokY) |
| task_res[n] = acc |
| task_res[f"{n}__cos"] = cos_sd(sd, Y_oracle) |
| print(f" {n:38s} cos={task_res[n+'__cos']:.4f} acc={acc:.4f}") |
|
|
| results["per_task"][t_name] = task_res |
|
|
| |
| all_methods = sorted({k for tres in results["per_task"].values() for k in tres.keys() if not k.endswith("__cos")}) |
| avg = {m: float(np.mean([results["per_task"][t][m] for t in HELDOUT_NAMES if m in results["per_task"][t]])) for m in all_methods} |
| results["avg"] = avg |
| print("\n=== AVG ===") |
| for m in sorted(avg, key=lambda x: -avg[x]): |
| print(f" {m:42s} {avg[m]:.4f}") |
| (OUT/"results_phaseB.json").write_text(json.dumps(results, indent=2, default=float)) |
|
|
| if __name__ == "__main__": |
| main() |
|
|