Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +12 -12
trainer.py
CHANGED
|
@@ -398,7 +398,7 @@ class Trainer(object):
|
|
| 398 |
|
| 399 |
''' Loading the atom and bond decoders'''
|
| 400 |
|
| 401 |
-
with open("
|
| 402 |
|
| 403 |
return pickle.load(f)
|
| 404 |
|
|
@@ -406,7 +406,7 @@ class Trainer(object):
|
|
| 406 |
|
| 407 |
''' Loading the atom and bond decoders'''
|
| 408 |
|
| 409 |
-
with open("
|
| 410 |
|
| 411 |
return pickle.load(f)
|
| 412 |
|
|
@@ -507,15 +507,15 @@ class Trainer(object):
|
|
| 507 |
|
| 508 |
|
| 509 |
# protein data
|
| 510 |
-
full_smiles = [line for line in open("
|
| 511 |
-
drug_smiles = [line for line in open("
|
| 512 |
|
| 513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 516 |
|
| 517 |
-
akt1_human_adj = torch.load("
|
| 518 |
-
akt1_human_annot = torch.load("
|
| 519 |
|
| 520 |
# Start training.
|
| 521 |
|
|
@@ -705,14 +705,14 @@ class Trainer(object):
|
|
| 705 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 706 |
|
| 707 |
|
| 708 |
-
drug_smiles = [line for line in open("
|
| 709 |
|
| 710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 713 |
|
| 714 |
-
akt1_human_adj = torch.load("
|
| 715 |
-
akt1_human_annot = torch.load("
|
| 716 |
|
| 717 |
self.G.eval()
|
| 718 |
#self.D.eval()
|
|
@@ -753,8 +753,8 @@ class Trainer(object):
|
|
| 753 |
#metric_calc_mol = []
|
| 754 |
metric_calc_dr = []
|
| 755 |
date = time.time()
|
| 756 |
-
if not os.path.exists("
|
| 757 |
-
os.makedirs("
|
| 758 |
with torch.inference_mode():
|
| 759 |
|
| 760 |
dataloader_iterator = iter(self.drugs_loader)
|
|
@@ -867,7 +867,7 @@ class Trainer(object):
|
|
| 867 |
|
| 868 |
print("molecule batch {} inferred".format(i))
|
| 869 |
|
| 870 |
-
with open("
|
| 871 |
for molecules in inference_drugs:
|
| 872 |
|
| 873 |
f.write(molecules)
|
|
|
|
| 398 |
|
| 399 |
''' Loading the atom and bond decoders'''
|
| 400 |
|
| 401 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 402 |
|
| 403 |
return pickle.load(f)
|
| 404 |
|
|
|
|
| 406 |
|
| 407 |
''' Loading the atom and bond decoders'''
|
| 408 |
|
| 409 |
+
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
| 410 |
|
| 411 |
return pickle.load(f)
|
| 412 |
|
|
|
|
| 507 |
|
| 508 |
|
| 509 |
# protein data
|
| 510 |
+
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 511 |
+
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
| 512 |
|
| 513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 516 |
|
| 517 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 518 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 519 |
|
| 520 |
# Start training.
|
| 521 |
|
|
|
|
| 705 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 706 |
|
| 707 |
|
| 708 |
+
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
| 709 |
|
| 710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 713 |
|
| 714 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 715 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 716 |
|
| 717 |
self.G.eval()
|
| 718 |
#self.D.eval()
|
|
|
|
| 753 |
#metric_calc_mol = []
|
| 754 |
metric_calc_dr = []
|
| 755 |
date = time.time()
|
| 756 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
| 757 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
| 758 |
with torch.inference_mode():
|
| 759 |
|
| 760 |
dataloader_iterator = iter(self.drugs_loader)
|
|
|
|
| 867 |
|
| 868 |
print("molecule batch {} inferred".format(i))
|
| 869 |
|
| 870 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
| 871 |
for molecules in inference_drugs:
|
| 872 |
|
| 873 |
f.write(molecules)
|