Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import PyTorchModelHubMixin | |
| ################################# | |
| # Latent Space Distance Metrics # | |
| ################################# | |
| class Cosine(nn.Module): | |
| def forward(self, x1, x2): | |
| return nn.CosineSimilarity()(x1, x2) | |
| class SquaredCosine(nn.Module): | |
| def forward(self, x1, x2): | |
| return nn.CosineSimilarity()(x1, x2) ** 2 | |
| class Euclidean(nn.Module): | |
| def forward(self, x1, x2): | |
| return torch.cdist(x1, x2, p=2.0) | |
| class SquaredEuclidean(nn.Module): | |
| def forward(self, x1, x2): | |
| return torch.cdist(x1, x2, p=2.0) ** 2 | |
| DISTANCE_METRICS = { | |
| "Cosine": Cosine, | |
| "SquaredCosine": SquaredCosine, | |
| "Euclidean": Euclidean, | |
| "SquaredEuclidean": SquaredEuclidean, | |
| } | |
| ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid} | |
| class ConPLex_DTI(nn.Module, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| drug_shape=2048, | |
| target_shape=1024, | |
| latent_dimension=1024, | |
| latent_activation="ReLU", | |
| latent_distance="Cosine", | |
| classify=True, | |
| ): | |
| super().__init__() | |
| self.drug_shape = drug_shape | |
| self.target_shape = target_shape | |
| self.latent_dimension = latent_dimension | |
| self.do_classify = classify | |
| self.latent_activation = ACTIVATIONS[latent_activation] | |
| self.drug_projector = nn.Sequential( | |
| nn.Linear(self.drug_shape, latent_dimension), self.latent_activation() | |
| ) | |
| nn.init.xavier_normal_(self.drug_projector[0].weight) | |
| self.target_projector = nn.Sequential( | |
| nn.Linear(self.target_shape, latent_dimension), self.latent_activation() | |
| ) | |
| nn.init.xavier_normal_(self.target_projector[0].weight) | |
| if self.do_classify: | |
| self.distance_metric = latent_distance | |
| self.activator = DISTANCE_METRICS[self.distance_metric]() | |
| def forward(self, drug, target): | |
| if self.do_classify: | |
| return self.classify(drug, target) | |
| else: | |
| return self.regress(drug, target) | |
| def regress(self, drug, target): | |
| drug_projection = self.drug_projector(drug) | |
| target_projection = self.target_projector(target) | |
| inner_prod = torch.bmm( | |
| drug_projection.view(-1, 1, self.latent_dimension), | |
| target_projection.view(-1, self.latent_dimension, 1), | |
| ).squeeze() | |
| return inner_prod.squeeze() | |
| def classify(self, drug, target): | |
| drug_projection = self.drug_projector(drug) | |
| target_projection = self.target_projector(target) | |
| distance = self.activator(drug_projection, target_projection) | |
| return distance.squeeze() | |
| if __name__ == "__main__": | |
| model_path = "./models/conplex_v1_bindingdb.pt" | |
| model = ConPLex_DTI() | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
| model.save_pretrained("ConPLex_V1_BindingDB") | |
| model.push_to_hub("ConPLex_V1_BindingDB") | |
| model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB") |