File size: 9,716 Bytes
e6bc942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python
"""Token-level (per-residue) evaluation for the bidirectional ProGen2 encoder.

The per-residue counterpart to eval_protein.py: the only evals that exercise the
encoder's CONTEXTUAL TOKEN representations (the proposal's token-level promise)
rather than a pooled vector. ProGen2 tokenizes one amino acid -> one token
(verified), so token i aligns to residue i with no special-token offset.

Tasks (--task):
  ss3      3-state secondary structure  AI4Protein/ssp_q3 (HF)         TAPE
  ptm      phosphosite PTM (binary)     PhosphositePTM.*.csv           ProteinBERT
  disorder intrinsic disorder (binary)  disorder_secondary_structure.* ProteinBERT

PTM and disorder are distributed only as CSVs in the ProteinBERT data repo
(github.com/Brandes-Lab/proteinbert_data_files), staged locally under --data-dir.
Their per-residue label is a contiguous digit string (one char per residue); SS3's
label is a numpy-printed array of space-separated ints.

Protocol: freeze encoder, take last-layer per-residue hidden states, fit a closed-form
one-hot ridge classifier over residues. Report per-residue accuracy; for the imbalanced
binary tasks (ptm, disorder) also report majority-class baseline and ROC-AUC (the honest
metric under class imbalance). Normal equations (A=X'X, B=X'Y) are accumulated
incrementally so residue features never all sit in memory at once.

Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations

import argparse
import os
import re
import sys

import torch

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional  # noqa: E402

TOKEN_TASKS = {
    "ss3": {"source": "hf", "dataset": "AI4Protein/ssp_q3", "seq": "aa_seq",
            "label": "label", "fmt": "tokens", "num_classes": 3, "binary": False},
    "ptm": {"source": "csv", "base": "PhosphositePTM", "seq": "seq",
            "label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
    "disorder": {"source": "csv", "base": "disorder_secondary_structure", "seq": "seq",
                 "label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
}


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model-name", default="hugohrban/progen2-base")
    p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
    p.add_argument("--task", default="ss3", choices=list(TOKEN_TASKS))
    p.add_argument("--data-dir", default="./pbert_data",
                   help="dir holding the ProteinBERT CSVs (for ptm/disorder)")
    p.add_argument("--max-length", type=int, default=512)
    p.add_argument("--batch-size", type=int, default=16)
    p.add_argument("--max-train", type=int, default=6000, help="cap train seqs (0 = full)")
    p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full)")
    p.add_argument("--ridge-alpha", type=float, default=10.0)
    return p.parse_args()


def parse_labels(s, fmt):
    if fmt == "chars":
        return [int(c) for c in str(s).strip() if c.isdigit()]
    return [int(x) for x in re.findall(r"-?\d+", str(s))]   # "tokens"


def load_split(cfg, split, cap, data_dir):
    """Return (list[str] seqs, list[list[int]] per-residue labels)."""
    if cfg["source"] == "hf":
        from datasets import load_dataset
        ds = load_dataset(cfg["dataset"], split=split)
        if cap and len(ds) > cap:
            ds = ds.shuffle(seed=0).select(range(cap))
        seqs = list(ds[cfg["seq"]])
        labs = [parse_labels(x, cfg["fmt"]) for x in ds[cfg["label"]]]
    else:
        import csv as _csv
        path = os.path.join(data_dir, f"{cfg['base']}.{split}.csv")
        seqs, labs = [], []
        with open(path) as fh:
            r = _csv.DictReader(fh)
            for i, row in enumerate(r):
                if cap and i >= cap:
                    break
                seqs.append(row[cfg["seq"]])
                labs.append(parse_labels(row[cfg["label"]], cfg["fmt"]))
    return seqs, labs


def build_base(model_name, device):
    from transformers import AutoModelForCausalLM
    import transformers.modeling_utils as _mu
    if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
        _mu.PreTrainedModel.all_tied_weights_keys = {}
    base = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )
    n = make_bidirectional(base)
    print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
    return base.to(device)


@torch.no_grad()
def residue_features(model, tok, seqs, labels, device, max_length, batch_size):
    """Yield (feats (R,H) float64, labs (R,) long) per batch — labelled residues only,
    aligned token i <-> residue i (no special tokens added)."""
    for i in range(0, len(seqs), batch_size):
        chunk_s = seqs[i:i + batch_size]
        chunk_l = labels[i:i + batch_size]
        enc = tok(chunk_s, padding=True, truncation=True, max_length=max_length,
                  add_special_tokens=False, return_tensors="pt").to(device)
        h = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                  output_hidden_states=True).hidden_states[-1]   # (B, T, H)
        feats, labs = [], []
        for b in range(h.size(0)):
            n = int(enc["attention_mask"][b].sum())
            n = min(n, len(chunk_l[b]))
            if n <= 0:
                continue
            feats.append(h[b, :n].float().cpu())
            labs.append(torch.tensor(chunk_l[b][:n], dtype=torch.long))
        if feats:
            yield torch.cat(feats, 0).double(), torch.cat(labs, 0)


def roc_auc(scores, labels):
    # Rank-based (Mann-Whitney U) AUC for binary labels; no sklearn dependency.
    s = torch.as_tensor(scores, dtype=torch.float64)
    y = torch.as_tensor(labels, dtype=torch.long)
    n_pos = int((y == 1).sum()); n_neg = int((y == 0).sum())
    if n_pos == 0 or n_neg == 0:
        return float("nan")
    order = s.argsort()
    ranks = torch.empty_like(s)
    ranks[order] = torch.arange(1, s.numel() + 1, dtype=torch.float64)  # ties ignored (ok for probe scores)
    sum_pos = float(ranks[y == 1].sum())
    return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)


def evaluate(tag, model, tok, cfg, tr, te, device, args):
    model.eval()
    C = cfg["num_classes"]
    A = B = None
    n_res = 0
    for feats, labs in residue_features(model, tok, tr[0], tr[1], device,
                                        args.max_length, args.batch_size):
        if A is None:
            D = feats.size(1) + 1
            A = torch.zeros(D, D, dtype=torch.float64)
            B = torch.zeros(D, C, dtype=torch.float64)
        X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
        Y = torch.zeros(feats.size(0), C, dtype=torch.float64)
        Y[torch.arange(feats.size(0)), labs] = 1.0
        A += X.t() @ X
        B += X.t() @ Y
        n_res += feats.size(0)
    W = torch.linalg.solve(A + args.ridge_alpha * torch.eye(A.size(0), dtype=torch.float64), B)

    correct = total = 0
    cls_count = torch.zeros(C, dtype=torch.long)
    scores, ys = [], []
    for feats, labs in residue_features(model, tok, te[0], te[1], device,
                                        args.max_length, args.batch_size):
        X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
        logits = X @ W
        pred = logits.argmax(1)
        correct += int((pred == labs).sum())
        total += labs.numel()
        cls_count += torch.bincount(labs, minlength=C)
        if cfg["binary"]:
            scores.append((logits[:, 1] - logits[:, 0]).cpu())
            ys.append(labs.cpu())
    acc = correct / max(total, 1)
    auc = roc_auc(torch.cat(scores), torch.cat(ys)) if cfg["binary"] else None
    msg = f"[{tag}] {args.task} acc={acc:.4f} (train_res={n_res}, eval_res={total})"
    if cfg["binary"]:
        majority = float(cls_count.max()) / max(total, 1)
        pos_frac = float(cls_count[1]) / max(total, 1)
        msg += f" | majority={majority:.4f} pos_frac={pos_frac:.4f} AUC={auc:.4f}"
    print(msg, flush=True)
    return acc, auc


def main():
    args = parse_args()
    from transformers import AutoTokenizer

    cfg = TOKEN_TASKS[args.task]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)

    tr = load_split(cfg, "train", args.max_train, args.data_dir)
    te = load_split(cfg, "test", args.max_test, args.data_dir)
    print(f"[data] {args.task} train_seqs={len(tr[0])} test_seqs={len(te[0])} "
          f"classes={cfg['num_classes']}", flush=True)

    base = build_base(args.model_name, device)
    base_acc, base_auc = evaluate("baseline (frozen, bidir)", base, tok, cfg, tr, te, device, args)

    trained_acc = trained_auc = None
    if args.adapter:
        from peft import PeftModel
        trained = PeftModel.from_pretrained(base, args.adapter).to(device)
        trained_acc, trained_auc = evaluate("trained (adapter)", trained, tok, cfg, tr, te,
                                            device, args)

    print(f"\n=== SUMMARY ({args.task}) ===", flush=True)
    metric, b, t = ("AUC", base_auc, trained_auc) if cfg["binary"] else ("acc", base_acc, trained_acc)
    line = f"  {args.task:20s} [{metric}] baseline={b:.4f}"
    if t is not None:
        line += f"  trained={t:.4f}  delta={t - b:+.4f}"
    print(line, flush=True)


if __name__ == "__main__":
    main()