cross-model-lora-prediction / phaseB_eval.py
Samarth0710's picture
Upload phaseB_eval.py with huggingface_hub
97957a4 verified
"""
PHASE B eval: full sweep on N ∈ {25, 60} × all mapping methods × 5 held-out tasks.
Reuses adapters trained by phaseB_train.py.
"""
import os, sys, json, gc, shutil, re, collections
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from peft import PeftModel
from safetensors.torch import load_file, save_file
sys.path.insert(0, "/app")
import scaled_pipeline as sp
import phaseA_new_methods as pa
import phaseB_train as pb
set_seed(42)
OUT = sp.OUT
MODEL_Y = sp.MODEL_Y
HELDOUT_NAMES = sp.HELDOUT_NAMES
def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}
def cos_sd(s1, s2):
keys = sorted(s1.keys())
a = torch.cat([s1[k].float().flatten() for k in keys])
b = torch.cat([s2[k].float().flatten() for k in keys])
return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def save_adapter(sd, dirname):
d = OUT/"Y_pred_phaseB"/dirname
d.mkdir(parents=True, exist_ok=True)
src = OUT/"Y"/sp.ANCHOR_NAMES[0]
shutil.copy(src/"adapter_config.json", d/"adapter_config.json")
save_file({k:v.to(torch.bfloat16) for k,v in sd.items()}, str(d/"adapter_model.safetensors"))
return d
@torch.no_grad()
def eval_adapter(adir, eval_ds, labels, tok, max_n=300):
base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
m = base if adir is None else PeftModel.from_pretrained(base, str(adir))
acc = sp.eval_classification(m, tok, eval_ds, labels, max_n=max_n)
del m, base; gc.collect(); torch.cuda.empty_cache()
return acc
def main():
# Load adapters
new_anchors_avail = json.loads((OUT/"new_anchors.json").read_text())
full_anchors = sp.ANCHOR_NAMES + new_anchors_avail
print(f"Total anchors available: {len(full_anchors)}")
X_full = [load_sd(OUT/"X"/n) for n in full_anchors]
Y_full = [load_sd(OUT/"Y"/n) for n in full_anchors]
# for the 25-anchor baseline, just slice
splits = {25: (sp.ANCHOR_NAMES, X_full[:25], Y_full[:25]),
len(full_anchors): (full_anchors, X_full, Y_full)}
tokY = AutoTokenizer.from_pretrained(MODEL_Y)
if tokY.pad_token is None: tokY.pad_token = tokY.eos_token
tokY.padding_side = "left"
results = {"per_task": {}, "config": {"anchor_pool_25": sp.ANCHOR_NAMES, "anchor_pool_full": full_anchors,
"heldout": HELDOUT_NAMES}}
for t_name in HELDOUT_NAMES:
print(f"\n=== {t_name} ===")
X_target = load_sd(OUT/"X"/t_name)
Y_oracle = load_sd(OUT/"Y"/t_name)
_, eval_ds, labels = sp.build_task(t_name, n_train=10, n_eval=sp.EVAL_PER_TASK)
task_res = {}
# base + oracle once
task_res["base_Y"] = eval_adapter(None, eval_ds, labels, tokY)
task_res["oracle_Y"] = eval_adapter(OUT/"Y"/t_name, eval_ds, labels, tokY)
print(f" base={task_res['base_Y']:.4f} oracle={task_res['oracle_Y']:.4f}")
for N, (names, Xa, Ya) in splits.items():
print(f"--- N={N} anchors ---")
preds = {}
preds[f"N{N}_mean"] = sp.mean_pred(Ya)
preds[f"N{N}_global_ridge"], _ = sp.global_ridge_pred(Xa, Ya, X_target, ridge=1e-3)
preds[f"N{N}_pertensor_ridge"] = sp.pertensor_ridge_pred(Xa, Ya, X_target, ridge=1e-3)
preds[f"N{N}_pertensor_pca"] = sp.pertensor_pca_linear_pred(Xa, Ya, X_target, k_lat=8, ridge=1e-2)
preds[f"N{N}_procrustes"] = pa.procrustes_pred(Xa, Ya, X_target)
for k in [5, 8, 12]:
if k > N: continue
p, idx, _ = pa.topk_global_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_global_ridge"] = p
if N==max(splits): print(f" topk{k} selected: {[names[i] for i in idx]}")
p, _, _ = pa.topk_pertensor_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_pertensor_ridge"] = p
# MLPs (slower)
preds[f"N{N}_pertensor_mlp"], _ = sp.pertensor_pca_mlp_pred(Xa, Ya, X_target, k_lat=8, epochs=300, lr=1e-3)
if N >= 12:
p, _, _ = pa.topk_pertensor_mlp(Xa, Ya, X_target, k=12, k_lat=6, epochs=300)
preds[f"N{N}_topk12_pertensor_mlp"] = p
for n, sd in preds.items():
d = save_adapter(sd, f"{t_name}_{n}")
acc = eval_adapter(d, eval_ds, labels, tokY)
task_res[n] = acc
task_res[f"{n}__cos"] = cos_sd(sd, Y_oracle)
print(f" {n:38s} cos={task_res[n+'__cos']:.4f} acc={acc:.4f}")
results["per_task"][t_name] = task_res
# aggregate
all_methods = sorted({k for tres in results["per_task"].values() for k in tres.keys() if not k.endswith("__cos")})
avg = {m: float(np.mean([results["per_task"][t][m] for t in HELDOUT_NAMES if m in results["per_task"][t]])) for m in all_methods}
results["avg"] = avg
print("\n=== AVG ===")
for m in sorted(avg, key=lambda x: -avg[x]):
print(f" {m:42s} {avg[m]:.4f}")
(OUT/"results_phaseB.json").write_text(json.dumps(results, indent=2, default=float))
if __name__ == "__main__":
main()