""" PHASE B eval: full sweep on N ∈ {25, 60} × all mapping methods × 5 held-out tasks. Reuses adapters trained by phaseB_train.py. """ import os, sys, json, gc, shutil, re, collections from pathlib import Path import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from peft import PeftModel from safetensors.torch import load_file, save_file sys.path.insert(0, "/app") import scaled_pipeline as sp import phaseA_new_methods as pa import phaseB_train as pb set_seed(42) OUT = sp.OUT MODEL_Y = sp.MODEL_Y HELDOUT_NAMES = sp.HELDOUT_NAMES def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()} def cos_sd(s1, s2): keys = sorted(s1.keys()) a = torch.cat([s1[k].float().flatten() for k in keys]) b = torch.cat([s2[k].float().flatten() for k in keys]) return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def save_adapter(sd, dirname): d = OUT/"Y_pred_phaseB"/dirname d.mkdir(parents=True, exist_ok=True) src = OUT/"Y"/sp.ANCHOR_NAMES[0] shutil.copy(src/"adapter_config.json", d/"adapter_config.json") save_file({k:v.to(torch.bfloat16) for k,v in sd.items()}, str(d/"adapter_model.safetensors")) return d @torch.no_grad() def eval_adapter(adir, eval_ds, labels, tok, max_n=300): base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda() m = base if adir is None else PeftModel.from_pretrained(base, str(adir)) acc = sp.eval_classification(m, tok, eval_ds, labels, max_n=max_n) del m, base; gc.collect(); torch.cuda.empty_cache() return acc def main(): # Load adapters new_anchors_avail = json.loads((OUT/"new_anchors.json").read_text()) full_anchors = sp.ANCHOR_NAMES + new_anchors_avail print(f"Total anchors available: {len(full_anchors)}") X_full = [load_sd(OUT/"X"/n) for n in full_anchors] Y_full = [load_sd(OUT/"Y"/n) for n in full_anchors] # for the 25-anchor baseline, just slice splits = {25: (sp.ANCHOR_NAMES, X_full[:25], Y_full[:25]), len(full_anchors): (full_anchors, X_full, Y_full)} tokY = AutoTokenizer.from_pretrained(MODEL_Y) if tokY.pad_token is None: tokY.pad_token = tokY.eos_token tokY.padding_side = "left" results = {"per_task": {}, "config": {"anchor_pool_25": sp.ANCHOR_NAMES, "anchor_pool_full": full_anchors, "heldout": HELDOUT_NAMES}} for t_name in HELDOUT_NAMES: print(f"\n=== {t_name} ===") X_target = load_sd(OUT/"X"/t_name) Y_oracle = load_sd(OUT/"Y"/t_name) _, eval_ds, labels = sp.build_task(t_name, n_train=10, n_eval=sp.EVAL_PER_TASK) task_res = {} # base + oracle once task_res["base_Y"] = eval_adapter(None, eval_ds, labels, tokY) task_res["oracle_Y"] = eval_adapter(OUT/"Y"/t_name, eval_ds, labels, tokY) print(f" base={task_res['base_Y']:.4f} oracle={task_res['oracle_Y']:.4f}") for N, (names, Xa, Ya) in splits.items(): print(f"--- N={N} anchors ---") preds = {} preds[f"N{N}_mean"] = sp.mean_pred(Ya) preds[f"N{N}_global_ridge"], _ = sp.global_ridge_pred(Xa, Ya, X_target, ridge=1e-3) preds[f"N{N}_pertensor_ridge"] = sp.pertensor_ridge_pred(Xa, Ya, X_target, ridge=1e-3) preds[f"N{N}_pertensor_pca"] = sp.pertensor_pca_linear_pred(Xa, Ya, X_target, k_lat=8, ridge=1e-2) preds[f"N{N}_procrustes"] = pa.procrustes_pred(Xa, Ya, X_target) for k in [5, 8, 12]: if k > N: continue p, idx, _ = pa.topk_global_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_global_ridge"] = p if N==max(splits): print(f" topk{k} selected: {[names[i] for i in idx]}") p, _, _ = pa.topk_pertensor_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_pertensor_ridge"] = p # MLPs (slower) preds[f"N{N}_pertensor_mlp"], _ = sp.pertensor_pca_mlp_pred(Xa, Ya, X_target, k_lat=8, epochs=300, lr=1e-3) if N >= 12: p, _, _ = pa.topk_pertensor_mlp(Xa, Ya, X_target, k=12, k_lat=6, epochs=300) preds[f"N{N}_topk12_pertensor_mlp"] = p for n, sd in preds.items(): d = save_adapter(sd, f"{t_name}_{n}") acc = eval_adapter(d, eval_ds, labels, tokY) task_res[n] = acc task_res[f"{n}__cos"] = cos_sd(sd, Y_oracle) print(f" {n:38s} cos={task_res[n+'__cos']:.4f} acc={acc:.4f}") results["per_task"][t_name] = task_res # aggregate all_methods = sorted({k for tres in results["per_task"].values() for k in tres.keys() if not k.endswith("__cos")}) avg = {m: float(np.mean([results["per_task"][t][m] for t in HELDOUT_NAMES if m in results["per_task"][t]])) for m in all_methods} results["avg"] = avg print("\n=== AVG ===") for m in sorted(avg, key=lambda x: -avg[x]): print(f" {m:42s} {avg[m]:.4f}") (OUT/"results_phaseB.json").write_text(json.dumps(results, indent=2, default=float)) if __name__ == "__main__": main()