Joblib
ynuozhang commited on
Commit
62e6dc2
·
1 Parent(s): 470021d

add inference

Browse files
Files changed (1) hide show
  1. load.py +891 -0
load.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # peptiverse_infer.py
2
+ from __future__ import annotations
3
+
4
+ import csv, re, json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Tuple, Any, List
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import joblib
13
+ import xgboost as xgb
14
+
15
+ from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
16
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
+
18
+
19
+ # -----------------------------
20
+ # Manifest
21
+ # -----------------------------
22
+ @dataclass(frozen=True)
23
+ class BestRow:
24
+ property_key: str
25
+ best_wt: Optional[str]
26
+ best_smiles: Optional[str]
27
+ task_type: str # "Classifier" or "Regression"
28
+ thr_wt: Optional[float]
29
+ thr_smiles: Optional[float]
30
+
31
+
32
+ def _clean(s: str) -> str:
33
+ return (s or "").strip()
34
+
35
+ def _none_if_dash(s: str) -> Optional[str]:
36
+ s = _clean(s)
37
+ if s in {"", "-", "—", "NA", "N/A"}:
38
+ return None
39
+ return s
40
+
41
+ def _float_or_none(s: str) -> Optional[float]:
42
+ s = _clean(s)
43
+ if s in {"", "-", "—", "NA", "N/A"}:
44
+ return None
45
+ return float(s)
46
+
47
+ def normalize_property_key(name: str) -> str:
48
+ n = name.strip().lower()
49
+ n = re.sub(r"\s*\(.*?\)\s*", "", n)
50
+ n = n.replace("-", "_").replace(" ", "_")
51
+ if "permeability" in n and "pampa" not in n and "caco" not in n:
52
+ return "permeability_penetrance"
53
+ if n == "binding_affinity":
54
+ return "binding_affinity"
55
+ if n == "halflife":
56
+ return "half_life"
57
+ if n == "non_fouling":
58
+ return "nf"
59
+ return n
60
+
61
+ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
62
+ """
63
+ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
64
+ Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
65
+ """
66
+ p = Path(path)
67
+ out: Dict[str, BestRow] = {}
68
+
69
+ with p.open("r", newline="") as f:
70
+ reader = csv.reader(f)
71
+ header = None
72
+ for raw in reader:
73
+ if not raw or all(_clean(x) == "" for x in raw):
74
+ continue
75
+ while raw and _clean(raw[-1]) == "":
76
+ raw = raw[:-1]
77
+
78
+ if header is None:
79
+ header = [h.strip() for h in raw]
80
+ continue
81
+
82
+ if len(raw) < len(header):
83
+ raw = raw + [""] * (len(header) - len(raw))
84
+ rec = dict(zip(header, raw))
85
+
86
+ prop_raw = _clean(rec.get("Properties", ""))
87
+ if not prop_raw:
88
+ continue
89
+ prop_key = normalize_property_key(prop_raw)
90
+
91
+ row = BestRow(
92
+ property_key=prop_key,
93
+ best_wt=_none_if_dash(rec.get("Best_Model_WT", "")),
94
+ best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")),
95
+ task_type=_clean(rec.get("Type", "Classifier")),
96
+ thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
97
+ thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
98
+ )
99
+ out[prop_key] = row
100
+
101
+ return out
102
+
103
+
104
+ MODEL_ALIAS = {
105
+ "SVM": "svm_gpu",
106
+ "SVR": "svr",
107
+ "ENET": "enet_gpu",
108
+ "CNN": "cnn",
109
+ "MLP": "mlp",
110
+ "TRANSFORMER": "transformer",
111
+ "XGB": "xgb",
112
+ "XGB_REG": "xgb_reg",
113
+ "POOLED": "pooled",
114
+ "UNPOOLED": "unpooled"
115
+ }
116
+ def canon_model(label: Optional[str]) -> Optional[str]:
117
+ if label is None:
118
+ return None
119
+ k = label.strip().upper()
120
+ return MODEL_ALIAS.get(k, label.strip().lower())
121
+
122
+
123
+ # -----------------------------
124
+ # Generic artifact loading
125
+ # -----------------------------
126
+ def find_best_artifact(model_dir: Path) -> Path:
127
+ for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]:
128
+ hits = sorted(model_dir.glob(pat))
129
+ if hits:
130
+ return hits[0]
131
+ raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
132
+
133
+ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
134
+ art = find_best_artifact(model_dir)
135
+
136
+ if art.suffix == ".json":
137
+ booster = xgb.Booster()
138
+ print(str(art))
139
+ booster.load_model(str(art))
140
+ return "xgb", booster, art
141
+
142
+ if art.suffix == ".joblib":
143
+ obj = joblib.load(art)
144
+ return "joblib", obj, art
145
+
146
+ if art.suffix == ".pt":
147
+ ckpt = torch.load(art, map_location=device, weights_only=False)
148
+ return "torch_ckpt", ckpt, art
149
+
150
+ raise ValueError(f"Unknown artifact type: {art}")
151
+
152
+
153
+ # -----------------------------
154
+ # NN architectures
155
+ # -----------------------------
156
+ class MaskedMeanPool(nn.Module):
157
+ def forward(self, X, M): # X:(B,L,H), M:(B,L)
158
+ Mf = M.unsqueeze(-1).float()
159
+ denom = Mf.sum(dim=1).clamp(min=1.0)
160
+ return (X * Mf).sum(dim=1) / denom
161
+
162
+ class MLPHead(nn.Module):
163
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
164
+ super().__init__()
165
+ self.pool = MaskedMeanPool()
166
+ self.net = nn.Sequential(
167
+ nn.Linear(in_dim, hidden),
168
+ nn.GELU(),
169
+ nn.Dropout(dropout),
170
+ nn.Linear(hidden, 1),
171
+ )
172
+ def forward(self, X, M):
173
+ z = self.pool(X, M)
174
+ return self.net(z).squeeze(-1)
175
+
176
+ class CNNHead(nn.Module):
177
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
178
+ super().__init__()
179
+ blocks = []
180
+ ch = in_ch
181
+ for _ in range(layers):
182
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
183
+ nn.GELU(),
184
+ nn.Dropout(dropout)]
185
+ ch = c
186
+ self.conv = nn.Sequential(*blocks)
187
+ self.head = nn.Linear(c, 1)
188
+
189
+ def forward(self, X, M):
190
+ Xc = X.transpose(1, 2) # (B,H,L)
191
+ Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
192
+ Mf = M.unsqueeze(-1).float()
193
+ denom = Mf.sum(dim=1).clamp(min=1.0)
194
+ pooled = (Y * Mf).sum(dim=1) / denom
195
+ return self.head(pooled).squeeze(-1)
196
+
197
+ class TransformerHead(nn.Module):
198
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
199
+ super().__init__()
200
+ self.proj = nn.Linear(in_dim, d_model)
201
+ enc_layer = nn.TransformerEncoderLayer(
202
+ d_model=d_model, nhead=nhead, dim_feedforward=ff,
203
+ dropout=dropout, batch_first=True, activation="gelu"
204
+ )
205
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
206
+ self.head = nn.Linear(d_model, 1)
207
+
208
+ def forward(self, X, M):
209
+ pad_mask = ~M
210
+ Z = self.proj(X)
211
+ Z = self.enc(Z, src_key_padding_mask=pad_mask)
212
+ Mf = M.unsqueeze(-1).float()
213
+ denom = Mf.sum(dim=1).clamp(min=1.0)
214
+ pooled = (Z * Mf).sum(dim=1) / denom
215
+ return self.head(pooled).squeeze(-1)
216
+
217
+ def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
218
+ if model_name == "mlp":
219
+ return int(sd["net.0.weight"].shape[1])
220
+ if model_name == "cnn":
221
+ return int(sd["conv.0.weight"].shape[1])
222
+ if model_name == "transformer":
223
+ return int(sd["proj.weight"].shape[1])
224
+ raise ValueError(model_name)
225
+
226
+ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
227
+ params = ckpt["best_params"]
228
+ sd = ckpt["state_dict"]
229
+ in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
230
+ dropout = float(params.get("dropout", 0.1))
231
+
232
+ if model_name == "mlp":
233
+ model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
234
+ elif model_name == "cnn":
235
+ model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
236
+ layers=int(params["layers"]), dropout=dropout)
237
+ elif model_name == "transformer":
238
+ model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]),
239
+ layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout)
240
+ else:
241
+ raise ValueError(f"Unknown NN model_name={model_name}")
242
+
243
+ model.load_state_dict(sd)
244
+ model.to(device)
245
+ model.eval()
246
+ return model
247
+
248
+
249
+ # -----------------------------
250
+ # Binding affinity models
251
+ # -----------------------------
252
+ def affinity_to_class(y: float) -> int:
253
+ # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
254
+ if y >= 9.0: return 0
255
+ if y < 7.0: return 2
256
+ return 1
257
+
258
+ class CrossAttnPooled(nn.Module):
259
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
260
+ super().__init__()
261
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
262
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
263
+
264
+ self.layers = nn.ModuleList([])
265
+ for _ in range(n_layers):
266
+ self.layers.append(nn.ModuleDict({
267
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
268
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
269
+ "n1t": nn.LayerNorm(hidden),
270
+ "n2t": nn.LayerNorm(hidden),
271
+ "n1b": nn.LayerNorm(hidden),
272
+ "n2b": nn.LayerNorm(hidden),
273
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
274
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
275
+ }))
276
+
277
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
278
+ self.reg = nn.Linear(hidden, 1)
279
+ self.cls = nn.Linear(hidden, 3)
280
+
281
+ def forward(self, t_vec, b_vec):
282
+ t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
283
+ b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
284
+ for L in self.layers:
285
+ t_attn, _ = L["attn_tb"](t, b, b)
286
+ t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
287
+ t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
288
+
289
+ b_attn, _ = L["attn_bt"](b, t, t)
290
+ b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
291
+ b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
292
+
293
+ z = torch.cat([t[0], b[0]], dim=-1)
294
+ h = self.shared(z)
295
+ return self.reg(h).squeeze(-1), self.cls(h)
296
+
297
+ class CrossAttnUnpooled(nn.Module):
298
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
299
+ super().__init__()
300
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
301
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
302
+
303
+ self.layers = nn.ModuleList([])
304
+ for _ in range(n_layers):
305
+ self.layers.append(nn.ModuleDict({
306
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
307
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
308
+ "n1t": nn.LayerNorm(hidden),
309
+ "n2t": nn.LayerNorm(hidden),
310
+ "n1b": nn.LayerNorm(hidden),
311
+ "n2b": nn.LayerNorm(hidden),
312
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
313
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
314
+ }))
315
+
316
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
317
+ self.reg = nn.Linear(hidden, 1)
318
+ self.cls = nn.Linear(hidden, 3)
319
+
320
+ def _masked_mean(self, X, M):
321
+ Mf = M.unsqueeze(-1).float()
322
+ denom = Mf.sum(dim=1).clamp(min=1.0)
323
+ return (X * Mf).sum(dim=1) / denom
324
+
325
+ def forward(self, T, Mt, B, Mb):
326
+ T = self.t_proj(T)
327
+ Bx = self.b_proj(B)
328
+ kp_t = ~Mt
329
+ kp_b = ~Mb
330
+
331
+ for L in self.layers:
332
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
333
+ T = L["n1t"](T + T_attn)
334
+ T = L["n2t"](T + L["fft"](T))
335
+
336
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
337
+ Bx = L["n1b"](Bx + B_attn)
338
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
339
+
340
+ t_pool = self._masked_mean(T, Mt)
341
+ b_pool = self._masked_mean(Bx, Mb)
342
+ z = torch.cat([t_pool, b_pool], dim=-1)
343
+ h = self.shared(z)
344
+ return self.reg(h).squeeze(-1), self.cls(h)
345
+
346
+ def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
347
+ ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
348
+ params = ckpt["best_params"]
349
+ sd = ckpt["state_dict"]
350
+
351
+ # infer Ht/Hb from projection weights
352
+ Ht = int(sd["t_proj.0.weight"].shape[1])
353
+ Hb = int(sd["b_proj.0.weight"].shape[1])
354
+
355
+ common = dict(
356
+ Ht=Ht, Hb=Hb,
357
+ hidden=int(params["hidden_dim"]),
358
+ n_heads=int(params["n_heads"]),
359
+ n_layers=int(params["n_layers"]),
360
+ dropout=float(params["dropout"]),
361
+ )
362
+
363
+ if pooled_or_unpooled == "pooled":
364
+ model = CrossAttnPooled(**common)
365
+ elif pooled_or_unpooled == "unpooled":
366
+ model = CrossAttnUnpooled(**common)
367
+ else:
368
+ raise ValueError(pooled_or_unpooled)
369
+
370
+ model.load_state_dict(sd)
371
+ model.to(device).eval()
372
+ return model
373
+
374
+
375
+ # -----------------------------
376
+ # Embedding generation
377
+ # -----------------------------
378
+ def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
379
+ """
380
+ Pytorch patch
381
+ """
382
+ if hasattr(torch, "isin"):
383
+ return torch.isin(ids, test_ids)
384
+ # Fallback: compare against each special id
385
+ # (B,L,1) == (1,1,K) -> (B,L,K)
386
+ return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
387
+
388
+ class SMILESEmbedder:
389
+ """
390
+ PeptideCLM RoFormer embeddings for SMILES.
391
+ - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
392
+ - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
393
+ plus a 1-mask of length Li (since already filtered).
394
+ """
395
+ def __init__(
396
+ self,
397
+ device: torch.device,
398
+ vocab_path: str,
399
+ splits_path: str,
400
+ clm_name: str = "aaronfeller/PeptideCLM-23M-all",
401
+ max_len: int = 512,
402
+ use_cache: bool = True,
403
+ ):
404
+ self.device = device
405
+ self.max_len = max_len
406
+ self.use_cache = use_cache
407
+
408
+ self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
409
+ self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
410
+
411
+ self.special_ids = self._get_special_ids(self.tokenizer)
412
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
413
+ if len(self.special_ids) else None)
414
+
415
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
416
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
417
+
418
+ @staticmethod
419
+ def _get_special_ids(tokenizer) -> List[int]:
420
+ cand = [
421
+ getattr(tokenizer, "pad_token_id", None),
422
+ getattr(tokenizer, "cls_token_id", None),
423
+ getattr(tokenizer, "sep_token_id", None),
424
+ getattr(tokenizer, "bos_token_id", None),
425
+ getattr(tokenizer, "eos_token_id", None),
426
+ getattr(tokenizer, "mask_token_id", None),
427
+ ]
428
+ return sorted({int(x) for x in cand if x is not None})
429
+
430
+ def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
431
+ tok = self.tokenizer(
432
+ smiles_list,
433
+ return_tensors="pt",
434
+ padding=True,
435
+ truncation=True,
436
+ max_length=self.max_len,
437
+ )
438
+ for k in tok:
439
+ tok[k] = tok[k].to(self.device)
440
+ if "attention_mask" not in tok:
441
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
442
+ return tok
443
+
444
+ @torch.no_grad()
445
+ def pooled(self, smiles: str) -> torch.Tensor:
446
+ s = smiles.strip()
447
+ if self.use_cache and s in self._cache_pooled:
448
+ return self._cache_pooled[s]
449
+
450
+ tok = self._tokenize([s])
451
+ ids = tok["input_ids"] # (1,L)
452
+ attn = tok["attention_mask"].bool() # (1,L)
453
+
454
+ out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
455
+ h = out.last_hidden_state # (1,L,H)
456
+
457
+ valid = attn
458
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
459
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
460
+
461
+ vf = valid.unsqueeze(-1).float()
462
+ summed = (h * vf).sum(dim=1) # (1,H)
463
+ denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
464
+ pooled = summed / denom # (1,H)
465
+
466
+ if self.use_cache:
467
+ self._cache_pooled[s] = pooled
468
+ return pooled
469
+
470
+ @torch.no_grad()
471
+ def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
472
+ """
473
+ Returns:
474
+ X: (1, Li, H) float32 on device
475
+ M: (1, Li) bool on device
476
+ where Li excludes padding + special tokens.
477
+ """
478
+ s = smiles.strip()
479
+ if self.use_cache and s in self._cache_unpooled:
480
+ return self._cache_unpooled[s]
481
+
482
+ tok = self._tokenize([s])
483
+ ids = tok["input_ids"] # (1,L)
484
+ attn = tok["attention_mask"].bool() # (1,L)
485
+
486
+ out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
487
+ h = out.last_hidden_state # (1,L,H)
488
+
489
+ valid = attn
490
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
491
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
492
+
493
+ # filter valid tokens
494
+ keep = valid[0] # (L,)
495
+ X = h[:, keep, :] # (1,Li,H)
496
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
497
+
498
+ if self.use_cache:
499
+ self._cache_unpooled[s] = (X, M)
500
+ return X, M
501
+
502
+
503
+ class WTEmbedder:
504
+ """
505
+ ESM2 embeddings for AA sequences.
506
+ - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
507
+ - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
508
+ plus a 1-mask of length Li (since already filtered).
509
+ """
510
+ def __init__(
511
+ self,
512
+ device: torch.device,
513
+ esm_name: str = "facebook/esm2_t33_650M_UR50D",
514
+ max_len: int = 1022,
515
+ use_cache: bool = True,
516
+ ):
517
+ self.device = device
518
+ self.max_len = max_len
519
+ self.use_cache = use_cache
520
+
521
+ self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
522
+ self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
523
+
524
+ self.special_ids = self._get_special_ids(self.tokenizer)
525
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
526
+ if len(self.special_ids) else None)
527
+
528
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
529
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
530
+
531
+ @staticmethod
532
+ def _get_special_ids(tokenizer) -> List[int]:
533
+ cand = [
534
+ getattr(tokenizer, "pad_token_id", None),
535
+ getattr(tokenizer, "cls_token_id", None),
536
+ getattr(tokenizer, "sep_token_id", None),
537
+ getattr(tokenizer, "bos_token_id", None),
538
+ getattr(tokenizer, "eos_token_id", None),
539
+ getattr(tokenizer, "mask_token_id", None),
540
+ ]
541
+ return sorted({int(x) for x in cand if x is not None})
542
+
543
+ def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]:
544
+ tok = self.tokenizer(
545
+ seq_list,
546
+ return_tensors="pt",
547
+ padding=True,
548
+ truncation=True,
549
+ max_length=self.max_len,
550
+ )
551
+ tok = {k: v.to(self.device) for k, v in tok.items()}
552
+ if "attention_mask" not in tok:
553
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
554
+ return tok
555
+
556
+ @torch.no_grad()
557
+ def pooled(self, seq: str) -> torch.Tensor:
558
+ s = seq.strip()
559
+ if self.use_cache and s in self._cache_pooled:
560
+ return self._cache_pooled[s]
561
+
562
+ tok = self._tokenize([s])
563
+ ids = tok["input_ids"] # (1,L)
564
+ attn = tok["attention_mask"].bool() # (1,L)
565
+
566
+ out = self.model(**tok)
567
+ h = out.last_hidden_state # (1,L,H)
568
+
569
+ valid = attn
570
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
571
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
572
+
573
+ vf = valid.unsqueeze(-1).float()
574
+ summed = (h * vf).sum(dim=1) # (1,H)
575
+ denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
576
+ pooled = summed / denom # (1,H)
577
+
578
+ if self.use_cache:
579
+ self._cache_pooled[s] = pooled
580
+ return pooled
581
+
582
+ @torch.no_grad()
583
+ def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
584
+ """
585
+ Returns:
586
+ X: (1, Li, H) float32 on device
587
+ M: (1, Li) bool on device
588
+ where Li excludes padding + special tokens.
589
+ """
590
+ s = seq.strip()
591
+ if self.use_cache and s in self._cache_unpooled:
592
+ return self._cache_unpooled[s]
593
+
594
+ tok = self._tokenize([s])
595
+ ids = tok["input_ids"] # (1,L)
596
+ attn = tok["attention_mask"].bool() # (1,L)
597
+
598
+ out = self.model(**tok)
599
+ h = out.last_hidden_state # (1,L,H)
600
+
601
+ valid = attn
602
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
603
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
604
+
605
+ keep = valid[0] # (L,)
606
+ X = h[:, keep, :] # (1,Li,H)
607
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
608
+
609
+ if self.use_cache:
610
+ self._cache_unpooled[s] = (X, M)
611
+ return X, M
612
+
613
+
614
+
615
+ # -----------------------------
616
+ # Predictor
617
+ # -----------------------------
618
+ class PeptiVersePredictor:
619
+ """
620
+ - loads best models from training_classifiers/
621
+ - computes embeddings as needed (pooled/unpooled)
622
+ - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
623
+ """
624
+ def __init__(
625
+ self,
626
+ manifest_path: str | Path,
627
+ classifier_weight_root: str | Path,
628
+ esm_name="facebook/esm2_t33_650M_UR50D",
629
+ clm_name="aaronfeller/PeptideCLM-23M-all",
630
+ smiles_vocab="tokenizer/new_vocab.txt",
631
+ smiles_splits="tokenizer/new_splits.txt",
632
+ device: Optional[str] = None,
633
+ ):
634
+ self.root = Path(classifier_weight_root)
635
+ self.training_root = self.root / "training_classifiers"
636
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
637
+
638
+ self.manifest = read_best_manifest_csv(manifest_path)
639
+
640
+ self.wt_embedder = WTEmbedder(self.device)
641
+ self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
642
+ vocab_path=str(self.root / smiles_vocab),
643
+ splits_path=str(self.root / smiles_splits))
644
+
645
+ self.models: Dict[Tuple[str, str], Any] = {}
646
+ self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
647
+
648
+ self._load_all_best_models()
649
+
650
+ def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path:
651
+ """
652
+ Usual layout: training_classifiers/<prop>/<model>_<mode>/
653
+ Fallbacks:
654
+ - training_classifiers/<prop>/<model>/
655
+ - training_classifiers/<prop>/<model>_wt
656
+ """
657
+ base = self.training_root / prop_key
658
+ candidates = [
659
+ base / f"{model_name}_{mode}",
660
+ base / model_name,
661
+ ]
662
+ if mode == "wt":
663
+ candidates += [base / f"{model_name}_wt"]
664
+ if mode == "smiles":
665
+ candidates += [base / f"{model_name}_smiles"]
666
+
667
+ for d in candidates:
668
+ if d.exists():
669
+ return d
670
+ raise FileNotFoundError(f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}")
671
+
672
+ def _load_all_best_models(self):
673
+ for prop_key, row in self.manifest.items():
674
+ for mode, label, thr in [
675
+ ("wt", row.best_wt, row.thr_wt),
676
+ ("smiles", row.best_smiles, row.thr_smiles),
677
+ ]:
678
+ m = canon_model(label)
679
+ if m is None:
680
+ continue
681
+
682
+ # ---- binding affinity special ----
683
+ if prop_key == "binding_affinity":
684
+ # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_*
685
+ pooled_or_unpooled = m # "pooled" or "unpooled"
686
+ folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
687
+ model_dir = self.training_root / "binding_affinity" / folder
688
+ art = find_best_artifact(model_dir)
689
+ if art.suffix != ".pt":
690
+ raise RuntimeError(f"Binding model expected best_model.pt, got {art}")
691
+ model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device)
692
+ self.models[(prop_key, mode)] = model
693
+ self.meta[(prop_key, mode)] = {
694
+ "task_type": "Regression",
695
+ "threshold": None,
696
+ "artifact": str(art),
697
+ "model_name": pooled_or_unpooled,
698
+ }
699
+ continue
700
+
701
+ model_dir = self._resolve_dir(prop_key, m, mode)
702
+ kind, obj, art = load_artifact(model_dir, self.device)
703
+
704
+ if kind in {"xgb", "joblib"}:
705
+ self.models[(prop_key, mode)] = obj
706
+ else:
707
+ # rebuild NN architecture
708
+ self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device)
709
+
710
+ self.meta[(prop_key, mode)] = {
711
+ "task_type": row.task_type,
712
+ "threshold": thr,
713
+ "artifact": str(art),
714
+ "model_name": m,
715
+ "kind": kind,
716
+ }
717
+
718
+ def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
719
+ """
720
+ Returns either:
721
+ - pooled np array shape (1,H) for xgb/joblib
722
+ - unpooled torch tensors (X,M) for NN
723
+ """
724
+ model = self.models[(prop_key, mode)]
725
+ meta = self.meta[(prop_key, mode)]
726
+ kind = meta.get("kind", None)
727
+ model_name = meta.get("model_name", "")
728
+
729
+ if prop_key == "binding_affinity":
730
+ raise RuntimeError("Use predict_binding_affinity().")
731
+
732
+ # If torch NN: needs unpooled
733
+ if kind == "torch_ckpt":
734
+ if mode == "wt":
735
+ X, M = self.wt_embedder.unpooled(input_str)
736
+ else:
737
+ X, M = self.smiles_embedder.unpooled(input_str)
738
+ return X, M
739
+
740
+ # Otherwise pooled vectors for xgb/joblib
741
+ if mode == "wt":
742
+ v = self.wt_embedder.pooled(input_str) # (1,H)
743
+ else:
744
+ v = self.smiles_embedder.pooled(input_str) # (1,H)
745
+ feats = v.detach().cpu().numpy().astype(np.float32)
746
+ feats = np.nan_to_num(feats, nan=0.0)
747
+ feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
748
+ return feats
749
+
750
+ def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]:
751
+ """
752
+ mode: "wt" for AA sequence input, "smiles" for SMILES input
753
+ Returns dict with score + label if classifier threshold exists.
754
+ """
755
+ if (prop_key, mode) not in self.models:
756
+ raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.")
757
+
758
+ meta = self.meta[(prop_key, mode)]
759
+ model = self.models[(prop_key, mode)]
760
+ task_type = meta["task_type"].lower()
761
+ thr = meta.get("threshold", None)
762
+ kind = meta.get("kind", None)
763
+
764
+ if prop_key == "binding_affinity":
765
+ raise RuntimeError("Use predict_binding_affinity().")
766
+
767
+ # NN path (logits / regression)
768
+ if kind == "torch_ckpt":
769
+ X, M = self._get_features_for_model(prop_key, mode, input_str)
770
+ with torch.no_grad():
771
+ y = model(X, M).squeeze().float().cpu().item()
772
+ if task_type == "classifier":
773
+ prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
774
+ out = {"property": prop_key, "mode": mode, "score": prob}
775
+ if thr is not None:
776
+ out["label"] = int(prob >= float(thr))
777
+ out["threshold"] = float(thr)
778
+ return out
779
+ else:
780
+ return {"property": prop_key, "mode": mode, "score": float(y)}
781
+
782
+ # xgb path
783
+ if kind == "xgb":
784
+ feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
785
+ dmat = xgb.DMatrix(feats)
786
+ pred = float(model.predict(dmat)[0])
787
+ out = {"property": prop_key, "mode": mode, "score": pred}
788
+ if task_type == "classifier" and thr is not None:
789
+ out["label"] = int(pred >= float(thr))
790
+ out["threshold"] = float(thr)
791
+ return out
792
+
793
+ # joblib path (svm/enet/svr)
794
+ if kind == "joblib":
795
+ feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
796
+ # classifier vs regressor behavior differs by estimator
797
+ if task_type == "classifier":
798
+ if hasattr(model, "predict_proba"):
799
+ pred = float(model.predict_proba(feats)[:, 1][0])
800
+ else:
801
+ if hasattr(model, "decision_function"):
802
+ logit = float(model.decision_function(feats)[0])
803
+ pred = float(1.0 / (1.0 + np.exp(-logit)))
804
+ else:
805
+ pred = float(model.predict(feats)[0])
806
+ out = {"property": prop_key, "mode": mode, "score": pred}
807
+ if thr is not None:
808
+ out["label"] = int(pred >= float(thr))
809
+ out["threshold"] = float(thr)
810
+ return out
811
+ else:
812
+ pred = float(model.predict(feats)[0])
813
+ return {"property": prop_key, "mode": mode, "score": pred}
814
+
815
+ raise RuntimeError(f"Unknown model kind={kind}")
816
+
817
+ def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]:
818
+ """
819
+ mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled)
820
+ "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled)
821
+ """
822
+ prop_key = "binding_affinity"
823
+ if (prop_key, mode) not in self.models:
824
+ raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
825
+
826
+ model = self.models[(prop_key, mode)]
827
+ pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
828
+
829
+ # target is always WT sequence (ESM)
830
+ if pooled_or_unpooled == "pooled":
831
+ t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht)
832
+ if mode == "wt":
833
+ b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb)
834
+ else:
835
+ b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb)
836
+ with torch.no_grad():
837
+ reg, logits = model(t_vec, b_vec)
838
+ affinity = float(reg.squeeze().cpu().item())
839
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
840
+ cls_thr = affinity_to_class(affinity)
841
+ else:
842
+ T, Mt = self.wt_embedder.unpooled(target_seq)
843
+ if mode == "wt":
844
+ B, Mb = self.wt_embedder.unpooled(binder_str)
845
+ else:
846
+ B, Mb = self.smiles_embedder.unpooled(binder_str)
847
+ with torch.no_grad():
848
+ reg, logits = model(T, Mt, B, Mb)
849
+ affinity = float(reg.squeeze().cpu().item())
850
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
851
+ cls_thr = affinity_to_class(affinity)
852
+
853
+ names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"}
854
+ return {
855
+ "property": "binding_affinity",
856
+ "mode": mode,
857
+ "affinity": affinity,
858
+ "class_by_threshold": names[cls_thr],
859
+ "class_by_logits": names[cls_logit],
860
+ "binding_model": pooled_or_unpooled,
861
+ }
862
+
863
+
864
+ # -----------------------------
865
+ # Minimal usage
866
+ # -----------------------------
867
+ if __name__ == "__main__":
868
+ # Example:
869
+ predictor = PeptiVersePredictor(
870
+ manifest_path="best_models.txt",
871
+ classifier_weight_root="/vast/projects/pranam/lab/yz927/projects/Classifier_Weight"
872
+ )
873
+ print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
874
+ print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
875
+
876
+ # Test Embedding #
877
+ """
878
+ device = torch.device("cuda:0")
879
+
880
+ wt = WTEmbedder(device)
881
+ sm = SMILESEmbedder(device,
882
+ vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt",
883
+ splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt"
884
+ )
885
+
886
+ p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280)
887
+ X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li)
888
+
889
+ p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles)
890
+ X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li)
891
+ """