Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| import torch | |
| import shutil | |
| from collections import OrderedDict | |
| import logging | |
| import numpy as np | |
| def save_ckpt( | |
| model, | |
| optimizer, | |
| train_epoch_loss, | |
| val_epoch_loss, | |
| train_epoch_nap, | |
| val_epoch_nap, | |
| epoch, | |
| save_path, | |
| name_pre, | |
| name_post="best", | |
| ): | |
| model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} | |
| state = { | |
| "epoch": epoch, | |
| "model_state_dict": model_cpu, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "train_loss": train_epoch_loss, | |
| "val_loss": val_epoch_loss, | |
| "train_map": train_epoch_nap, | |
| "val_map": val_epoch_nap, | |
| } | |
| if not os.path.exists(save_path): | |
| os.mkdir(save_path) | |
| print("Directory ", save_path, " is created.") | |
| filename = "{}/{}_{}.pth".format(save_path, name_pre, name_post) | |
| torch.save(state, filename) | |
| print("model has been saved as {}".format(filename)) | |
| def load_pretrained_models( | |
| model, pretrained_model, phase, ismax=True | |
| ): # ismax means max best | |
| if ismax: | |
| best_value = -np.inf | |
| else: | |
| best_value = np.inf | |
| epoch = -1 | |
| if pretrained_model: | |
| if os.path.isfile(pretrained_model): | |
| logging.info("===> Loading checkpoint '{}'".format(pretrained_model)) | |
| checkpoint = torch.load(pretrained_model) | |
| try: | |
| best_value = checkpoint["best_value"] | |
| if best_value == -np.inf or best_value == np.inf: | |
| show_best_value = False | |
| else: | |
| show_best_value = True | |
| except: | |
| best_value = best_value | |
| show_best_value = False | |
| model_dict = model.state_dict() | |
| ckpt_model_state_dict = checkpoint["state_dict"] | |
| # rename ckpt (avoid name is not same because of multi-gpus) | |
| is_model_multi_gpus = True if list(model_dict)[0][0][0] == "m" else False | |
| is_ckpt_multi_gpus = ( | |
| True if list(ckpt_model_state_dict)[0][0] == "m" else False | |
| ) | |
| if not (is_model_multi_gpus == is_ckpt_multi_gpus): | |
| temp_dict = OrderedDict() | |
| for k, v in ckpt_model_state_dict.items(): | |
| if is_ckpt_multi_gpus: | |
| name = k[7:] # remove 'module.' | |
| else: | |
| name = "module." + k # add 'module' | |
| temp_dict[name] = v | |
| # load params | |
| ckpt_model_state_dict = temp_dict | |
| model_dict.update(ckpt_model_state_dict) | |
| model.load_state_dict(ckpt_model_state_dict) | |
| if show_best_value: | |
| logging.info( | |
| "The pretrained_model is at checkpoint {}. \t " | |
| "Best value: {}".format(checkpoint["epoch"], best_value) | |
| ) | |
| else: | |
| logging.info( | |
| "The pretrained_model is at checkpoint {}.".format( | |
| checkpoint["epoch"] | |
| ) | |
| ) | |
| if phase == "train": | |
| epoch = checkpoint["epoch"] | |
| else: | |
| epoch = -1 | |
| else: | |
| raise ImportError( | |
| "===> No checkpoint found at '{}'".format(pretrained_model) | |
| ) | |
| else: | |
| logging.info("===> No pre-trained model") | |
| return model, best_value, epoch | |
| def load_pretrained_optimizer( | |
| pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True | |
| ): | |
| if pretrained_model: | |
| if os.path.isfile(pretrained_model): | |
| checkpoint = torch.load(pretrained_model) | |
| if "optimizer_state_dict" in checkpoint.keys(): | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| for state in optimizer.state.values(): | |
| for k, v in state.items(): | |
| if torch.is_tensor(v): | |
| state[k] = v.cuda() | |
| if "scheduler_state_dict" in checkpoint.keys(): | |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| if use_ckpt_lr: | |
| try: | |
| lr = scheduler.get_lr()[0] | |
| except: | |
| lr = lr | |
| return optimizer, scheduler, lr | |
| def save_checkpoint(state, is_best, save_path, postname): | |
| filename = "{}/{}_{}.pth".format(save_path, postname, int(state["epoch"])) | |
| torch.save(state, filename) | |
| if is_best: | |
| shutil.copyfile(filename, "{}/{}_best.pth".format(save_path, postname)) | |
| def change_ckpt_dict(model, optimizer, scheduler, opt): | |
| for _ in range(opt.epoch): | |
| scheduler.step() | |
| is_best = opt.test_value < opt.best_value | |
| opt.best_value = min(opt.test_value, opt.best_value) | |
| model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} | |
| # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} | |
| save_checkpoint( | |
| { | |
| "epoch": opt.epoch, | |
| "state_dict": model_cpu, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "best_value": opt.best_value, | |
| }, | |
| is_best, | |
| opt.save_path, | |
| opt.post, | |
| ) | |
| def load_models(model, device): | |
| print("------Copying model 1---------") | |
| prop_predictor1 = copy.deepcopy(model) | |
| print("------Copying model 2---------") | |
| prop_predictor2 = copy.deepcopy(model) | |
| print("------Copying model 3---------") | |
| prop_predictor3 = copy.deepcopy(model) | |
| print("------Copying model 4---------") | |
| prop_predictor4 = copy.deepcopy(model) | |
| test_model_path = "./PLA-Net/pretrained-models/BINARY_ada" | |
| test_model_path1 = test_model_path + "/Fold1/Best_Model.pth" | |
| test_model_path2 = test_model_path + "/Fold2/Best_Model.pth" | |
| test_model_path3 = test_model_path + "/Fold3/Best_Model.pth" | |
| test_model_path4 = test_model_path + "/Fold4/Best_Model.pth" | |
| # LOAD MODELS | |
| print("------- Loading weights----------") | |
| ckpt1 = torch.load(test_model_path1, map_location=lambda storage, loc: storage) | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ].t() | |
| ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ] = ckpt1["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ].t() | |
| prop_predictor1.load_state_dict(ckpt1["model_state_dict"]) | |
| prop_predictor1.to(device) | |
| ckpt2 = torch.load(test_model_path2, map_location=lambda storage, loc: storage) | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ].t() | |
| ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ] = ckpt2["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ].t() | |
| prop_predictor2.load_state_dict(ckpt2["model_state_dict"]) | |
| prop_predictor2.to(device) | |
| ckpt3 = torch.load(test_model_path3, map_location=lambda storage, loc: storage) | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ].t() | |
| ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ] = ckpt3["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ].t() | |
| prop_predictor3.load_state_dict(ckpt3["model_state_dict"]) | |
| prop_predictor3.to(device) | |
| ckpt4 = torch.load(test_model_path4, map_location=lambda storage, loc: storage) | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.0.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.1.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.2.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.3.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.4.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.5.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.6.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.7.weight" | |
| ].t() | |
| ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ] = ckpt4["model_state_dict"][ | |
| "molecule_gcn.atom_encoder.atom_embedding_list.8.weight" | |
| ].t() | |
| prop_predictor4.load_state_dict(ckpt4["model_state_dict"]) | |
| prop_predictor4.to(device) | |
| return prop_predictor1, prop_predictor2, prop_predictor3, prop_predictor4 | |