AlienChen commited on
Commit
f122106
·
verified ·
1 Parent(s): fe763fa

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +210 -43
models/peptide_classifiers.py CHANGED
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
- self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
@@ -544,47 +544,58 @@ class HemolysisModel:
544
  scores = self.get_scores(input_seqs)
545
  return scores
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  class NonfoulingModel:
548
  def __init__(self, device):
549
- # change model path
550
- self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
551
-
 
 
 
 
552
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
553
  self.model.eval()
554
 
555
  self.device = device
556
 
557
- def generate_embeddings(self, sequences):
558
- """Generate ESM embeddings for protein sequences"""
559
  with torch.no_grad():
560
- embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
561
- embeddings = embeddings.cpu().numpy()
562
-
563
- return embeddings
564
 
565
- def get_scores(self, input_seqs):
566
- scores = np.zeros(len(input_seqs))
567
- features = self.generate_embeddings(input_seqs)
568
-
569
- if len(features) == 0:
570
- return scores
571
-
572
- features = np.nan_to_num(features, nan=0.)
573
- features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
574
-
575
- features = xgb.DMatrix(features)
576
-
577
- scores = self.predictor.predict(features)
578
- return torch.from_numpy(scores).to(self.device)
579
 
580
- def __call__(self, input_seqs: list):
581
- scores = self.get_scores(input_seqs)
582
- return scores
583
-
584
  class SolubilityModel:
585
  def __init__(self, device):
586
  # change model path
587
- self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
588
 
589
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
590
  self.model.eval()
@@ -624,7 +635,8 @@ class SolubilityModelNew:
624
  self.device = device
625
 
626
  def get_scores(self, x):
627
- mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
 
628
  ratios = mask.float().mean(dim=1)
629
  return 1 - ratios
630
 
@@ -663,24 +675,179 @@ class PeptideCNN(nn.Module):
663
  return features
