Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +10 -12
trainer.py
CHANGED
|
@@ -581,10 +581,10 @@ class Trainer(object):
|
|
| 581 |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
| 582 |
|
| 583 |
if self.submodel == "CrossLoss":
|
| 584 |
-
GAN1_input_e =
|
| 585 |
-
GAN1_input_x =
|
| 586 |
-
GAN1_disc_e =
|
| 587 |
-
GAN1_disc_x =
|
| 588 |
elif self.submodel == "Ligand":
|
| 589 |
GAN1_input_e = a_tensor
|
| 590 |
GAN1_input_x = x_tensor
|
|
@@ -737,11 +737,13 @@ class Trainer(object):
|
|
| 737 |
drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 738 |
else:
|
| 739 |
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
| 740 |
-
|
| 741 |
-
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 742 |
-
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 743 |
-
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 746 |
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 747 |
|
|
@@ -819,10 +821,6 @@ class Trainer(object):
|
|
| 819 |
GAN1_input_x = x_tensor
|
| 820 |
GAN1_disc_e = drugs_a_tensor
|
| 821 |
GAN1_disc_x = drugs_x_tensor
|
| 822 |
-
GAN2_input_e = drugs_a_tensor
|
| 823 |
-
GAN2_input_x = drugs_x_tensor
|
| 824 |
-
GAN2_disc_e = a_tensor
|
| 825 |
-
GAN2_disc_x = x_tensor
|
| 826 |
elif self.submodel == "Ligand":
|
| 827 |
GAN1_input_e = a_tensor
|
| 828 |
GAN1_input_x = x_tensor
|
|
|
|
| 581 |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
| 582 |
|
| 583 |
if self.submodel == "CrossLoss":
|
| 584 |
+
GAN1_input_e = a_tensor
|
| 585 |
+
GAN1_input_x = x_tensor
|
| 586 |
+
GAN1_disc_e = drugs_a_tensor
|
| 587 |
+
GAN1_disc_x = drugs_x_tensor
|
| 588 |
elif self.submodel == "Ligand":
|
| 589 |
GAN1_input_e = a_tensor
|
| 590 |
GAN1_input_x = x_tensor
|
|
|
|
| 737 |
drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 738 |
else:
|
| 739 |
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
+
if self.submodel == "RL":
|
| 742 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 743 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 744 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 745 |
+
else:
|
| 746 |
+
fps_r = None
|
| 747 |
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 748 |
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 749 |
|
|
|
|
| 821 |
GAN1_input_x = x_tensor
|
| 822 |
GAN1_disc_e = drugs_a_tensor
|
| 823 |
GAN1_disc_x = drugs_x_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
elif self.submodel == "Ligand":
|
| 825 |
GAN1_input_e = a_tensor
|
| 826 |
GAN1_input_x = x_tensor
|