File size: 10,004 Bytes
e6bc942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#!/usr/bin/env python
"""TAPE downstream evaluation for the bidirectional ProGen2 encoder.

Protein analog of eval_sts.py: freeze the encoder, mean-pool a fixed embedding
per sequence, fit a linear (ridge) probe on the train split, and report Spearman
on the test split — the standard frozen-features protocol for judging embedding
quality. Reports the FROZEN baseline (bidirectional ProGen2, no adaptation) vs the
TRAINED adapter, so we can see whether MNTP/SimCSE actually improved the encoder.

Multi-task: pass --tasks stability,fluorescence,homology to harden the verdict
across more than one downstream task. Pass --max-train 0 / --max-test 0 to use the
FULL splits (no cap).

Two probe kinds (both frozen mean-pooled embedding + closed-form linear probe):
  reg  -> ridge regression, reported as Spearman ρ
  clf  -> one-hot ridge classifier (argmax), reported as top-1 accuracy

Tasks (sequence-level; TAPE = Rao 2019, ProteinBERT = Brandes 2022):
  stability     -> AI4Protein/TAPE_Stability          reg  TAPE
  fluorescence  -> AI4Protein/TAPE_Fluorescence       reg  TAPE
  homology      -> GleghornLab/bom_remote_homology    clf  TAPE (remote homology)
  fold          -> GleghornLab/fold_prediction        clf  TAPE (fold classification)
  signalpeptide -> GrimSqueaker/SignalP_Binary        clf  ProteinBERT
  neuropeptide  -> GrimSqueaker/ProFET_NP_SP_Cleaved  clf  ProteinBERT (cleaved precursor)

(Token-level PTM and disorder live in eval_token.py — they are per-residue and only
distributed as CSVs in the ProteinBERT data repo, not on HF.)

Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations

import argparse
import os
import sys

import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional, mean_pool  # noqa: E402


TASKS = {
    "stability": {"dataset": "AI4Protein/TAPE_Stability", "kind": "reg"},
    "fluorescence": {"dataset": "AI4Protein/TAPE_Fluorescence", "kind": "reg"},
    "homology": {"dataset": "GleghornLab/bom_remote_homology", "kind": "clf",
                 "test_split": "test"},
    "fold": {"dataset": "GleghornLab/fold_prediction", "kind": "clf",
             "test_split": "test"},
    # ProteinBERT sequence-level benchmarks (Brandes 2022), via HF mirrors:
    "signalpeptide": {"dataset": "GrimSqueaker/SignalP_Binary", "kind": "clf",
                      "test_split": "test"},
    "neuropeptide": {"dataset": "GrimSqueaker/ProFET_NP_SP_Cleaved", "kind": "clf",
                     "test_split": "test"},
}
SEQ_COLS = ("aa_seq", "seqs", "seq", "sequence", "primary")
LABEL_COLS = ("label", "labels", "target", "log_fluorescence", "stability_score")


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model-name", default="hugohrban/progen2-base")
    p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
    p.add_argument("--tasks", default="stability,fluorescence,homology",
                   help="comma-separated subset of: " + ",".join(TASKS))
    p.add_argument("--max-length", type=int, default=512)
    p.add_argument("--batch-size", type=int, default=32)
    p.add_argument("--max-train", type=int, default=0, help="cap train seqs (0 = full split)")
    p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full split)")
    p.add_argument("--ridge-alpha", type=float, default=10.0)
    return p.parse_args()


def spearman(a, b):
    try:
        from scipy.stats import spearmanr
        return float(spearmanr(a, b).correlation)
    except Exception:
        ta = torch.tensor(a, dtype=torch.float64); tb = torch.tensor(b, dtype=torch.float64)
        ra = ta.argsort().argsort().double(); rb = tb.argsort().argsort().double()
        ra = ra - ra.mean(); rb = rb - rb.mean()
        return float((ra @ rb) / (ra.norm() * rb.norm() + 1e-12))


@torch.no_grad()
def encode(model, tok, seqs, device, max_length, batch_size):
    embs = []
    for i in range(0, len(seqs), batch_size):
        enc = tok(seqs[i:i + batch_size], padding=True, truncation=True,
                  max_length=max_length, return_tensors="pt").to(device)
        out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                    output_hidden_states=True)
        pooled = mean_pool(out.hidden_states[-1], enc["attention_mask"])
        embs.append(F.normalize(pooled.float(), dim=-1).cpu())
    return torch.cat(embs, 0)


def ridge_probe(xtr, ytr, xte, alpha):
    # Closed-form ridge (no sklearn dependency): w = (X'X + aI)^-1 X'y, with bias.
    Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
    Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
    ytr = torch.tensor(ytr, dtype=torch.float64).unsqueeze(1)
    d = Xtr.size(1)
    A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
    w = torch.linalg.solve(A, Xtr.t() @ ytr)
    return (Xte @ w).squeeze(1).tolist()


def clf_probe(xtr, ytr, xte, alpha, num_classes):
    # Closed-form one-hot ridge classifier: W = (X'X + aI)^-1 X'Y_onehot, argmax.
    # A legitimate linear probe, dependency-free and consistent with ridge_probe.
    Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
    Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
    yt = torch.tensor(ytr, dtype=torch.long)
    Y = torch.zeros(Xtr.size(0), num_classes, dtype=torch.float64)
    Y[torch.arange(Xtr.size(0)), yt] = 1.0
    d = Xtr.size(1)
    A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
    W = torch.linalg.solve(A, Xtr.t() @ Y)          # (d, C)
    return (Xte @ W).argmax(1).tolist()


def pairwise_cos(xte, cap=2000):
    # O(n^2) — subsample so a big test split doesn't blow up memory.
    x = xte[:cap]
    return float((x @ x.t())[~torch.eye(x.size(0), dtype=torch.bool)].mean())


def build_base(model_name, device):
    from transformers import AutoModelForCausalLM
    import transformers.modeling_utils as _mu
    if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
        _mu.PreTrainedModel.all_tied_weights_keys = {}
    base = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )
    n = make_bidirectional(base)
    print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
    return base.to(device)


def _pick(cols, candidates, what):
    for c in candidates:
        if c in cols:
            return c
    raise KeyError(f"no {what} column in {cols} (tried {candidates})")


def load_split(dataset, split, cap, cast=float):
    from datasets import load_dataset
    try:
        ds = load_dataset(dataset, split=split)
    except Exception:
        alt = {"test": "valid", "valid": "test"}.get(split, split)
        ds = load_dataset(dataset, split=alt)
    seq_c = _pick(ds.column_names, SEQ_COLS, "sequence")
    lab_c = _pick(ds.column_names, LABEL_COLS, "label")
    if cap and len(ds) > cap:
        ds = ds.shuffle(seed=0).select(range(cap))
    return list(ds[seq_c]), [cast(x) for x in ds[lab_c]]


def main():
    args = parse_args()
    from transformers import AutoTokenizer

    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
    for t in tasks:
        if t not in TASKS:
            sys.exit(f"unknown task '{t}'; choose from {list(TASKS)}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)

    # Load every task's splits up front so we encode each model once per task.
    data = {}
    for t in tasks:
        meta = TASKS[t]
        d, kind = meta["dataset"], meta["kind"]
        cast = int if kind == "clf" else float
        tr_seq, tr_y = load_split(d, "train", args.max_train, cast)
        te_seq, te_y = load_split(d, meta.get("test_split", "test"), args.max_test, cast)
        nc = (max(tr_y + te_y) + 1) if kind == "clf" else None
        data[t] = (tr_seq, tr_y, te_seq, te_y, kind, nc)
        extra = f" classes={nc}" if nc else ""
        print(f"[data] {t} ({kind}): train={len(tr_seq)} test={len(te_seq)}{extra}", flush=True)

    base = build_base(args.model_name, device)

    def run(model, label):
        out = {}
        for t in tasks:
            tr_seq, tr_y, te_seq, te_y, kind, nc = data[t]
            xtr = encode(model, tok, tr_seq, device, args.max_length, args.batch_size)
            xte = encode(model, tok, te_seq, device, args.max_length, args.batch_size)
            if kind == "clf":
                pred = clf_probe(xtr, tr_y, xte, args.ridge_alpha, nc)
                score = float(sum(int(p == y) for p, y in zip(pred, te_y)) / len(te_y))
                metric = "Acc"
            else:
                pred = ridge_probe(xtr, tr_y, xte, args.ridge_alpha)
                score = spearman(pred, te_y)
                metric = "Spearman"
            print(f"[{label}] {t} {metric}={score:.4f} (probe on {len(tr_seq)} train, "
                  f"eval {len(te_seq)} test) | mean_pairwise_cos={pairwise_cos(xte):.4f}",
                  flush=True)
            out[t] = (metric, score)
        return out

    base.eval()
    base_s = run(base, "baseline (frozen, bidir)")

    trained_s = {}
    if args.adapter:
        from peft import PeftModel
        trained = PeftModel.from_pretrained(base, args.adapter).to(device)
        trained.eval()
        trained_s = run(trained, "trained (adapter)")

    print("\n=== SUMMARY ===", flush=True)
    for t in tasks:
        metric, b = base_s[t]
        line = f"  {t:14s} [{metric}] baseline={b:.4f}"
        if t in trained_s:
            tr = trained_s[t][1]
            line += f"  trained={tr:.4f}  delta={tr - b:+.4f}"
        print(line, flush=True)


if __name__ == "__main__":
    main()