664
  return self.predictor(features) # Output shape: (B, 1)
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  class HalfLifeModel:
667
- def __init__(self, device):
668
- input_dim = 1280
669
- hidden_dims = [input_dim // 2, input_dim // 4]
670
- output_dim = input_dim // 8
671
- dropout_rate = 0.3
672
- self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
673
- self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
674
- self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
- def __call__(self, x):
677
- prediction = self.model(x, return_features=False)
678
- halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
679
- return halflife / 2
680
 
681
 
682
  def load_bindevaluator(checkpoint_path, device):
683
- bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
684
  bindevaluator.eval()
685
  for param in bindevaluator.parameters():
686
  param.requires_grad = False
 
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
 
544
  scores = self.get_scores(input_seqs)
545
  return scores
546
 
547
+ # ======================== MLP =========================================
548
+ # Still need mean pooling along lengths
549
+ class MaskedMeanPool(nn.Module):
550
+ def forward(self, X, M): # X: (B,L,H), M: (B,L)
551
+ Mf = M.unsqueeze(-1).float()
552
+ denom = Mf.sum(dim=1).clamp(min=1.0)
553
+ return (X * Mf).sum(dim=1) / denom # (B,H)
554
+
555
+ class MLPClassifier(nn.Module):
556
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
557
+ super().__init__()
558
+ self.pool = MaskedMeanPool()
559
+ self.net = nn.Sequential(
560
+ nn.Linear(in_dim, hidden),
561
+ nn.GELU(),
562
+ nn.Dropout(dropout),
563
+ nn.Linear(hidden, 1),
564
+ )
565
+ def forward(self, X, M):
566
+ z = self.pool(X, M)
567
+ return self.net(z).squeeze(-1) # logits
568
+ # ======================== MLP =========================================
569
+
570
  class NonfoulingModel:
571
  def __init__(self, device):
572
+ ckpt = torch.load('/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
573
+ best_params = ckpt["best_params"]
574
+ self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
575
+ self.predictor.load_state_dict(ckpt["state_dict"])
576
+ self.predictor = self.predictor.to(device)
577
+ self.predictor.eval()
578
+
579
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
580
  self.model.eval()
581
 
582
  self.device = device
583
 
584
+ def get_scores(self, input_ids, attention_mask):
 
585
  with torch.no_grad():
586
+ features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
587
+ scores = self.predictor(features, attention_mask)
588
+ return scores
 
589
 
590
+ def __call__(self, input_ids):
591
+ attention_mask = torch.ones_like(input_ids).to(self.device)
592
+ scores = self.get_scores(input_ids, attention_mask)
593
+ return 1.0 / (1.0 + torch.exp(-scores))
 
 
 
 
 
 
 
 
 
 
594
 
 
 
 
 
595
  class SolubilityModel:
596
  def __init__(self, device):
597
  # change model path
598
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
599
 
600
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
601
  self.model.eval()
 
635
  self.device = device
636
 
637
  def get_scores(self, x):
638
+ a = x[:, 1:-1]
639
+ mask = (a.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
640
  ratios = mask.float().mean(dim=1)
641
  return 1 - ratios
642
 
 
675
  return features
676
  return self.predictor(features) # Output shape: (B, 1)
677
 
678
+ # class HalfLifeModel:
679
+ # def __init__(self, device):
680
+ # input_dim = 1280
681
+ # hidden_dims = [input_dim // 2, input_dim // 4]
682
+ # output_dim = input_dim // 8
683
+ # dropout_rate = 0.3
684
+ # self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
685
+ # self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
686
+ # self.model.eval()
687
+
688
+ # def __call__(self, x):
689
+ # prediction = self.model(x, return_features=False)
690
+ # halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
691
+ # return halflife / 2
692
+
693
+
694
+ # -----------------------------
695
+ # Model definition (must match training)
696
+ # -----------------------------
697
+ class TransformerRegressor(nn.Module):
698
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
699
+ super().__init__()
700
+ self.proj = nn.Linear(in_dim, d_model)
701
+ enc_layer = nn.TransformerEncoderLayer(
702
+ d_model=d_model,
703
+ nhead=nhead,
704
+ dim_feedforward=ff,
705
+ dropout=dropout,
706
+ batch_first=True,
707
+ activation="gelu",
708
+ )
709
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
710
+ self.head = nn.Linear(d_model, 1)
711
+
712
+ def forward(self, X, M):
713
+ # M: True = keep token, False = padding
714
+ pad_mask = ~M
715
+ Z = self.proj(X)
716
+ Z = self.enc(Z, src_key_padding_mask=pad_mask)
717
+ Mf = M.unsqueeze(-1).float()
718
+ denom = Mf.sum(dim=1).clamp(min=1.0)
719
+ pooled = (Z * Mf).sum(dim=1) / denom
720
+ return self.head(pooled).squeeze(-1)
721
+
722
+
723
+ def build_model(model_name: str, in_dim: int, params: dict) -> nn.Module:
724
+ # In your training code, transformer uses fixed architecture values (d_model/nhead/layers/ff/dropout).
725
+ # (See build_model in finetune_nn_cv.py :contentReference[oaicite:2]{index=2})
726
+ if model_name != "transformer":
727
+ raise ValueError(f"This inference file currently supports model_name='transformer', got: {model_name}")
728
+ return TransformerRegressor(
729
+ in_dim=in_dim,
730
+ d_model=384,
731
+ nhead=4,
732
+ layers=1,
733
+ ff=512,
734
+ dropout=0.1521676463658988,
735
+ )
736
+
737
+ def _clean_state_dict(state_dict: dict) -> dict:
738
+ cleaned = {}
739
+ for k, v in state_dict.items():
740
+ if k.startswith("module."):
741
+ k = k[len("module.") :]
742
+ if k.startswith("model."):
743
+ k = k[len("model.") :]
744
+ cleaned[k] = v
745
+ return cleaned
746
+
747
  class HalfLifeModel:
748
+ """
749
+ Loads:
750
+ - ESM2 encoder to generate *unpooled* token embeddings (per residue)
751
+ - Your fine-tuned TransformerRegressor from final_model.pt
752
+
753
+ By default, __call__ returns "hours":
754
+ - if ckpt['target_col'] == 'log_label' -> expm1(pred)
755
+ - else -> raw pred
756
+ """
757
+
758
+ def __init__(
759
+ self,
760
+ device,
761
+ ckpt_path = "/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer_log/final_model.pt",
762
+ ):
763
+ self.device = device
764
+
765
+ # --- load NN checkpoint (saved by your finetune script) ---
766
+ ckpt = torch.load(ckpt_path, map_location="cpu")
767
+ if not isinstance(ckpt, dict) or "state_dict" not in ckpt:
768
+ raise ValueError(f"Checkpoint at {ckpt_path} is not the expected dict with a 'state_dict' key.")
769
+
770
+ self.best_params = ckpt.get("best_params", {})
771
+ self.in_dim = int(ckpt.get("in_dim"))
772
+ self.target_col = ckpt.get("target_col", "label") # 'log_label' or 'label'
773
+
774
+ # --- build + load regressor ---
775
+ self.regressor = build_model(model_name="transformer", in_dim=self.in_dim, params=self.best_params)
776
+ self.regressor.load_state_dict(_clean_state_dict(ckpt["state_dict"]), strict=True)
777
+ self.regressor.to(self.device)
778
+ self.regressor.eval()
779
+
780
+ # --- ESM2 embedding model ---
781
+ self.emb_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.device)
782
+ self.emb_model.eval()
783
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
784
+
785
+ # sanity: ESM2 hidden size should match training in_dim
786
+ esm_hidden = int(self.emb_model.config.hidden_size)
787
+ if esm_hidden != self.in_dim:
788
+ raise ValueError(
789
+ f"Mismatch: ESM hidden_size={esm_hidden}, but checkpoint in_dim={self.in_dim}.\n"
790
+ f"Did you train on a different embedding model/dimension than facebook/esm2_t33_650M_UR50D?"
791
+ )
792
+
793
+ @torch.no_grad()
794
+ def _embed_unpooled_batch(self, sequences):
795
+ out = self.emb_model(input_ids=sequences)
796
+ hs = out.last_hidden_state # (B, T, H)
797
+
798
+ per_seq = []
799
+ lengths = []
800
+
801
+ for i in range(hs.shape[0]):
802
+ emb = hs[i, 1:-1, :] # (L, H)
803
+ per_seq.append(emb)
804
+ lengths.append(int(emb.shape[0]))
805
+
806
+ Lmax = max(lengths) if lengths else 0
807
+ H = hs.shape[-1]
808
+ X = hs.new_zeros((len(sequences), Lmax, H), dtype=torch.float32)
809
+ M = torch.zeros((len(sequences), Lmax), dtype=torch.bool, device=self.device)
810
+
811
+ for i, emb in enumerate(per_seq):
812
+ L = emb.shape[0]
813
+ if L == 0:
814
+ continue
815
+ X[i, :L, :] = emb.to(torch.float32)
816
+ M[i, :L] = True
817
+
818
+ return X, M
819
+
820
+ @torch.no_grad()
821
+ def predict_raw(self, input_seqs):
822
+ """
823
+ Returns the regressor output in the same space as training target_col:
824
+ - if trained on log_label -> returns log1p(hours)
825
+ - if trained on label -> returns hours (or whatever label scale was)
826
+ """
827
+ if len(input_seqs) == 0:
828
+ return np.array([], dtype=np.float32)
829
+
830
+ X, M = self._embed_unpooled_batch(input_seqs)
831
+ yhat = self.regressor(X, M).detach().cpu().numpy().astype(np.float32) # (B,)
832
+ # pdb.set_trace()
833
+ return yhat
834
+
835
+ def predict_hours(self, input_seqs) -> np.ndarray:
836
+ """
837
+ If your model was trained on log_label, convert back to hours via expm1.
838
+ Otherwise returns raw predictions.
839
+ """
840
+ raw = self.predict_raw(input_seqs)
841
+ if self.target_col == "log_label":
842
+ return np.expm1(raw).astype(np.float32)
843
+ return raw.astype(np.float32)
844
 
845
+ def __call__(self, input_seqs) -> np.ndarray:
846
+ return torch.from_numpy(self.predict_hours(input_seqs)).to(self.device)
 
 
847
 
848
 
849
  def load_bindevaluator(checkpoint_path, device):
850
+ bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
851
  bindevaluator.eval()
852
  for param in bindevaluator.parameters():
853
  param.requires_grad = False