from __future__ import annotations import csv, re, json from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple, Any, List import numpy as np import torch import torch.nn as nn import joblib import xgboost as xgb from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer from lightning.pytorch import seed_everything seed_everything(1986) # ----------------------------- # Manifest # ----------------------------- @dataclass(frozen=True) class BestRow: property_key: str best_wt: Optional[str] best_smiles: Optional[str] task_type: str # "Classifier" or "Regression" thr_wt: Optional[float] thr_smiles: Optional[float] def _clean(s: str) -> str: return (s or "").strip() def _none_if_dash(s: str) -> Optional[str]: s = _clean(s) if s in {"", "-", "—", "NA", "N/A"}: return None return s def _float_or_none(s: str) -> Optional[float]: s = _clean(s) if s in {"", "-", "—", "NA", "N/A"}: return None return float(s) def normalize_property_key(name: str) -> str: n = name.strip().lower() n = re.sub(r"\s*\(.*?\)\s*", "", n) n = n.replace("-", "_").replace(" ", "_") if "permeability" in n and "pampa" not in n and "caco" not in n: return "permeability_penetrance" if n == "binding_affinity": return "binding_affinity" if n in {"halflife", "half_life"}: return "halflife" if n == "non_fouling": return "nf" return n def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: """ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES, Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223, """ p = Path(path) out: Dict[str, BestRow] = {} with p.open("r", newline="") as f: reader = csv.reader(f) header = None for raw in reader: if not raw or all(_clean(x) == "" for x in raw): continue while raw and _clean(raw[-1]) == "": raw = raw[:-1] if header is None: header = [h.strip() for h in raw] continue if len(raw) < len(header): raw = raw + [""] * (len(header) - len(raw)) rec = dict(zip(header, raw)) prop_raw = _clean(rec.get("Properties", "")) if not prop_raw: continue prop_key = normalize_property_key(prop_raw) row = BestRow( property_key=prop_key, best_wt=_none_if_dash(rec.get("Best_Model_WT", "")), best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")), task_type=_clean(rec.get("Type", "Classifier")), thr_wt=_float_or_none(rec.get("Threshold_WT", "")), thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), ) out[prop_key] = row return out MODEL_ALIAS = { "SVM": "svm_gpu", "SVR": "svr", "ENET": "enet_gpu", "CNN": "cnn", "MLP": "mlp", "TRANSFORMER": "transformer", "XGB": "xgb", "XGB_REG": "xgb_reg", "POOLED": "pooled", "UNPOOLED": "unpooled", "TRANSFORMER_WT_LOG": "transformer_wt_log", } def canon_model(label: Optional[str]) -> Optional[str]: if label is None: return None k = label.strip().upper() return MODEL_ALIAS.get(k, label.strip().lower()) # ----------------------------- # Generic artifact loading # ----------------------------- def find_best_artifact(model_dir: Path) -> Path: for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]: hits = sorted(model_dir.glob(pat)) if hits: return hits[0] raise FileNotFoundError(f"No best_model artifact found in {model_dir}") def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: art = find_best_artifact(model_dir) if art.suffix == ".json": booster = xgb.Booster() #print(str(art)) booster.load_model(str(art)) return "xgb", booster, art if art.suffix == ".joblib": obj = joblib.load(art) return "joblib", obj, art if art.suffix == ".pt": ckpt = torch.load(art, map_location=device, weights_only=False) return "torch_ckpt", ckpt, art raise ValueError(f"Unknown artifact type: {art}") # ----------------------------- # NN architectures # ----------------------------- class MaskedMeanPool(nn.Module): def forward(self, X, M): # X:(B,L,H), M:(B,L) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom class MLPHead(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): z = self.pool(X, M) return self.net(z).squeeze(-1) class CNNHead(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks = [] ch = in_ch for _ in range(layers): blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Xc = X.transpose(1, 2) # (B,H,L) Y = self.conv(Xc).transpose(1, 2) # (B,L,C) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Y * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) class TransformerHead(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): pad_mask = ~M Z = self.proj(X) Z = self.enc(Z, src_key_padding_mask=pad_mask) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Z * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: if model_name == "mlp": return int(sd["net.0.weight"].shape[1]) if model_name == "cnn": return int(sd["conv.0.weight"].shape[1]) if model_name == "transformer": return int(sd["proj.weight"].shape[1]) raise ValueError(model_name) def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int: # enc.layers.0.*, enc.layers.1.*, ... idxs = set() for k in sd.keys(): if k.startswith(prefix): rest = k[len(prefix):] m = re.match(r"(\d+)\.", rest) if m: idxs.add(int(m.group(1))) return (max(idxs) + 1) if idxs else 1 def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]: """ Returns (d_model, layers, ff) inferred from weights. - d_model from proj.weight (shape: [d_model, in_dim]) - layers from count of enc.layers.* - ff from enc.layers.0.linear1.weight (shape: [ff, d_model]) """ if "proj.weight" not in sd: raise KeyError("Missing proj.weight in state_dict; cannot infer transformer d_model.") d_model = int(sd["proj.weight"].shape[0]) layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.") if "enc.layers.0.linear1.weight" in sd: ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) else: ff = 4 * d_model return d_model, layers, ff def _pick_nhead(d_model: int) -> int: # prefer common head counts; must divide d_model for h in (8, 6, 4, 3, 2, 1): if d_model % h == 0: return h return 1 def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: params = ckpt["best_params"] sd = ckpt["state_dict"] in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) dropout = float(params.get("dropout", 0.1)) if model_name == "mlp": model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) elif model_name == "cnn": model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), layers=int(params["layers"]), dropout=dropout) elif model_name == "transformer": # if transfer-learning ckpt omits arch params, infer from state_dict. special case for transformer_wt_log d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim") if d_model is None: d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd) nhead_i = _pick_nhead(d_model_i) model = TransformerHead( in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)), layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)), dropout=float(params.get("dropout", dropout)), ) else: d_model = int(d_model) model = TransformerHead( in_dim=in_dim, d_model=d_model, nhead=int(params.get("nhead", _pick_nhead(d_model))), layers=int(params.get("layers", 2)), ff=int(params.get("ff", 4 * d_model)), dropout=dropout ) else: raise ValueError(f"Unknown NN model_name={model_name}") model.load_state_dict(sd) model.to(device) model.eval() return model # ----------------------------- # Binding affinity models # ----------------------------- def affinity_to_class(y: float) -> int: # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7) if y >= 9.0: return 0 if y < 7.0: return 2 return 1 class CrossAttnPooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def forward(self, t_vec, b_vec): t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H) b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H) for L in self.layers: t_attn, _ = L["attn_tb"](t, b, b) t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) b_attn, _ = L["attn_bt"](b, t, t) b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) z = torch.cat([t[0], b[0]], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) class CrossAttnUnpooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def _masked_mean(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom def forward(self, T, Mt, B, Mb): T = self.t_proj(T) Bx = self.b_proj(B) kp_t = ~Mt kp_b = ~Mb for L in self.layers: T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) T = L["n1t"](T + T_attn) T = L["n2t"](T + L["fft"](T)) B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) Bx = L["n1b"](Bx + B_attn) Bx = L["n2b"](Bx + L["ffb"](Bx)) t_pool = self._masked_mean(T, Mt) b_pool = self._masked_mean(Bx, Mb) z = torch.cat([t_pool, b_pool], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) params = ckpt["best_params"] sd = ckpt["state_dict"] # infer Ht/Hb from projection weights Ht = int(sd["t_proj.0.weight"].shape[1]) Hb = int(sd["b_proj.0.weight"].shape[1]) common = dict( Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]), n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]), dropout=float(params["dropout"]), ) if pooled_or_unpooled == "pooled": model = CrossAttnPooled(**common) elif pooled_or_unpooled == "unpooled": model = CrossAttnUnpooled(**common) else: raise ValueError(pooled_or_unpooled) model.load_state_dict(sd) model.to(device).eval() return model # ----------------------------- # Embedding generation # ----------------------------- def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: """ Pytorch patch """ if hasattr(torch, "isin"): return torch.isin(ids, test_ids) # Fallback: compare against each special id # (B,L,1) == (1,1,K) -> (B,L,K) return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) class SMILESEmbedder: """ PeptideCLM RoFormer embeddings for SMILES. - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS - unpooled(): returns token embeddings filtered to valid tokens (specials removed), plus a 1-mask of length Li (since already filtered). """ def __init__( self, device: torch.device, vocab_path: str, splits_path: str, clm_name: str = "aaronfeller/PeptideCLM-23M-all", max_len: int = 512, use_cache: bool = True, ): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if len(self.special_ids) else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "cls_token_id", None), getattr(tokenizer, "sep_token_id", None), getattr(tokenizer, "bos_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "mask_token_id", None), ] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]: tok = self.tokenizer( smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len, ) for k in tok: tok[k] = tok[k].to(self.device) if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok @torch.no_grad() def pooled(self, smiles: str) -> torch.Tensor: s = smiles.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) vf = valid.unsqueeze(-1).float() summed = (h * vf).sum(dim=1) # (1,H) denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1) pooled = summed / denom # (1,H) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: X: (1, Li, H) float32 on device M: (1, Li) bool on device where Li excludes padding + special tokens. """ s = smiles.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) # filter valid tokens keep = valid[0] # (L,) X = h[:, keep, :] # (1,Li,H) M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M class WTEmbedder: """ ESM2 embeddings for AA sequences. - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...} - unpooled(): returns token embeddings filtered to valid tokens (specials removed), plus a 1-mask of length Li (since already filtered). """ def __init__( self, device: torch.device, esm_name: str = "facebook/esm2_t33_650M_UR50D", max_len: int = 1022, use_cache: bool = True, ): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = EsmTokenizer.from_pretrained(esm_name) self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if len(self.special_ids) else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "cls_token_id", None), getattr(tokenizer, "sep_token_id", None), getattr(tokenizer, "bos_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "mask_token_id", None), ] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]: tok = self.tokenizer( seq_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len, ) tok = {k: v.to(self.device) for k, v in tok.items()} if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok @torch.no_grad() def pooled(self, seq: str) -> torch.Tensor: s = seq.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(**tok) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) vf = valid.unsqueeze(-1).float() summed = (h * vf).sum(dim=1) # (1,H) denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1) pooled = summed / denom # (1,H) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: X: (1, Li, H) float32 on device M: (1, Li) bool on device where Li excludes padding + special tokens. """ s = seq.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(**tok) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) keep = valid[0] # (L,) X = h[:, keep, :] # (1,Li,H) M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M # ----------------------------- # Predictor # ----------------------------- class PeptiVersePredictor: """ - loads best models from training_classifiers/ - computes embeddings as needed (pooled/unpooled) - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled. """ def __init__( self, manifest_path: str | Path, classifier_weight_root: str | Path, esm_name="facebook/esm2_t33_650M_UR50D", clm_name="aaronfeller/PeptideCLM-23M-all", smiles_vocab="tokenizer/new_vocab.txt", smiles_splits="tokenizer/new_splits.txt", device: Optional[str] = None, ): self.root = Path(classifier_weight_root) self.training_root = self.root / "training_classifiers" self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.manifest = read_best_manifest_csv(manifest_path) self.wt_embedder = WTEmbedder(self.device) self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, vocab_path=str(self.root / smiles_vocab), splits_path=str(self.root / smiles_splits)) self.models: Dict[Tuple[str, str], Any] = {} self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} self._load_all_best_models() def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path: # map halflife -> half_life folder on disk (common layout) disk_prop = "half_life" if prop_key == "halflife" else prop_key base = self.training_root / disk_prop # special handling for halflife xgb_wt_log / xgb_smiles if prop_key == "halflife" and model_name in {"xgb_wt_log", "xgb_smiles"}: d = base / model_name if d.exists(): return d # special handling for halflife transformer wt log folder if prop_key == "halflife" and mode == "wt" and model_name == "transformer": d = base / "transformer_wt_log" if d.exists(): return d if prop_key == "halflife" and model_name == "xgb": d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles") if d.exists(): return d candidates = [ base / f"{model_name}_{mode}", base / model_name, ] if mode == "wt": candidates += [base / f"{model_name}_wt"] if mode == "smiles": candidates += [base / f"{model_name}_smiles"] for d in candidates: if d.exists(): return d raise FileNotFoundError( f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}" ) def _load_all_best_models(self): for prop_key, row in self.manifest.items(): for mode, label, thr in [ ("wt", row.best_wt, row.thr_wt), ("smiles", row.best_smiles, row.thr_smiles), ]: m = canon_model(label) if m is None: continue # ---- binding affinity special ---- if prop_key == "binding_affinity": # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_* pooled_or_unpooled = m # "pooled" or "unpooled" folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc. model_dir = self.training_root / "binding_affinity" / folder art = find_best_artifact(model_dir) if art.suffix != ".pt": raise RuntimeError(f"Binding model expected best_model.pt, got {art}") model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device) self.models[(prop_key, mode)] = model self.meta[(prop_key, mode)] = { "task_type": "Regression", "threshold": None, "artifact": str(art), "model_name": pooled_or_unpooled, } continue model_dir = self._resolve_dir(prop_key, m, mode) kind, obj, art = load_artifact(model_dir, self.device) if kind in {"xgb", "joblib"}: self.models[(prop_key, mode)] = obj else: # rebuild NN architecture arch = m if arch.startswith("transformer"): arch = "transformer" elif arch.startswith("mlp"): arch = "mlp" elif arch.startswith("cnn"): arch = "cnn" self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device) self.meta[(prop_key, mode)] = { "task_type": row.task_type, "threshold": thr, "artifact": str(art), "model_name": m, "kind": kind, } def _get_features_for_model(self, prop_key: str, mode: str, input_str: str): """ Returns either: - pooled np array shape (1,H) for xgb/joblib - unpooled torch tensors (X,M) for NN """ model = self.models[(prop_key, mode)] meta = self.meta[(prop_key, mode)] kind = meta.get("kind", None) model_name = meta.get("model_name", "") if prop_key == "binding_affinity": raise RuntimeError("Use predict_binding_affinity().") # If torch NN: needs unpooled if kind == "torch_ckpt": if mode == "wt": X, M = self.wt_embedder.unpooled(input_str) else: X, M = self.smiles_embedder.unpooled(input_str) return X, M # Otherwise pooled vectors for xgb/joblib if mode == "wt": v = self.wt_embedder.pooled(input_str) # (1,H) else: v = self.smiles_embedder.pooled(input_str) # (1,H) feats = v.detach().cpu().numpy().astype(np.float32) feats = np.nan_to_num(feats, nan=0.0) feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) return feats def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]: """ mode: "wt" for AA sequence input, "smiles" for SMILES input Returns dict with score + label if classifier threshold exists. """ if (prop_key, mode) not in self.models: raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.") meta = self.meta[(prop_key, mode)] model = self.models[(prop_key, mode)] task_type = meta["task_type"].lower() thr = meta.get("threshold", None) kind = meta.get("kind", None) if prop_key == "binding_affinity": raise RuntimeError("Use predict_binding_affinity().") # NN path (logits / regression) if kind == "torch_ckpt": X, M = self._get_features_for_model(prop_key, mode, input_str) with torch.no_grad(): y = model(X, M).squeeze().float().cpu().item() # invert log1p(hours) ONLY for WT half-life log models model_name = meta.get("model_name", "") if ( prop_key == "halflife" and mode == "wt" and model_name in {"xgb_wt_log", "transformer_wt_log"} ): y = float(np.expm1(y)) if task_type == "classifier": prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit) out = {"property": prop_key, "mode": mode, "score": prob} if thr is not None: out["label"] = int(prob >= float(thr)) out["threshold"] = float(thr) return out else: return {"property": prop_key, "mode": mode, "score": float(y)} if kind == "xgb": feats = self._get_features_for_model(prop_key, mode, input_str) dmat = xgb.DMatrix(feats) pred = float(model.predict(dmat)[0]) # invert log1p(hours) ONLY for WT half-life log models model_name = meta.get("model_name", "") if ( prop_key == "halflife" and mode == "wt" and model_name in {"xgb_wt_log", "transformer_wt_log"} ): pred = float(np.expm1(pred)) out = {"property": prop_key, "mode": mode, "score": pred} return out # joblib path (svm/enet/svr) if kind == "joblib": feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H) # classifier vs regressor behavior differs by estimator if task_type == "classifier": if hasattr(model, "predict_proba"): pred = float(model.predict_proba(feats)[:, 1][0]) else: if hasattr(model, "decision_function"): logit = float(model.decision_function(feats)[0]) pred = float(1.0 / (1.0 + np.exp(-logit))) else: pred = float(model.predict(feats)[0]) out = {"property": prop_key, "mode": mode, "score": pred} if thr is not None: out["label"] = int(pred >= float(thr)) out["threshold"] = float(thr) return out else: pred = float(model.predict(feats)[0]) return {"property": prop_key, "mode": mode, "score": pred} raise RuntimeError(f"Unknown model kind={kind}") def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]: """ mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled) "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled) """ prop_key = "binding_affinity" if (prop_key, mode) not in self.models: raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).") model = self.models[(prop_key, mode)] pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled # target is always WT sequence (ESM) if pooled_or_unpooled == "pooled": t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht) if mode == "wt": b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb) else: b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb) with torch.no_grad(): reg, logits = model(t_vec, b_vec) affinity = float(reg.squeeze().cpu().item()) cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) cls_thr = affinity_to_class(affinity) else: T, Mt = self.wt_embedder.unpooled(target_seq) if mode == "wt": B, Mb = self.wt_embedder.unpooled(binder_str) else: B, Mb = self.smiles_embedder.unpooled(binder_str) with torch.no_grad(): reg, logits = model(T, Mt, B, Mb) affinity = float(reg.squeeze().cpu().item()) cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) cls_thr = affinity_to_class(affinity) names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"} return { "property": "binding_affinity", "mode": mode, "affinity": affinity, "class_by_threshold": names[cls_thr], "class_by_logits": names[cls_logit], "binding_model": pooled_or_unpooled, } if __name__ == "__main__": predictor = PeptiVersePredictor( manifest_path="best_models.txt", classifier_weight_root="./" ) print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ")) print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="...")) # Test Embedding # """ device = torch.device("cuda:0") wt = WTEmbedder(device) sm = SMILESEmbedder(device, vocab_path="./tokeizner/new_vocab.txt", splits_path="./tokenizer/new_splits.txt" ) p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280) X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li) p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles) X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li) """