Update models/peptide_classifiers.py
Browse files
models/peptide_classifiers.py
CHANGED
|
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
-
self.predictor = xgb.Booster(model_file='/
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
@@ -569,7 +569,7 @@ class MLPClassifier(nn.Module):
|
|
| 569 |
|
| 570 |
class NonfoulingModel:
|
| 571 |
def __init__(self, device):
|
| 572 |
-
ckpt = torch.load('/
|
| 573 |
best_params = ckpt["best_params"]
|
| 574 |
self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
|
| 575 |
self.predictor.load_state_dict(ckpt["state_dict"])
|
|
@@ -595,7 +595,7 @@ class NonfoulingModel:
|
|
| 595 |
class SolubilityModel:
|
| 596 |
def __init__(self, device):
|
| 597 |
# change model path
|
| 598 |
-
self.predictor = xgb.Booster(model_file='/
|
| 599 |
|
| 600 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 601 |
self.model.eval()
|
|
@@ -675,21 +675,6 @@ class PeptideCNN(nn.Module):
|
|
| 675 |
return features
|
| 676 |
return self.predictor(features) # Output shape: (B, 1)
|
| 677 |
|
| 678 |
-
# class HalfLifeModel:
|
| 679 |
-
# def __init__(self, device):
|
| 680 |
-
# input_dim = 1280
|
| 681 |
-
# hidden_dims = [input_dim // 2, input_dim // 4]
|
| 682 |
-
# output_dim = input_dim // 8
|
| 683 |
-
# dropout_rate = 0.3
|
| 684 |
-
# self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 685 |
-
# self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
|
| 686 |
-
# self.model.eval()
|
| 687 |
-
|
| 688 |
-
# def __call__(self, x):
|
| 689 |
-
# prediction = self.model(x, return_features=False)
|
| 690 |
-
# halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
|
| 691 |
-
# return halflife / 2
|
| 692 |
-
|
| 693 |
|
| 694 |
# -----------------------------
|
| 695 |
# Model definition (must match training)
|
|
@@ -758,7 +743,7 @@ class HalfLifeModel:
|
|
| 758 |
def __init__(
|
| 759 |
self,
|
| 760 |
device,
|
| 761 |
-
ckpt_path = "/
|
| 762 |
):
|
| 763 |
self.device = device
|
| 764 |
|
|
|
|
| 509 |
|
| 510 |
class HemolysisModel:
|
| 511 |
def __init__(self, device):
|
| 512 |
+
self.predictor = xgb.Booster(model_file='../classifier_ckpt/wt_hemolysis.json')
|
| 513 |
|
| 514 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
self.model.eval()
|
|
|
|
| 569 |
|
| 570 |
class NonfoulingModel:
|
| 571 |
def __init__(self, device):
|
| 572 |
+
ckpt = torch.load('../classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
|
| 573 |
best_params = ckpt["best_params"]
|
| 574 |
self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
|
| 575 |
self.predictor.load_state_dict(ckpt["state_dict"])
|
|
|
|
| 595 |
class SolubilityModel:
|
| 596 |
def __init__(self, device):
|
| 597 |
# change model path
|
| 598 |
+
self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_solubility.json')
|
| 599 |
|
| 600 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 601 |
self.model.eval()
|
|
|
|
| 675 |
return features
|
| 676 |
return self.predictor(features) # Output shape: (B, 1)
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
# -----------------------------
|
| 680 |
# Model definition (must match training)
|
|
|
|
| 743 |
def __init__(
|
| 744 |
self,
|
| 745 |
device,
|
| 746 |
+
ckpt_path = "../classifier_ckpt/wt_halflife.pt",
|
| 747 |
):
|
| 748 |
self.device = device
|
| 749 |
|