Update models/peptide_classifiers.py
Browse files- models/peptide_classifiers.py +210 -43
models/peptide_classifiers.py
CHANGED
|
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
-
self.predictor = xgb.Booster(model_file='
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
@@ -544,47 +544,58 @@ class HemolysisModel:
|
|
| 544 |
scores = self.get_scores(input_seqs)
|
| 545 |
return scores
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
class NonfoulingModel:
|
| 548 |
def __init__(self, device):
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 553 |
self.model.eval()
|
| 554 |
|
| 555 |
self.device = device
|
| 556 |
|
| 557 |
-
def
|
| 558 |
-
"""Generate ESM embeddings for protein sequences"""
|
| 559 |
with torch.no_grad():
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
return embeddings
|
| 564 |
|
| 565 |
-
def
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
if len(features) == 0:
|
| 570 |
-
return scores
|
| 571 |
-
|
| 572 |
-
features = np.nan_to_num(features, nan=0.)
|
| 573 |
-
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 574 |
-
|
| 575 |
-
features = xgb.DMatrix(features)
|
| 576 |
-
|
| 577 |
-
scores = self.predictor.predict(features)
|
| 578 |
-
return torch.from_numpy(scores).to(self.device)
|
| 579 |
|
| 580 |
-
def __call__(self, input_seqs: list):
|
| 581 |
-
scores = self.get_scores(input_seqs)
|
| 582 |
-
return scores
|
| 583 |
-
|
| 584 |
class SolubilityModel:
|
| 585 |
def __init__(self, device):
|
| 586 |
# change model path
|
| 587 |
-
self.predictor = xgb.Booster(model_file='
|
| 588 |
|
| 589 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 590 |
self.model.eval()
|
|
@@ -624,7 +635,8 @@ class SolubilityModelNew:
|
|
| 624 |
self.device = device
|
| 625 |
|
| 626 |
def get_scores(self, x):
|
| 627 |
-
|
|
|
|
| 628 |
ratios = mask.float().mean(dim=1)
|
| 629 |
return 1 - ratios
|
| 630 |
|
|
@@ -663,24 +675,179 @@ class PeptideCNN(nn.Module):
|
|
| 663 |
return features
|
| 664 |
return self.predictor(features) # Output shape: (B, 1)
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
class HalfLifeModel:
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
-
def __call__(self,
|
| 677 |
-
|
| 678 |
-
halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
|
| 679 |
-
return halflife / 2
|
| 680 |
|
| 681 |
|
| 682 |
def load_bindevaluator(checkpoint_path, device):
|
| 683 |
-
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path,
|
| 684 |
bindevaluator.eval()
|
| 685 |
for param in bindevaluator.parameters():
|
| 686 |
param.requires_grad = False
|
|
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_hemolysis.json')
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
|
|
| 544 |
scores = self.get_scores(input_seqs)
|
| 545 |
return scores
|
| 546 |
|
| 547 |
+
# ======================== MLP =========================================
|
| 548 |
+
# Still need mean pooling along lengths
|
| 549 |
+
class MaskedMeanPool(nn.Module):
|
| 550 |
+
def forward(self, X, M): # X: (B,L,H), M: (B,L)
|
| 551 |
+
Mf = M.unsqueeze(-1).float()
|
| 552 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 553 |
+
return (X * Mf).sum(dim=1) / denom # (B,H)
|
| 554 |
+
|
| 555 |
+
class MLPClassifier(nn.Module):
|
| 556 |
+
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 557 |
+
super().__init__()
|
| 558 |
+
self.pool = MaskedMeanPool()
|
| 559 |
+
self.net = nn.Sequential(
|
| 560 |
+
nn.Linear(in_dim, hidden),
|
| 561 |
+
nn.GELU(),
|
| 562 |
+
nn.Dropout(dropout),
|
| 563 |
+
nn.Linear(hidden, 1),
|
| 564 |
+
)
|
| 565 |
+
def forward(self, X, M):
|
| 566 |
+
z = self.pool(X, M)
|
| 567 |
+
return self.net(z).squeeze(-1) # logits
|
| 568 |
+
# ======================== MLP =========================================
|
| 569 |
+
|
| 570 |
class NonfoulingModel:
|
| 571 |
def __init__(self, device):
|
| 572 |
+
ckpt = torch.load('/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
|
| 573 |
+
best_params = ckpt["best_params"]
|
| 574 |
+
self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
|
| 575 |
+
self.predictor.load_state_dict(ckpt["state_dict"])
|
| 576 |
+
self.predictor = self.predictor.to(device)
|
| 577 |
+
self.predictor.eval()
|
| 578 |
+
|
| 579 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 580 |
self.model.eval()
|
| 581 |
|
| 582 |
self.device = device
|
| 583 |
|
| 584 |
+
def get_scores(self, input_ids, attention_mask):
|
|
|
|
| 585 |
with torch.no_grad():
|
| 586 |
+
features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| 587 |
+
scores = self.predictor(features, attention_mask)
|
| 588 |
+
return scores
|
|
|
|
| 589 |
|
| 590 |
+
def __call__(self, input_ids):
|
| 591 |
+
attention_mask = torch.ones_like(input_ids).to(self.device)
|
| 592 |
+
scores = self.get_scores(input_ids, attention_mask)
|
| 593 |
+
return 1.0 / (1.0 + torch.exp(-scores))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
class SolubilityModel:
|
| 596 |
def __init__(self, device):
|
| 597 |
# change model path
|
| 598 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
|
| 599 |
|
| 600 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 601 |
self.model.eval()
|
|
|
|
| 635 |
self.device = device
|
| 636 |
|
| 637 |
def get_scores(self, x):
|
| 638 |
+
a = x[:, 1:-1]
|
| 639 |
+
mask = (a.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
|
| 640 |
ratios = mask.float().mean(dim=1)
|
| 641 |
return 1 - ratios
|
| 642 |
|
|
|
|
| 675 |
return features
|
| 676 |
return self.predictor(features) # Output shape: (B, 1)
|
| 677 |
|
| 678 |
+
# class HalfLifeModel:
|
| 679 |
+
# def __init__(self, device):
|
| 680 |
+
# input_dim = 1280
|
| 681 |
+
# hidden_dims = [input_dim // 2, input_dim // 4]
|
| 682 |
+
# output_dim = input_dim // 8
|
| 683 |
+
# dropout_rate = 0.3
|
| 684 |
+
# self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 685 |
+
# self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
|
| 686 |
+
# self.model.eval()
|
| 687 |
+
|
| 688 |
+
# def __call__(self, x):
|
| 689 |
+
# prediction = self.model(x, return_features=False)
|
| 690 |
+
# halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
|
| 691 |
+
# return halflife / 2
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
# -----------------------------
|
| 695 |
+
# Model definition (must match training)
|
| 696 |
+
# -----------------------------
|
| 697 |
+
class TransformerRegressor(nn.Module):
|
| 698 |
+
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 699 |
+
super().__init__()
|
| 700 |
+
self.proj = nn.Linear(in_dim, d_model)
|
| 701 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 702 |
+
d_model=d_model,
|
| 703 |
+
nhead=nhead,
|
| 704 |
+
dim_feedforward=ff,
|
| 705 |
+
dropout=dropout,
|
| 706 |
+
batch_first=True,
|
| 707 |
+
activation="gelu",
|
| 708 |
+
)
|
| 709 |
+
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 710 |
+
self.head = nn.Linear(d_model, 1)
|
| 711 |
+
|
| 712 |
+
def forward(self, X, M):
|
| 713 |
+
# M: True = keep token, False = padding
|
| 714 |
+
pad_mask = ~M
|
| 715 |
+
Z = self.proj(X)
|
| 716 |
+
Z = self.enc(Z, src_key_padding_mask=pad_mask)
|
| 717 |
+
Mf = M.unsqueeze(-1).float()
|
| 718 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 719 |
+
pooled = (Z * Mf).sum(dim=1) / denom
|
| 720 |
+
return self.head(pooled).squeeze(-1)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def build_model(model_name: str, in_dim: int, params: dict) -> nn.Module:
|
| 724 |
+
# In your training code, transformer uses fixed architecture values (d_model/nhead/layers/ff/dropout).
|
| 725 |
+
# (See build_model in finetune_nn_cv.py :contentReference[oaicite:2]{index=2})
|
| 726 |
+
if model_name != "transformer":
|
| 727 |
+
raise ValueError(f"This inference file currently supports model_name='transformer', got: {model_name}")
|
| 728 |
+
return TransformerRegressor(
|
| 729 |
+
in_dim=in_dim,
|
| 730 |
+
d_model=384,
|
| 731 |
+
nhead=4,
|
| 732 |
+
layers=1,
|
| 733 |
+
ff=512,
|
| 734 |
+
dropout=0.1521676463658988,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
def _clean_state_dict(state_dict: dict) -> dict:
|
| 738 |
+
cleaned = {}
|
| 739 |
+
for k, v in state_dict.items():
|
| 740 |
+
if k.startswith("module."):
|
| 741 |
+
k = k[len("module.") :]
|
| 742 |
+
if k.startswith("model."):
|
| 743 |
+
k = k[len("model.") :]
|
| 744 |
+
cleaned[k] = v
|
| 745 |
+
return cleaned
|
| 746 |
+
|
| 747 |
class HalfLifeModel:
|
| 748 |
+
"""
|
| 749 |
+
Loads:
|
| 750 |
+
- ESM2 encoder to generate *unpooled* token embeddings (per residue)
|
| 751 |
+
- Your fine-tuned TransformerRegressor from final_model.pt
|
| 752 |
+
|
| 753 |
+
By default, __call__ returns "hours":
|
| 754 |
+
- if ckpt['target_col'] == 'log_label' -> expm1(pred)
|
| 755 |
+
- else -> raw pred
|
| 756 |
+
"""
|
| 757 |
+
|
| 758 |
+
def __init__(
|
| 759 |
+
self,
|
| 760 |
+
device,
|
| 761 |
+
ckpt_path = "/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer_log/final_model.pt",
|
| 762 |
+
):
|
| 763 |
+
self.device = device
|
| 764 |
+
|
| 765 |
+
# --- load NN checkpoint (saved by your finetune script) ---
|
| 766 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 767 |
+
if not isinstance(ckpt, dict) or "state_dict" not in ckpt:
|
| 768 |
+
raise ValueError(f"Checkpoint at {ckpt_path} is not the expected dict with a 'state_dict' key.")
|
| 769 |
+
|
| 770 |
+
self.best_params = ckpt.get("best_params", {})
|
| 771 |
+
self.in_dim = int(ckpt.get("in_dim"))
|
| 772 |
+
self.target_col = ckpt.get("target_col", "label") # 'log_label' or 'label'
|
| 773 |
+
|
| 774 |
+
# --- build + load regressor ---
|
| 775 |
+
self.regressor = build_model(model_name="transformer", in_dim=self.in_dim, params=self.best_params)
|
| 776 |
+
self.regressor.load_state_dict(_clean_state_dict(ckpt["state_dict"]), strict=True)
|
| 777 |
+
self.regressor.to(self.device)
|
| 778 |
+
self.regressor.eval()
|
| 779 |
+
|
| 780 |
+
# --- ESM2 embedding model ---
|
| 781 |
+
self.emb_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.device)
|
| 782 |
+
self.emb_model.eval()
|
| 783 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 784 |
+
|
| 785 |
+
# sanity: ESM2 hidden size should match training in_dim
|
| 786 |
+
esm_hidden = int(self.emb_model.config.hidden_size)
|
| 787 |
+
if esm_hidden != self.in_dim:
|
| 788 |
+
raise ValueError(
|
| 789 |
+
f"Mismatch: ESM hidden_size={esm_hidden}, but checkpoint in_dim={self.in_dim}.\n"
|
| 790 |
+
f"Did you train on a different embedding model/dimension than facebook/esm2_t33_650M_UR50D?"
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
@torch.no_grad()
|
| 794 |
+
def _embed_unpooled_batch(self, sequences):
|
| 795 |
+
out = self.emb_model(input_ids=sequences)
|
| 796 |
+
hs = out.last_hidden_state # (B, T, H)
|
| 797 |
+
|
| 798 |
+
per_seq = []
|
| 799 |
+
lengths = []
|
| 800 |
+
|
| 801 |
+
for i in range(hs.shape[0]):
|
| 802 |
+
emb = hs[i, 1:-1, :] # (L, H)
|
| 803 |
+
per_seq.append(emb)
|
| 804 |
+
lengths.append(int(emb.shape[0]))
|
| 805 |
+
|
| 806 |
+
Lmax = max(lengths) if lengths else 0
|
| 807 |
+
H = hs.shape[-1]
|
| 808 |
+
X = hs.new_zeros((len(sequences), Lmax, H), dtype=torch.float32)
|
| 809 |
+
M = torch.zeros((len(sequences), Lmax), dtype=torch.bool, device=self.device)
|
| 810 |
+
|
| 811 |
+
for i, emb in enumerate(per_seq):
|
| 812 |
+
L = emb.shape[0]
|
| 813 |
+
if L == 0:
|
| 814 |
+
continue
|
| 815 |
+
X[i, :L, :] = emb.to(torch.float32)
|
| 816 |
+
M[i, :L] = True
|
| 817 |
+
|
| 818 |
+
return X, M
|
| 819 |
+
|
| 820 |
+
@torch.no_grad()
|
| 821 |
+
def predict_raw(self, input_seqs):
|
| 822 |
+
"""
|
| 823 |
+
Returns the regressor output in the same space as training target_col:
|
| 824 |
+
- if trained on log_label -> returns log1p(hours)
|
| 825 |
+
- if trained on label -> returns hours (or whatever label scale was)
|
| 826 |
+
"""
|
| 827 |
+
if len(input_seqs) == 0:
|
| 828 |
+
return np.array([], dtype=np.float32)
|
| 829 |
+
|
| 830 |
+
X, M = self._embed_unpooled_batch(input_seqs)
|
| 831 |
+
yhat = self.regressor(X, M).detach().cpu().numpy().astype(np.float32) # (B,)
|
| 832 |
+
# pdb.set_trace()
|
| 833 |
+
return yhat
|
| 834 |
+
|
| 835 |
+
def predict_hours(self, input_seqs) -> np.ndarray:
|
| 836 |
+
"""
|
| 837 |
+
If your model was trained on log_label, convert back to hours via expm1.
|
| 838 |
+
Otherwise returns raw predictions.
|
| 839 |
+
"""
|
| 840 |
+
raw = self.predict_raw(input_seqs)
|
| 841 |
+
if self.target_col == "log_label":
|
| 842 |
+
return np.expm1(raw).astype(np.float32)
|
| 843 |
+
return raw.astype(np.float32)
|
| 844 |
|
| 845 |
+
def __call__(self, input_seqs) -> np.ndarray:
|
| 846 |
+
return torch.from_numpy(self.predict_hours(input_seqs)).to(self.device)
|
|
|
|
|
|
|
| 847 |
|
| 848 |
|
| 849 |
def load_bindevaluator(checkpoint_path, device):
|
| 850 |
+
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
|
| 851 |
bindevaluator.eval()
|
| 852 |
for param in bindevaluator.parameters():
|
| 853 |
param.requires_grad = False
|