Update models/peptide_classifiers.py
Browse files
models/peptide_classifiers.py
CHANGED
|
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
-
self.predictor = xgb.Booster(model_file='
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
@@ -547,7 +547,7 @@ class HemolysisModel:
|
|
| 547 |
class NonfoulingModel:
|
| 548 |
def __init__(self, device):
|
| 549 |
# change model path
|
| 550 |
-
self.predictor = xgb.Booster(model_file='
|
| 551 |
|
| 552 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 553 |
self.model.eval()
|
|
@@ -584,7 +584,7 @@ class NonfoulingModel:
|
|
| 584 |
class SolubilityModel:
|
| 585 |
def __init__(self, device):
|
| 586 |
# change model path
|
| 587 |
-
self.predictor = xgb.Booster(model_file='
|
| 588 |
|
| 589 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 590 |
self.model.eval()
|
|
@@ -670,7 +670,7 @@ class HalfLifeModel:
|
|
| 670 |
output_dim = input_dim // 8
|
| 671 |
dropout_rate = 0.3
|
| 672 |
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 673 |
-
self.model.load_state_dict(torch.load('
|
| 674 |
self.model.eval()
|
| 675 |
|
| 676 |
def __call__(self, x):
|
|
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
|
|
| 547 |
class NonfoulingModel:
|
| 548 |
def __init__(self, device):
|
| 549 |
# change model path
|
| 550 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
|
| 551 |
|
| 552 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 553 |
self.model.eval()
|
|
|
|
| 584 |
class SolubilityModel:
|
| 585 |
def __init__(self, device):
|
| 586 |
# change model path
|
| 587 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
|
| 588 |
|
| 589 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 590 |
self.model.eval()
|
|
|
|
| 670 |
output_dim = input_dim // 8
|
| 671 |
dropout_rate = 0.3
|
| 672 |
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 673 |
+
self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
|
| 674 |
self.model.eval()
|
| 675 |
|
| 676 |
def __call__(self, x):
|