AlienChen commited on
Commit
f711381
·
verified ·
1 Parent(s): 3c92e07

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +4 -4
models/peptide_classifiers.py CHANGED
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
- self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
@@ -547,7 +547,7 @@ class HemolysisModel:
547
  class NonfoulingModel:
548
  def __init__(self, device):
549
  # change model path
550
- self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_nonfouling.json')
551
 
552
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
553
  self.model.eval()
@@ -584,7 +584,7 @@ class NonfoulingModel:
584
  class SolubilityModel:
585
  def __init__(self, device):
586
  # change model path
587
- self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
588
 
589
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
590
  self.model.eval()
@@ -670,7 +670,7 @@ class HalfLifeModel:
670
  output_dim = input_dim // 8
671
  dropout_rate = 0.3
672
  self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
673
- self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
674
  self.model.eval()
675
 
676
  def __call__(self, x):
 
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
 
547
  class NonfoulingModel:
548
  def __init__(self, device):
549
  # change model path
550
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
551
 
552
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
553
  self.model.eval()
 
584
  class SolubilityModel:
585
  def __init__(self, device):
586
  # change model path
587
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
588
 
589
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
590
  self.model.eval()
 
670
  output_dim = input_dim // 8
671
  dropout_rate = 0.3
672
  self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
673
+ self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
674
  self.model.eval()
675
 
676
  def __call__(self, x):