File size: 5,259 Bytes
97957a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
PHASE B eval: full sweep on N ∈ {25, 60} × all mapping methods × 5 held-out tasks.
Reuses adapters trained by phaseB_train.py.
"""
import os, sys, json, gc, shutil, re, collections
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from peft import PeftModel
from safetensors.torch import load_file, save_file

sys.path.insert(0, "/app")
import scaled_pipeline as sp
import phaseA_new_methods as pa
import phaseB_train as pb

set_seed(42)
OUT = sp.OUT
MODEL_Y = sp.MODEL_Y
HELDOUT_NAMES = sp.HELDOUT_NAMES

def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}

def cos_sd(s1, s2):
    keys = sorted(s1.keys())
    a = torch.cat([s1[k].float().flatten() for k in keys])
    b = torch.cat([s2[k].float().flatten() for k in keys])
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

def save_adapter(sd, dirname):
    d = OUT/"Y_pred_phaseB"/dirname
    d.mkdir(parents=True, exist_ok=True)
    src = OUT/"Y"/sp.ANCHOR_NAMES[0]
    shutil.copy(src/"adapter_config.json", d/"adapter_config.json")
    save_file({k:v.to(torch.bfloat16) for k,v in sd.items()}, str(d/"adapter_model.safetensors"))
    return d

@torch.no_grad()
def eval_adapter(adir, eval_ds, labels, tok, max_n=300):
    base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
    m = base if adir is None else PeftModel.from_pretrained(base, str(adir))
    acc = sp.eval_classification(m, tok, eval_ds, labels, max_n=max_n)
    del m, base; gc.collect(); torch.cuda.empty_cache()
    return acc

def main():
    # Load adapters
    new_anchors_avail = json.loads((OUT/"new_anchors.json").read_text())
    full_anchors = sp.ANCHOR_NAMES + new_anchors_avail
    print(f"Total anchors available: {len(full_anchors)}")

    X_full = [load_sd(OUT/"X"/n) for n in full_anchors]
    Y_full = [load_sd(OUT/"Y"/n) for n in full_anchors]

    # for the 25-anchor baseline, just slice
    splits = {25: (sp.ANCHOR_NAMES, X_full[:25], Y_full[:25]),
              len(full_anchors): (full_anchors, X_full, Y_full)}

    tokY = AutoTokenizer.from_pretrained(MODEL_Y)
    if tokY.pad_token is None: tokY.pad_token = tokY.eos_token
    tokY.padding_side = "left"

    results = {"per_task": {}, "config": {"anchor_pool_25": sp.ANCHOR_NAMES, "anchor_pool_full": full_anchors,
                                         "heldout": HELDOUT_NAMES}}
    for t_name in HELDOUT_NAMES:
        print(f"\n=== {t_name} ===")
        X_target = load_sd(OUT/"X"/t_name)
        Y_oracle = load_sd(OUT/"Y"/t_name)
        _, eval_ds, labels = sp.build_task(t_name, n_train=10, n_eval=sp.EVAL_PER_TASK)

        task_res = {}
        # base + oracle once
        task_res["base_Y"] = eval_adapter(None, eval_ds, labels, tokY)
        task_res["oracle_Y"] = eval_adapter(OUT/"Y"/t_name, eval_ds, labels, tokY)
        print(f"  base={task_res['base_Y']:.4f}  oracle={task_res['oracle_Y']:.4f}")

        for N, (names, Xa, Ya) in splits.items():
            print(f"--- N={N} anchors ---")
            preds = {}
            preds[f"N{N}_mean"] = sp.mean_pred(Ya)
            preds[f"N{N}_global_ridge"], _ = sp.global_ridge_pred(Xa, Ya, X_target, ridge=1e-3)
            preds[f"N{N}_pertensor_ridge"] = sp.pertensor_ridge_pred(Xa, Ya, X_target, ridge=1e-3)
            preds[f"N{N}_pertensor_pca"] = sp.pertensor_pca_linear_pred(Xa, Ya, X_target, k_lat=8, ridge=1e-2)
            preds[f"N{N}_procrustes"] = pa.procrustes_pred(Xa, Ya, X_target)
            for k in [5, 8, 12]:
                if k > N: continue
                p, idx, _ = pa.topk_global_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_global_ridge"] = p
                if N==max(splits): print(f"  topk{k} selected: {[names[i] for i in idx]}")
                p, _, _ = pa.topk_pertensor_ridge(Xa, Ya, X_target, k=k); preds[f"N{N}_topk{k}_pertensor_ridge"] = p
            # MLPs (slower)
            preds[f"N{N}_pertensor_mlp"], _ = sp.pertensor_pca_mlp_pred(Xa, Ya, X_target, k_lat=8, epochs=300, lr=1e-3)
            if N >= 12:
                p, _, _ = pa.topk_pertensor_mlp(Xa, Ya, X_target, k=12, k_lat=6, epochs=300)
                preds[f"N{N}_topk12_pertensor_mlp"] = p

            for n, sd in preds.items():
                d = save_adapter(sd, f"{t_name}_{n}")
                acc = eval_adapter(d, eval_ds, labels, tokY)
                task_res[n] = acc
                task_res[f"{n}__cos"] = cos_sd(sd, Y_oracle)
                print(f"  {n:38s} cos={task_res[n+'__cos']:.4f} acc={acc:.4f}")

        results["per_task"][t_name] = task_res

    # aggregate
    all_methods = sorted({k for tres in results["per_task"].values() for k in tres.keys() if not k.endswith("__cos")})
    avg = {m: float(np.mean([results["per_task"][t][m] for t in HELDOUT_NAMES if m in results["per_task"][t]])) for m in all_methods}
    results["avg"] = avg
    print("\n=== AVG ===")
    for m in sorted(avg, key=lambda x: -avg[x]):
        print(f"  {m:42s} {avg[m]:.4f}")
    (OUT/"results_phaseB.json").write_text(json.dumps(results, indent=2, default=float))

if __name__ == "__main__":
    main()