yinuozhang commited on
Commit
8950131
·
1 Parent(s): c0891bf

model update

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -35
  2. app.py +0 -0
  3. best_models.txt +10 -0
  4. inference.py +907 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
best_models.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
+ Hemolysis, XGB, Transformer, Classifier, 0.2801, 0.4343,
3
+ Non-Fouling, MLP, XGB, Classifier, 0.57, 0.3982,
4
+ Solubility, CNN, -, Classifier, 0.377, -,
5
+ Permeability (Penetrance), XGB, -, Classifier, 0.4301, -,
6
+ Toxicity, -, Transformer, Classifier, -, 0.3401,
7
+ Binding_affinity, unpooled, unpooled, Regression, -, -,
8
+ Permeability_PAMPA, -, Transformer, Regression, -, -,
9
+ Permeability_CACO2, -, Transformer, Regression, -, -,
10
+ Halflife, xgb_wt_log, xgb_smiles, Regression, -, -,
inference.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # peptiverse_infer.py
2
+ from __future__ import annotations
3
+
4
+ import csv, re, json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Tuple, Any, List
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import joblib
13
+ import xgboost as xgb
14
+
15
+ from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
16
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
+
18
+
19
+ # -----------------------------
20
+ # Manifest
21
+ # -----------------------------
22
+ @dataclass(frozen=True)
23
+ class BestRow:
24
+ property_key: str
25
+ best_wt: Optional[str]
26
+ best_smiles: Optional[str]
27
+ task_type: str # "Classifier" or "Regression"
28
+ thr_wt: Optional[float]
29
+ thr_smiles: Optional[float]
30
+
31
+
32
+ def _clean(s: str) -> str:
33
+ return (s or "").strip()
34
+
35
+ def _none_if_dash(s: str) -> Optional[str]:
36
+ s = _clean(s)
37
+ if s in {"", "-", "—", "NA", "N/A"}:
38
+ return None
39
+ return s
40
+
41
+ def _float_or_none(s: str) -> Optional[float]:
42
+ s = _clean(s)
43
+ if s in {"", "-", "—", "NA", "N/A"}:
44
+ return None
45
+ return float(s)
46
+
47
+ def normalize_property_key(name: str) -> str:
48
+ n = name.strip().lower()
49
+ n = re.sub(r"\s*\(.*?\)\s*", "", n)
50
+ n = n.replace("-", "_").replace(" ", "_")
51
+
52
+ if "permeability" in n and "pampa" not in n and "caco" not in n:
53
+ return "permeability_penetrance"
54
+ if n == "binding_affinity":
55
+ return "binding_affinity"
56
+ if n in {"halflife", "half_life"}:
57
+ return "halflife"
58
+ if n == "non_fouling":
59
+ return "nf"
60
+ return n
61
+
62
+
63
+ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
64
+ """
65
+ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
66
+ Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
67
+ """
68
+ p = Path(path)
69
+ out: Dict[str, BestRow] = {}
70
+
71
+ with p.open("r", newline="") as f:
72
+ reader = csv.reader(f)
73
+ header = None
74
+ for raw in reader:
75
+ if not raw or all(_clean(x) == "" for x in raw):
76
+ continue
77
+ while raw and _clean(raw[-1]) == "":
78
+ raw = raw[:-1]
79
+
80
+ if header is None:
81
+ header = [h.strip() for h in raw]
82
+ continue
83
+
84
+ if len(raw) < len(header):
85
+ raw = raw + [""] * (len(header) - len(raw))
86
+ rec = dict(zip(header, raw))
87
+
88
+ prop_raw = _clean(rec.get("Properties", ""))
89
+ if not prop_raw:
90
+ continue
91
+ prop_key = normalize_property_key(prop_raw)
92
+
93
+ row = BestRow(
94
+ property_key=prop_key,
95
+ best_wt=_none_if_dash(rec.get("Best_Model_WT", "")),
96
+ best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")),
97
+ task_type=_clean(rec.get("Type", "Classifier")),
98
+ thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
99
+ thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
100
+ )
101
+ out[prop_key] = row
102
+
103
+ return out
104
+
105
+
106
+ MODEL_ALIAS = {
107
+ "SVM": "svm_gpu",
108
+ "SVR": "svr",
109
+ "ENET": "enet_gpu",
110
+ "CNN": "cnn",
111
+ "MLP": "mlp",
112
+ "TRANSFORMER": "transformer",
113
+ "XGB": "xgb",
114
+ "XGB_REG": "xgb_reg",
115
+ "POOLED": "pooled",
116
+ "UNPOOLED": "unpooled"
117
+ }
118
+ def canon_model(label: Optional[str]) -> Optional[str]:
119
+ if label is None:
120
+ return None
121
+ k = label.strip().upper()
122
+ return MODEL_ALIAS.get(k, label.strip().lower())
123
+
124
+
125
+ # -----------------------------
126
+ # Generic artifact loading
127
+ # -----------------------------
128
+ def find_best_artifact(model_dir: Path) -> Path:
129
+ for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]:
130
+ hits = sorted(model_dir.glob(pat))
131
+ if hits:
132
+ return hits[0]
133
+ raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
134
+
135
+ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
136
+ art = find_best_artifact(model_dir)
137
+
138
+ if art.suffix == ".json":
139
+ booster = xgb.Booster()
140
+ print(str(art))
141
+ booster.load_model(str(art))
142
+ return "xgb", booster, art
143
+
144
+ if art.suffix == ".joblib":
145
+ obj = joblib.load(art)
146
+ return "joblib", obj, art
147
+
148
+ if art.suffix == ".pt":
149
+ ckpt = torch.load(art, map_location=device, weights_only=False)
150
+ return "torch_ckpt", ckpt, art
151
+
152
+ raise ValueError(f"Unknown artifact type: {art}")
153
+
154
+
155
+ # -----------------------------
156
+ # NN architectures
157
+ # -----------------------------
158
+ class MaskedMeanPool(nn.Module):
159
+ def forward(self, X, M): # X:(B,L,H), M:(B,L)
160
+ Mf = M.unsqueeze(-1).float()
161
+ denom = Mf.sum(dim=1).clamp(min=1.0)
162
+ return (X * Mf).sum(dim=1) / denom
163
+
164
+ class MLPHead(nn.Module):
165
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
166
+ super().__init__()
167
+ self.pool = MaskedMeanPool()
168
+ self.net = nn.Sequential(
169
+ nn.Linear(in_dim, hidden),
170
+ nn.GELU(),
171
+ nn.Dropout(dropout),
172
+ nn.Linear(hidden, 1),
173
+ )
174
+ def forward(self, X, M):
175
+ z = self.pool(X, M)
176
+ return self.net(z).squeeze(-1)
177
+
178
+ class CNNHead(nn.Module):
179
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
180
+ super().__init__()
181
+ blocks = []
182
+ ch = in_ch
183
+ for _ in range(layers):
184
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
185
+ nn.GELU(),
186
+ nn.Dropout(dropout)]
187
+ ch = c
188
+ self.conv = nn.Sequential(*blocks)
189
+ self.head = nn.Linear(c, 1)
190
+
191
+ def forward(self, X, M):
192
+ Xc = X.transpose(1, 2) # (B,H,L)
193
+ Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
194
+ Mf = M.unsqueeze(-1).float()
195
+ denom = Mf.sum(dim=1).clamp(min=1.0)
196
+ pooled = (Y * Mf).sum(dim=1) / denom
197
+ return self.head(pooled).squeeze(-1)
198
+
199
+ class TransformerHead(nn.Module):
200
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
201
+ super().__init__()
202
+ self.proj = nn.Linear(in_dim, d_model)
203
+ enc_layer = nn.TransformerEncoderLayer(
204
+ d_model=d_model, nhead=nhead, dim_feedforward=ff,
205
+ dropout=dropout, batch_first=True, activation="gelu"
206
+ )
207
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
208
+ self.head = nn.Linear(d_model, 1)
209
+
210
+ def forward(self, X, M):
211
+ pad_mask = ~M
212
+ Z = self.proj(X)
213
+ Z = self.enc(Z, src_key_padding_mask=pad_mask)
214
+ Mf = M.unsqueeze(-1).float()
215
+ denom = Mf.sum(dim=1).clamp(min=1.0)
216
+ pooled = (Z * Mf).sum(dim=1) / denom
217
+ return self.head(pooled).squeeze(-1)
218
+
219
+ def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
220
+ if model_name == "mlp":
221
+ return int(sd["net.0.weight"].shape[1])
222
+ if model_name == "cnn":
223
+ return int(sd["conv.0.weight"].shape[1])
224
+ if model_name == "transformer":
225
+ return int(sd["proj.weight"].shape[1])
226
+ raise ValueError(model_name)
227
+
228
+ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
229
+ params = ckpt["best_params"]
230
+ sd = ckpt["state_dict"]
231
+ in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
232
+ dropout = float(params.get("dropout", 0.1))
233
+
234
+ if model_name == "mlp":
235
+ model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
236
+ elif model_name == "cnn":
237
+ model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
238
+ layers=int(params["layers"]), dropout=dropout)
239
+ elif model_name == "transformer":
240
+ model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]),
241
+ layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout)
242
+ else:
243
+ raise ValueError(f"Unknown NN model_name={model_name}")
244
+
245
+ model.load_state_dict(sd)
246
+ model.to(device)
247
+ model.eval()
248
+ return model
249
+
250
+
251
+ # -----------------------------
252
+ # Binding affinity models
253
+ # -----------------------------
254
+ def affinity_to_class(y: float) -> int:
255
+ # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
256
+ if y >= 9.0: return 0
257
+ if y < 7.0: return 2
258
+ return 1
259
+
260
+ class CrossAttnPooled(nn.Module):
261
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
262
+ super().__init__()
263
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
264
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
265
+
266
+ self.layers = nn.ModuleList([])
267
+ for _ in range(n_layers):
268
+ self.layers.append(nn.ModuleDict({
269
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
270
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
271
+ "n1t": nn.LayerNorm(hidden),
272
+ "n2t": nn.LayerNorm(hidden),
273
+ "n1b": nn.LayerNorm(hidden),
274
+ "n2b": nn.LayerNorm(hidden),
275
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
276
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
277
+ }))
278
+
279
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
280
+ self.reg = nn.Linear(hidden, 1)
281
+ self.cls = nn.Linear(hidden, 3)
282
+
283
+ def forward(self, t_vec, b_vec):
284
+ t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
285
+ b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
286
+ for L in self.layers:
287
+ t_attn, _ = L["attn_tb"](t, b, b)
288
+ t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
289
+ t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
290
+
291
+ b_attn, _ = L["attn_bt"](b, t, t)
292
+ b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
293
+ b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
294
+
295
+ z = torch.cat([t[0], b[0]], dim=-1)
296
+ h = self.shared(z)
297
+ return self.reg(h).squeeze(-1), self.cls(h)
298
+
299
+ class CrossAttnUnpooled(nn.Module):
300
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
301
+ super().__init__()
302
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
303
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
304
+
305
+ self.layers = nn.ModuleList([])
306
+ for _ in range(n_layers):
307
+ self.layers.append(nn.ModuleDict({
308
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
309
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
310
+ "n1t": nn.LayerNorm(hidden),
311
+ "n2t": nn.LayerNorm(hidden),
312
+ "n1b": nn.LayerNorm(hidden),
313
+ "n2b": nn.LayerNorm(hidden),
314
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
315
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
316
+ }))
317
+
318
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
319
+ self.reg = nn.Linear(hidden, 1)
320
+ self.cls = nn.Linear(hidden, 3)
321
+
322
+ def _masked_mean(self, X, M):
323
+ Mf = M.unsqueeze(-1).float()
324
+ denom = Mf.sum(dim=1).clamp(min=1.0)
325
+ return (X * Mf).sum(dim=1) / denom
326
+
327
+ def forward(self, T, Mt, B, Mb):
328
+ T = self.t_proj(T)
329
+ Bx = self.b_proj(B)
330
+ kp_t = ~Mt
331
+ kp_b = ~Mb
332
+
333
+ for L in self.layers:
334
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
335
+ T = L["n1t"](T + T_attn)
336
+ T = L["n2t"](T + L["fft"](T))
337
+
338
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
339
+ Bx = L["n1b"](Bx + B_attn)
340
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
341
+
342
+ t_pool = self._masked_mean(T, Mt)
343
+ b_pool = self._masked_mean(Bx, Mb)
344
+ z = torch.cat([t_pool, b_pool], dim=-1)
345
+ h = self.shared(z)
346
+ return self.reg(h).squeeze(-1), self.cls(h)
347
+
348
+ def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
349
+ ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
350
+ params = ckpt["best_params"]
351
+ sd = ckpt["state_dict"]
352
+
353
+ # infer Ht/Hb from projection weights
354
+ Ht = int(sd["t_proj.0.weight"].shape[1])
355
+ Hb = int(sd["b_proj.0.weight"].shape[1])
356
+
357
+ common = dict(
358
+ Ht=Ht, Hb=Hb,
359
+ hidden=int(params["hidden_dim"]),
360
+ n_heads=int(params["n_heads"]),
361
+ n_layers=int(params["n_layers"]),
362
+ dropout=float(params["dropout"]),
363
+ )
364
+
365
+ if pooled_or_unpooled == "pooled":
366
+ model = CrossAttnPooled(**common)
367
+ elif pooled_or_unpooled == "unpooled":
368
+ model = CrossAttnUnpooled(**common)
369
+ else:
370
+ raise ValueError(pooled_or_unpooled)
371
+
372
+ model.load_state_dict(sd)
373
+ model.to(device).eval()
374
+ return model
375
+
376
+
377
+ # -----------------------------
378
+ # Embedding generation
379
+ # -----------------------------
380
+ def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
381
+ """
382
+ Pytorch patch
383
+ """
384
+ if hasattr(torch, "isin"):
385
+ return torch.isin(ids, test_ids)
386
+ # Fallback: compare against each special id
387
+ # (B,L,1) == (1,1,K) -> (B,L,K)
388
+ return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
389
+
390
+ class SMILESEmbedder:
391
+ """
392
+ PeptideCLM RoFormer embeddings for SMILES.
393
+ - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
394
+ - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
395
+ plus a 1-mask of length Li (since already filtered).
396
+ """
397
+ def __init__(
398
+ self,
399
+ device: torch.device,
400
+ vocab_path: str,
401
+ splits_path: str,
402
+ clm_name: str = "aaronfeller/PeptideCLM-23M-all",
403
+ max_len: int = 512,
404
+ use_cache: bool = True,
405
+ ):
406
+ self.device = device
407
+ self.max_len = max_len
408
+ self.use_cache = use_cache
409
+
410
+ self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
411
+ self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
412
+
413
+ self.special_ids = self._get_special_ids(self.tokenizer)
414
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
415
+ if len(self.special_ids) else None)
416
+
417
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
418
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
419
+
420
+ @staticmethod
421
+ def _get_special_ids(tokenizer) -> List[int]:
422
+ cand = [
423
+ getattr(tokenizer, "pad_token_id", None),
424
+ getattr(tokenizer, "cls_token_id", None),
425
+ getattr(tokenizer, "sep_token_id", None),
426
+ getattr(tokenizer, "bos_token_id", None),
427
+ getattr(tokenizer, "eos_token_id", None),
428
+ getattr(tokenizer, "mask_token_id", None),
429
+ ]
430
+ return sorted({int(x) for x in cand if x is not None})
431
+
432
+ def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
433
+ tok = self.tokenizer(
434
+ smiles_list,
435
+ return_tensors="pt",
436
+ padding=True,
437
+ truncation=True,
438
+ max_length=self.max_len,
439
+ )
440
+ for k in tok:
441
+ tok[k] = tok[k].to(self.device)
442
+ if "attention_mask" not in tok:
443
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
444
+ return tok
445
+
446
+ @torch.no_grad()
447
+ def pooled(self, smiles: str) -> torch.Tensor:
448
+ s = smiles.strip()
449
+ if self.use_cache and s in self._cache_pooled:
450
+ return self._cache_pooled[s]
451
+
452
+ tok = self._tokenize([s])
453
+ ids = tok["input_ids"] # (1,L)
454
+ attn = tok["attention_mask"].bool() # (1,L)
455
+
456
+ out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
457
+ h = out.last_hidden_state # (1,L,H)
458
+
459
+ valid = attn
460
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
461
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
462
+
463
+ vf = valid.unsqueeze(-1).float()
464
+ summed = (h * vf).sum(dim=1) # (1,H)
465
+ denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
466
+ pooled = summed / denom # (1,H)
467
+
468
+ if self.use_cache:
469
+ self._cache_pooled[s] = pooled
470
+ return pooled
471
+
472
+ @torch.no_grad()
473
+ def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ """
475
+ Returns:
476
+ X: (1, Li, H) float32 on device
477
+ M: (1, Li) bool on device
478
+ where Li excludes padding + special tokens.
479
+ """
480
+ s = smiles.strip()
481
+ if self.use_cache and s in self._cache_unpooled:
482
+ return self._cache_unpooled[s]
483
+
484
+ tok = self._tokenize([s])
485
+ ids = tok["input_ids"] # (1,L)
486
+ attn = tok["attention_mask"].bool() # (1,L)
487
+
488
+ out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
489
+ h = out.last_hidden_state # (1,L,H)
490
+
491
+ valid = attn
492
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
493
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
494
+
495
+ # filter valid tokens
496
+ keep = valid[0] # (L,)
497
+ X = h[:, keep, :] # (1,Li,H)
498
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
499
+
500
+ if self.use_cache:
501
+ self._cache_unpooled[s] = (X, M)
502
+ return X, M
503
+
504
+
505
+ class WTEmbedder:
506
+ """
507
+ ESM2 embeddings for AA sequences.
508
+ - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
509
+ - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
510
+ plus a 1-mask of length Li (since already filtered).
511
+ """
512
+ def __init__(
513
+ self,
514
+ device: torch.device,
515
+ esm_name: str = "facebook/esm2_t33_650M_UR50D",
516
+ max_len: int = 1022,
517
+ use_cache: bool = True,
518
+ ):
519
+ self.device = device
520
+ self.max_len = max_len
521
+ self.use_cache = use_cache
522
+
523
+ self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
524
+ self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
525
+
526
+ self.special_ids = self._get_special_ids(self.tokenizer)
527
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
528
+ if len(self.special_ids) else None)
529
+
530
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
531
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
532
+
533
+ @staticmethod
534
+ def _get_special_ids(tokenizer) -> List[int]:
535
+ cand = [
536
+ getattr(tokenizer, "pad_token_id", None),
537
+ getattr(tokenizer, "cls_token_id", None),
538
+ getattr(tokenizer, "sep_token_id", None),
539
+ getattr(tokenizer, "bos_token_id", None),
540
+ getattr(tokenizer, "eos_token_id", None),
541
+ getattr(tokenizer, "mask_token_id", None),
542
+ ]
543
+ return sorted({int(x) for x in cand if x is not None})
544
+
545
+ def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]:
546
+ tok = self.tokenizer(
547
+ seq_list,
548
+ return_tensors="pt",
549
+ padding=True,
550
+ truncation=True,
551
+ max_length=self.max_len,
552
+ )
553
+ tok = {k: v.to(self.device) for k, v in tok.items()}
554
+ if "attention_mask" not in tok:
555
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
556
+ return tok
557
+
558
+ @torch.no_grad()
559
+ def pooled(self, seq: str) -> torch.Tensor:
560
+ s = seq.strip()
561
+ if self.use_cache and s in self._cache_pooled:
562
+ return self._cache_pooled[s]
563
+
564
+ tok = self._tokenize([s])
565
+ ids = tok["input_ids"] # (1,L)
566
+ attn = tok["attention_mask"].bool() # (1,L)
567
+
568
+ out = self.model(**tok)
569
+ h = out.last_hidden_state # (1,L,H)
570
+
571
+ valid = attn
572
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
573
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
574
+
575
+ vf = valid.unsqueeze(-1).float()
576
+ summed = (h * vf).sum(dim=1) # (1,H)
577
+ denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
578
+ pooled = summed / denom # (1,H)
579
+
580
+ if self.use_cache:
581
+ self._cache_pooled[s] = pooled
582
+ return pooled
583
+
584
+ @torch.no_grad()
585
+ def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
586
+ """
587
+ Returns:
588
+ X: (1, Li, H) float32 on device
589
+ M: (1, Li) bool on device
590
+ where Li excludes padding + special tokens.
591
+ """
592
+ s = seq.strip()
593
+ if self.use_cache and s in self._cache_unpooled:
594
+ return self._cache_unpooled[s]
595
+
596
+ tok = self._tokenize([s])
597
+ ids = tok["input_ids"] # (1,L)
598
+ attn = tok["attention_mask"].bool() # (1,L)
599
+
600
+ out = self.model(**tok)
601
+ h = out.last_hidden_state # (1,L,H)
602
+
603
+ valid = attn
604
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
605
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
606
+
607
+ keep = valid[0] # (L,)
608
+ X = h[:, keep, :] # (1,Li,H)
609
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
610
+
611
+ if self.use_cache:
612
+ self._cache_unpooled[s] = (X, M)
613
+ return X, M
614
+
615
+
616
+
617
+ # -----------------------------
618
+ # Predictor
619
+ # -----------------------------
620
+ class PeptiVersePredictor:
621
+ """
622
+ - loads best models from training_classifiers/
623
+ - computes embeddings as needed (pooled/unpooled)
624
+ - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
625
+ """
626
+ def __init__(
627
+ self,
628
+ manifest_path: str | Path,
629
+ classifier_weight_root: str | Path,
630
+ esm_name="facebook/esm2_t33_650M_UR50D",
631
+ clm_name="aaronfeller/PeptideCLM-23M-all",
632
+ smiles_vocab="tokenizer/new_vocab.txt",
633
+ smiles_splits="tokenizer/new_splits.txt",
634
+ device: Optional[str] = None,
635
+ ):
636
+ self.root = Path(classifier_weight_root)
637
+ self.training_root = self.root / "training_classifiers"
638
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
639
+
640
+ self.manifest = read_best_manifest_csv(manifest_path)
641
+
642
+ self.wt_embedder = WTEmbedder(self.device)
643
+ self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
644
+ vocab_path=str(self.root / smiles_vocab),
645
+ splits_path=str(self.root / smiles_splits))
646
+
647
+ self.models: Dict[Tuple[str, str], Any] = {}
648
+ self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
649
+
650
+ self._load_all_best_models()
651
+
652
+ def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path:
653
+ # ✅ map halflife -> half_life folder on disk (common layout)
654
+ disk_prop = "half_life" if prop_key == "halflife" else prop_key
655
+ base = self.training_root / disk_prop
656
+
657
+ # ✅ special handling for halflife xgb_wt_log / xgb_smiles
658
+ if prop_key == "halflife" and model_name in {"xgb_wt_log", "xgb_smiles"}:
659
+ d = base / model_name
660
+ if d.exists():
661
+ return d
662
+
663
+ # (optional) allow manifest to say just "xgb" for halflife
664
+ if prop_key == "halflife" and model_name == "xgb":
665
+ d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
666
+ if d.exists():
667
+ return d
668
+
669
+ # Standard patterns
670
+ candidates = [
671
+ base / f"{model_name}_{mode}",
672
+ base / model_name,
673
+ ]
674
+ if mode == "wt":
675
+ candidates += [base / f"{model_name}_wt"]
676
+ if mode == "smiles":
677
+ candidates += [base / f"{model_name}_smiles"]
678
+
679
+ for d in candidates:
680
+ if d.exists():
681
+ return d
682
+
683
+ raise FileNotFoundError(
684
+ f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}"
685
+ )
686
+
687
+
688
+ def _load_all_best_models(self):
689
+ for prop_key, row in self.manifest.items():
690
+ for mode, label, thr in [
691
+ ("wt", row.best_wt, row.thr_wt),
692
+ ("smiles", row.best_smiles, row.thr_smiles),
693
+ ]:
694
+ m = canon_model(label)
695
+ if m is None:
696
+ continue
697
+
698
+ # ---- binding affinity special ----
699
+ if prop_key == "binding_affinity":
700
+ # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_*
701
+ pooled_or_unpooled = m # "pooled" or "unpooled"
702
+ folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
703
+ model_dir = self.training_root / "binding_affinity" / folder
704
+ art = find_best_artifact(model_dir)
705
+ if art.suffix != ".pt":
706
+ raise RuntimeError(f"Binding model expected best_model.pt, got {art}")
707
+ model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device)
708
+ self.models[(prop_key, mode)] = model
709
+ self.meta[(prop_key, mode)] = {
710
+ "task_type": "Regression",
711
+ "threshold": None,
712
+ "artifact": str(art),
713
+ "model_name": pooled_or_unpooled,
714
+ }
715
+ continue
716
+
717
+ model_dir = self._resolve_dir(prop_key, m, mode)
718
+ kind, obj, art = load_artifact(model_dir, self.device)
719
+
720
+ if kind in {"xgb", "joblib"}:
721
+ self.models[(prop_key, mode)] = obj
722
+ else:
723
+ # rebuild NN architecture
724
+ self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device)
725
+
726
+ self.meta[(prop_key, mode)] = {
727
+ "task_type": row.task_type,
728
+ "threshold": thr,
729
+ "artifact": str(art),
730
+ "model_name": m,
731
+ "kind": kind,
732
+ }
733
+
734
+ def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
735
+ """
736
+ Returns either:
737
+ - pooled np array shape (1,H) for xgb/joblib
738
+ - unpooled torch tensors (X,M) for NN
739
+ """
740
+ model = self.models[(prop_key, mode)]
741
+ meta = self.meta[(prop_key, mode)]
742
+ kind = meta.get("kind", None)
743
+ model_name = meta.get("model_name", "")
744
+
745
+ if prop_key == "binding_affinity":
746
+ raise RuntimeError("Use predict_binding_affinity().")
747
+
748
+ # If torch NN: needs unpooled
749
+ if kind == "torch_ckpt":
750
+ if mode == "wt":
751
+ X, M = self.wt_embedder.unpooled(input_str)
752
+ else:
753
+ X, M = self.smiles_embedder.unpooled(input_str)
754
+ return X, M
755
+
756
+ # Otherwise pooled vectors for xgb/joblib
757
+ if mode == "wt":
758
+ v = self.wt_embedder.pooled(input_str) # (1,H)
759
+ else:
760
+ v = self.smiles_embedder.pooled(input_str) # (1,H)
761
+ feats = v.detach().cpu().numpy().astype(np.float32)
762
+ feats = np.nan_to_num(feats, nan=0.0)
763
+ feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
764
+ return feats
765
+
766
+ def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]:
767
+ """
768
+ mode: "wt" for AA sequence input, "smiles" for SMILES input
769
+ Returns dict with score + label if classifier threshold exists.
770
+ """
771
+ if (prop_key, mode) not in self.models:
772
+ raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.")
773
+
774
+ meta = self.meta[(prop_key, mode)]
775
+ model = self.models[(prop_key, mode)]
776
+ task_type = meta["task_type"].lower()
777
+ thr = meta.get("threshold", None)
778
+ kind = meta.get("kind", None)
779
+
780
+ if prop_key == "binding_affinity":
781
+ raise RuntimeError("Use predict_binding_affinity().")
782
+
783
+ # NN path (logits / regression)
784
+ if kind == "torch_ckpt":
785
+ X, M = self._get_features_for_model(prop_key, mode, input_str)
786
+ with torch.no_grad():
787
+ y = model(X, M).squeeze().float().cpu().item()
788
+ if task_type == "classifier":
789
+ prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
790
+ out = {"property": prop_key, "mode": mode, "score": prob}
791
+ if thr is not None:
792
+ out["label"] = int(prob >= float(thr))
793
+ out["threshold"] = float(thr)
794
+ return out
795
+ else:
796
+ return {"property": prop_key, "mode": mode, "score": float(y)}
797
+
798
+ # xgb path
799
+ if kind == "xgb":
800
+ feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
801
+ dmat = xgb.DMatrix(feats)
802
+ pred = float(model.predict(dmat)[0])
803
+ out = {"property": prop_key, "mode": mode, "score": pred}
804
+ if task_type == "classifier" and thr is not None:
805
+ out["label"] = int(pred >= float(thr))
806
+ out["threshold"] = float(thr)
807
+ return out
808
+
809
+ # joblib path (svm/enet/svr)
810
+ if kind == "joblib":
811
+ feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
812
+ # classifier vs regressor behavior differs by estimator
813
+ if task_type == "classifier":
814
+ if hasattr(model, "predict_proba"):
815
+ pred = float(model.predict_proba(feats)[:, 1][0])
816
+ else:
817
+ if hasattr(model, "decision_function"):
818
+ logit = float(model.decision_function(feats)[0])
819
+ pred = float(1.0 / (1.0 + np.exp(-logit)))
820
+ else:
821
+ pred = float(model.predict(feats)[0])
822
+ out = {"property": prop_key, "mode": mode, "score": pred}
823
+ if thr is not None:
824
+ out["label"] = int(pred >= float(thr))
825
+ out["threshold"] = float(thr)
826
+ return out
827
+ else:
828
+ pred = float(model.predict(feats)[0])
829
+ return {"property": prop_key, "mode": mode, "score": pred}
830
+
831
+ raise RuntimeError(f"Unknown model kind={kind}")
832
+
833
+ def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]:
834
+ """
835
+ mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled)
836
+ "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled)
837
+ """
838
+ prop_key = "binding_affinity"
839
+ if (prop_key, mode) not in self.models:
840
+ raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
841
+
842
+ model = self.models[(prop_key, mode)]
843
+ pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
844
+
845
+ # target is always WT sequence (ESM)
846
+ if pooled_or_unpooled == "pooled":
847
+ t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht)
848
+ if mode == "wt":
849
+ b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb)
850
+ else:
851
+ b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb)
852
+ with torch.no_grad():
853
+ reg, logits = model(t_vec, b_vec)
854
+ affinity = float(reg.squeeze().cpu().item())
855
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
856
+ cls_thr = affinity_to_class(affinity)
857
+ else:
858
+ T, Mt = self.wt_embedder.unpooled(target_seq)
859
+ if mode == "wt":
860
+ B, Mb = self.wt_embedder.unpooled(binder_str)
861
+ else:
862
+ B, Mb = self.smiles_embedder.unpooled(binder_str)
863
+ with torch.no_grad():
864
+ reg, logits = model(T, Mt, B, Mb)
865
+ affinity = float(reg.squeeze().cpu().item())
866
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
867
+ cls_thr = affinity_to_class(affinity)
868
+
869
+ names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"}
870
+ return {
871
+ "property": "binding_affinity",
872
+ "mode": mode,
873
+ "affinity": affinity,
874
+ "class_by_threshold": names[cls_thr],
875
+ "class_by_logits": names[cls_logit],
876
+ "binding_model": pooled_or_unpooled,
877
+ }
878
+
879
+
880
+ # -----------------------------
881
+ # Minimal usage
882
+ # -----------------------------
883
+ if __name__ == "__main__":
884
+ # Example:
885
+ predictor = PeptiVersePredictor(
886
+ manifest_path="best_models.txt",
887
+ classifier_weight_root="/home/enol/Classifier_Weight"
888
+ )
889
+ print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
890
+ print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
891
+
892
+ # Test Embedding #
893
+ """
894
+ device = torch.device("cuda:0")
895
+
896
+ wt = WTEmbedder(device)
897
+ sm = SMILESEmbedder(device,
898
+ vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt",
899
+ splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt"
900
+ )
901
+
902
+ p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280)
903
+ X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li)
904
+
905
+ p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles)
906
+ X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li)
907
+ """