Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +101 -74
trainer.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch
|
|
| 6 |
from utils import *
|
| 7 |
from models import Generator, Generator2, simple_disc
|
| 8 |
import torch_geometric.utils as geoutils
|
| 9 |
-
#import
|
| 10 |
import re
|
| 11 |
from torch_geometric.loader import DataLoader
|
| 12 |
from new_dataloader import DruggenDataset
|
|
@@ -19,7 +19,7 @@ RDLogger.DisableLog('rdApp.*')
|
|
| 19 |
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
| 20 |
from training_data import load_data
|
| 21 |
import random
|
| 22 |
-
|
| 23 |
|
| 24 |
class Trainer(object):
|
| 25 |
|
|
@@ -27,6 +27,19 @@ class Trainer(object):
|
|
| 27 |
|
| 28 |
def __init__(self, config):
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 31 |
"""Initialize configurations."""
|
| 32 |
self.submodel = config.submodel
|
|
@@ -57,7 +70,10 @@ class Trainer(object):
|
|
| 57 |
|
| 58 |
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
| 59 |
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
| 62 |
|
| 63 |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
|
@@ -219,6 +235,14 @@ class Trainer(object):
|
|
| 219 |
self.clipping_value = config.clipping_value
|
| 220 |
# Miscellaneous.
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
self.mode = config.mode
|
| 223 |
|
| 224 |
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
|
@@ -398,7 +422,7 @@ class Trainer(object):
|
|
| 398 |
|
| 399 |
''' Loading the atom and bond decoders'''
|
| 400 |
|
| 401 |
-
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 402 |
|
| 403 |
return pickle.load(f)
|
| 404 |
|
|
@@ -406,7 +430,7 @@ class Trainer(object):
|
|
| 406 |
|
| 407 |
''' Loading the atom and bond decoders'''
|
| 408 |
|
| 409 |
-
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
| 410 |
|
| 411 |
return pickle.load(f)
|
| 412 |
|
|
@@ -429,17 +453,17 @@ class Trainer(object):
|
|
| 429 |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
| 430 |
|
| 431 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
| 432 |
-
|
| 433 |
|
| 434 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 435 |
-
|
| 436 |
|
| 437 |
|
| 438 |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
| 439 |
-
|
| 440 |
|
| 441 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 442 |
-
|
| 443 |
|
| 444 |
|
| 445 |
def save_model(self, model_directory, idx,i):
|
|
@@ -507,16 +531,19 @@ class Trainer(object):
|
|
| 507 |
|
| 508 |
|
| 509 |
# protein data
|
| 510 |
-
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 511 |
-
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
| 512 |
|
| 513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 516 |
|
| 517 |
-
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 518 |
-
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
| 520 |
# Start training.
|
| 521 |
|
| 522 |
print('Start training...')
|
|
@@ -577,8 +604,8 @@ class Trainer(object):
|
|
| 577 |
GAN2_disc_e = drugs_a_tensor
|
| 578 |
GAN2_disc_x = drugs_x_tensor
|
| 579 |
elif self.submodel == "RL":
|
| 580 |
-
GAN1_input_e =
|
| 581 |
-
GAN1_input_x =
|
| 582 |
GAN1_disc_e = a_tensor
|
| 583 |
GAN1_disc_x = x_tensor
|
| 584 |
GAN2_input_e = drugs_a_tensor
|
|
@@ -586,8 +613,8 @@ class Trainer(object):
|
|
| 586 |
GAN2_disc_e = drugs_a_tensor
|
| 587 |
GAN2_disc_x = drugs_x_tensor
|
| 588 |
elif self.submodel == "NoTarget":
|
| 589 |
-
GAN1_input_e =
|
| 590 |
-
GAN1_input_x =
|
| 591 |
GAN1_disc_e = a_tensor
|
| 592 |
GAN1_disc_x = x_tensor
|
| 593 |
|
|
@@ -639,9 +666,10 @@ class Trainer(object):
|
|
| 639 |
GAN1_input_x,
|
| 640 |
self.batch_size,
|
| 641 |
sim_reward,
|
| 642 |
-
self.dataset.
|
| 643 |
fps_r,
|
| 644 |
-
self.submodel
|
|
|
|
| 645 |
|
| 646 |
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
| 647 |
|
|
@@ -659,7 +687,8 @@ class Trainer(object):
|
|
| 659 |
fps_r,
|
| 660 |
GAN2_input_e,
|
| 661 |
GAN2_input_x,
|
| 662 |
-
self.submodel
|
|
|
|
| 663 |
|
| 664 |
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
| 665 |
|
|
@@ -695,31 +724,31 @@ class Trainer(object):
|
|
| 695 |
|
| 696 |
# Load the trained generator.
|
| 697 |
self.G.to(self.device)
|
| 698 |
-
#self.D.to(self.device)
|
| 699 |
self.G2.to(self.device)
|
| 700 |
-
#self.D2.to(self.device)
|
| 701 |
|
| 702 |
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
| 703 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 704 |
-
|
| 705 |
-
|
|
|
|
| 706 |
|
| 707 |
-
|
| 708 |
-
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
| 709 |
|
| 710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 713 |
|
| 714 |
-
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 715 |
-
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 716 |
|
| 717 |
self.G.eval()
|
| 718 |
#self.D.eval()
|
| 719 |
self.G2.eval()
|
| 720 |
#self.D2.eval()
|
|
|
|
|
|
|
| 721 |
|
| 722 |
-
self.inf_batch_size =256
|
| 723 |
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
| 724 |
self.inf_dataset_file,
|
| 725 |
self.inf_raw_file,
|
|
@@ -753,24 +782,25 @@ class Trainer(object):
|
|
| 753 |
#metric_calc_mol = []
|
| 754 |
metric_calc_dr = []
|
| 755 |
date = time.time()
|
| 756 |
-
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
| 757 |
-
os.makedirs("experiments/inference/{}".format(self.submodel))
|
| 758 |
with torch.inference_mode():
|
| 759 |
|
| 760 |
-
dataloader_iterator = iter(self.
|
| 761 |
-
|
| 762 |
-
for
|
|
|
|
| 763 |
try:
|
| 764 |
drugs = next(dataloader_iterator)
|
| 765 |
except StopIteration:
|
| 766 |
-
dataloader_iterator = iter(self.
|
| 767 |
drugs = next(dataloader_iterator)
|
| 768 |
|
| 769 |
# Preprocess both dataset
|
| 770 |
|
| 771 |
bulk_data = load_data(data,
|
| 772 |
drugs,
|
| 773 |
-
self.
|
| 774 |
self.device,
|
| 775 |
self.b_dim,
|
| 776 |
self.m_dim,
|
|
@@ -809,8 +839,8 @@ class Trainer(object):
|
|
| 809 |
GAN2_disc_e = drugs_a_tensor
|
| 810 |
GAN2_disc_x = drugs_x_tensor
|
| 811 |
elif self.submodel == "RL":
|
| 812 |
-
GAN1_input_e =
|
| 813 |
-
GAN1_input_x =
|
| 814 |
GAN1_disc_e = a_tensor
|
| 815 |
GAN1_disc_x = x_tensor
|
| 816 |
GAN2_input_e = drugs_a_tensor
|
|
@@ -818,8 +848,8 @@ class Trainer(object):
|
|
| 818 |
GAN2_disc_e = drugs_a_tensor
|
| 819 |
GAN2_disc_x = drugs_x_tensor
|
| 820 |
elif self.submodel == "NoTarget":
|
| 821 |
-
GAN1_input_e =
|
| 822 |
-
GAN1_input_x =
|
| 823 |
GAN1_disc_e = a_tensor
|
| 824 |
GAN1_disc_x = x_tensor
|
| 825 |
# =================================================================================== #
|
|
@@ -830,53 +860,50 @@ class Trainer(object):
|
|
| 830 |
self.V,
|
| 831 |
GAN1_input_e,
|
| 832 |
GAN1_input_x,
|
| 833 |
-
self.
|
| 834 |
sim_reward,
|
| 835 |
-
self.dataset.
|
| 836 |
fps_r,
|
| 837 |
-
self.submodel
|
|
|
|
| 838 |
|
| 839 |
-
_,
|
| 840 |
|
| 841 |
# =================================================================================== #
|
| 842 |
# 3. GAN2 Inference #
|
| 843 |
# =================================================================================== #
|
| 844 |
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
|
|
|
|
|
|
| 857 |
|
| 858 |
-
|
| 859 |
|
| 860 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
|
|
|
| 861 |
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
print("molecule batch {} inferred".format(i))
|
| 869 |
-
|
| 870 |
-
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
| 871 |
for molecules in inference_drugs:
|
| 872 |
|
| 873 |
f.write(molecules)
|
| 874 |
f.write("\n")
|
| 875 |
metric_calc_dr.append(molecules)
|
| 876 |
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
|
|
|
| 880 |
break
|
| 881 |
|
| 882 |
et = time.time() - start_time
|
|
@@ -885,8 +912,8 @@ class Trainer(object):
|
|
| 885 |
|
| 886 |
print("Metrics calculation started using MOSES.")
|
| 887 |
|
| 888 |
-
print("Validity: ", fraction_valid(
|
| 889 |
-
print("Uniqueness: ", fraction_unique(
|
| 890 |
-
print("Validity: ", novelty(
|
| 891 |
|
| 892 |
-
print("Metrics are calculated.")
|
|
|
|
| 6 |
from utils import *
|
| 7 |
from models import Generator, Generator2, simple_disc
|
| 8 |
import torch_geometric.utils as geoutils
|
| 9 |
+
#import wandb
|
| 10 |
import re
|
| 11 |
from torch_geometric.loader import DataLoader
|
| 12 |
from new_dataloader import DruggenDataset
|
|
|
|
| 19 |
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
| 20 |
from training_data import load_data
|
| 21 |
import random
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
|
| 24 |
class Trainer(object):
|
| 25 |
|
|
|
|
| 27 |
|
| 28 |
def __init__(self, config):
|
| 29 |
|
| 30 |
+
if config.set_seed:
|
| 31 |
+
np.random.seed(config.seed)
|
| 32 |
+
random.seed(config.seed)
|
| 33 |
+
torch.manual_seed(config.seed)
|
| 34 |
+
torch.cuda.manual_seed(config.seed)
|
| 35 |
+
|
| 36 |
+
torch.backends.cudnn.deterministic = True
|
| 37 |
+
torch.backends.cudnn.benchmark = False
|
| 38 |
+
|
| 39 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
| 40 |
+
|
| 41 |
+
print(f'Using seed {config.seed}')
|
| 42 |
+
|
| 43 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 44 |
"""Initialize configurations."""
|
| 45 |
self.submodel = config.submodel
|
|
|
|
| 70 |
|
| 71 |
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
| 72 |
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
| 73 |
+
self.inference_iterations = config.inference_iterations
|
| 74 |
+
|
| 75 |
+
self.inf_batch_size = config.inf_batch_size
|
| 76 |
+
|
| 77 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
| 78 |
|
| 79 |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
|
|
|
| 235 |
self.clipping_value = config.clipping_value
|
| 236 |
# Miscellaneous.
|
| 237 |
|
| 238 |
+
# resume training
|
| 239 |
+
|
| 240 |
+
self.resume = config.resume
|
| 241 |
+
self.resume_epoch = config.resume_epoch
|
| 242 |
+
self.resume_iter = config.resume_iter
|
| 243 |
+
self.resume_directory = config.resume_directory
|
| 244 |
+
|
| 245 |
+
|
| 246 |
self.mode = config.mode
|
| 247 |
|
| 248 |
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
|
|
|
| 422 |
|
| 423 |
''' Loading the atom and bond decoders'''
|
| 424 |
|
| 425 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 426 |
|
| 427 |
return pickle.load(f)
|
| 428 |
|
|
|
|
| 430 |
|
| 431 |
''' Loading the atom and bond decoders'''
|
| 432 |
|
| 433 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
| 434 |
|
| 435 |
return pickle.load(f)
|
| 436 |
|
|
|
|
| 453 |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
| 454 |
|
| 455 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
| 456 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
| 457 |
|
| 458 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 459 |
+
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
| 460 |
|
| 461 |
|
| 462 |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
| 463 |
+
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
|
| 464 |
|
| 465 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 466 |
+
self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
|
| 467 |
|
| 468 |
|
| 469 |
def save_model(self, model_directory, idx,i):
|
|
|
|
| 531 |
|
| 532 |
|
| 533 |
# protein data
|
| 534 |
+
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
|
| 535 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
|
| 536 |
|
| 537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 540 |
|
| 541 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 542 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 543 |
+
|
| 544 |
+
if self.resume:
|
| 545 |
+
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
| 546 |
+
|
| 547 |
# Start training.
|
| 548 |
|
| 549 |
print('Start training...')
|
|
|
|
| 604 |
GAN2_disc_e = drugs_a_tensor
|
| 605 |
GAN2_disc_x = drugs_x_tensor
|
| 606 |
elif self.submodel == "RL":
|
| 607 |
+
GAN1_input_e = a_tensor
|
| 608 |
+
GAN1_input_x = x_tensor
|
| 609 |
GAN1_disc_e = a_tensor
|
| 610 |
GAN1_disc_x = x_tensor
|
| 611 |
GAN2_input_e = drugs_a_tensor
|
|
|
|
| 613 |
GAN2_disc_e = drugs_a_tensor
|
| 614 |
GAN2_disc_x = drugs_x_tensor
|
| 615 |
elif self.submodel == "NoTarget":
|
| 616 |
+
GAN1_input_e = a_tensor
|
| 617 |
+
GAN1_input_x = x_tensor
|
| 618 |
GAN1_disc_e = a_tensor
|
| 619 |
GAN1_disc_x = x_tensor
|
| 620 |
|
|
|
|
| 666 |
GAN1_input_x,
|
| 667 |
self.batch_size,
|
| 668 |
sim_reward,
|
| 669 |
+
self.dataset.matrices2mol,
|
| 670 |
fps_r,
|
| 671 |
+
self.submodel,
|
| 672 |
+
self.dataset_name)
|
| 673 |
|
| 674 |
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
| 675 |
|
|
|
|
| 687 |
fps_r,
|
| 688 |
GAN2_input_e,
|
| 689 |
GAN2_input_x,
|
| 690 |
+
self.submodel,
|
| 691 |
+
self.drugs_name)
|
| 692 |
|
| 693 |
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
| 694 |
|
|
|
|
| 724 |
|
| 725 |
# Load the trained generator.
|
| 726 |
self.G.to(self.device)
|
|
|
|
| 727 |
self.G2.to(self.device)
|
|
|
|
| 728 |
|
| 729 |
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
| 730 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 731 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 732 |
+
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
| 733 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 734 |
|
| 735 |
+
|
| 736 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
|
| 737 |
|
| 738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 741 |
|
| 742 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 743 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 744 |
|
| 745 |
self.G.eval()
|
| 746 |
#self.D.eval()
|
| 747 |
self.G2.eval()
|
| 748 |
#self.D2.eval()
|
| 749 |
+
|
| 750 |
+
step = self.inference_iterations
|
| 751 |
|
|
|
|
| 752 |
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
| 753 |
self.inf_dataset_file,
|
| 754 |
self.inf_raw_file,
|
|
|
|
| 782 |
#metric_calc_mol = []
|
| 783 |
metric_calc_dr = []
|
| 784 |
date = time.time()
|
| 785 |
+
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
|
| 786 |
+
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
|
| 787 |
with torch.inference_mode():
|
| 788 |
|
| 789 |
+
dataloader_iterator = iter(self.inf_drugs_loader)
|
| 790 |
+
pbar = tqdm(range(self.inference_sample_num))
|
| 791 |
+
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
| 792 |
+
for i, data in enumerate(self.inf_loader):
|
| 793 |
try:
|
| 794 |
drugs = next(dataloader_iterator)
|
| 795 |
except StopIteration:
|
| 796 |
+
dataloader_iterator = iter(self.inf_drugs_loader)
|
| 797 |
drugs = next(dataloader_iterator)
|
| 798 |
|
| 799 |
# Preprocess both dataset
|
| 800 |
|
| 801 |
bulk_data = load_data(data,
|
| 802 |
drugs,
|
| 803 |
+
self.inf_batch_size,
|
| 804 |
self.device,
|
| 805 |
self.b_dim,
|
| 806 |
self.m_dim,
|
|
|
|
| 839 |
GAN2_disc_e = drugs_a_tensor
|
| 840 |
GAN2_disc_x = drugs_x_tensor
|
| 841 |
elif self.submodel == "RL":
|
| 842 |
+
GAN1_input_e = a_tensor
|
| 843 |
+
GAN1_input_x = x_tensor
|
| 844 |
GAN1_disc_e = a_tensor
|
| 845 |
GAN1_disc_x = x_tensor
|
| 846 |
GAN2_input_e = drugs_a_tensor
|
|
|
|
| 848 |
GAN2_disc_e = drugs_a_tensor
|
| 849 |
GAN2_disc_x = drugs_x_tensor
|
| 850 |
elif self.submodel == "NoTarget":
|
| 851 |
+
GAN1_input_e = a_tensor
|
| 852 |
+
GAN1_input_x = x_tensor
|
| 853 |
GAN1_disc_e = a_tensor
|
| 854 |
GAN1_disc_x = x_tensor
|
| 855 |
# =================================================================================== #
|
|
|
|
| 860 |
self.V,
|
| 861 |
GAN1_input_e,
|
| 862 |
GAN1_input_x,
|
| 863 |
+
self.inf_batch_size,
|
| 864 |
sim_reward,
|
| 865 |
+
self.dataset.matrices2mol,
|
| 866 |
fps_r,
|
| 867 |
+
self.submodel,
|
| 868 |
+
self.dataset_name)
|
| 869 |
|
| 870 |
+
_, fake_mol_g, _, _, node, edge = generator_output
|
| 871 |
|
| 872 |
# =================================================================================== #
|
| 873 |
# 3. GAN2 Inference #
|
| 874 |
# =================================================================================== #
|
| 875 |
|
| 876 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 877 |
+
output = generator2_loss(self.G2,
|
| 878 |
+
self.D2,
|
| 879 |
+
self.V2,
|
| 880 |
+
edge,
|
| 881 |
+
node,
|
| 882 |
+
self.inf_batch_size,
|
| 883 |
+
sim_reward,
|
| 884 |
+
self.dataset.matrices2mol_drugs,
|
| 885 |
+
fps_r,
|
| 886 |
+
GAN2_input_e,
|
| 887 |
+
GAN2_input_x,
|
| 888 |
+
self.submodel,
|
| 889 |
+
self.drugs_name)
|
| 890 |
|
| 891 |
+
_, fake_mol_g, edges, nodes = output
|
| 892 |
|
| 893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
| 894 |
+
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 895 |
|
| 896 |
+
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
for molecules in inference_drugs:
|
| 898 |
|
| 899 |
f.write(molecules)
|
| 900 |
f.write("\n")
|
| 901 |
metric_calc_dr.append(molecules)
|
| 902 |
|
| 903 |
+
if len(inference_drugs) > 0:
|
| 904 |
+
pbar.update(1)
|
| 905 |
+
|
| 906 |
+
if len(metric_calc_dr) == self.inference_sample_num:
|
| 907 |
break
|
| 908 |
|
| 909 |
et = time.time() - start_time
|
|
|
|
| 912 |
|
| 913 |
print("Metrics calculation started using MOSES.")
|
| 914 |
|
| 915 |
+
print("Validity: ", fraction_valid(metric_calc_dr), "\n")
|
| 916 |
+
print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
|
| 917 |
+
print("Validity: ", novelty(metric_calc_dr, drug_smiles), "\n")
|
| 918 |
|
| 919 |
+
print("Metrics are calculated.")
|