Spaces:
Runtime error
Runtime error
| import sys, os | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.optim as optim | |
| from torch.optim import lr_scheduler | |
| import time | |
| from time import perf_counter | |
| import pickle | |
| from model.config import load_config | |
| from model.genconvit_ed import GenConViTED | |
| from model.genconvit_vae import GenConViTVAE | |
| from dataset.loader import load_data, load_checkpoint | |
| import optparse | |
| config = load_config() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_pretrained(pretrained_model_filename): | |
| assert os.path.isfile( | |
| pretrained_model_filename | |
| ), "Saved model file does not exist. Exiting." | |
| model, optimizer, start_epoch, min_loss = load_checkpoint( | |
| model, optimizer, filename=pretrained_model_filename | |
| ) | |
| # now individually transfer the optimizer parts... | |
| for state in optimizer.state.values(): | |
| for k, v in state.items(): | |
| if isinstance(v, torch.Tensor): | |
| state[k] = v.to(device) | |
| return model, optimizer, start_epoch, min_loss | |
| def train_model( | |
| dir_path, mod, num_epochs, pretrained_model_filename, test_model, batch_size | |
| ): | |
| print("Loading data...") | |
| dataloaders, dataset_sizes = load_data(dir_path, batch_size) | |
| print("Done.") | |
| if mod == "ed": | |
| from train.train_ed import train, valid | |
| model = GenConViTED(config) | |
| else: | |
| from train.train_vae import train, valid | |
| model = GenConViTVAE(config) | |
| optimizer = optim.Adam( | |
| model.parameters(), | |
| lr=float(config["learning_rate"]), | |
| weight_decay=float(config["weight_decay"]), | |
| ) | |
| criterion = nn.CrossEntropyLoss() | |
| criterion.to(device) | |
| mse = nn.MSELoss() | |
| min_val_loss = int(config["min_val_loss"]) | |
| scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1) | |
| if pretrained_model_filename: | |
| model, optimizer, start_epoch, min_loss = load_pretrained( | |
| pretrained_model_filename | |
| ) | |
| model.to(device) | |
| torch.manual_seed(1) | |
| train_loss, train_acc, valid_loss, valid_acc = [], [], [], [] | |
| since = time.time() | |
| for epoch in range(0, num_epochs): | |
| train_loss, train_acc, epoch_loss = train( | |
| model, | |
| device, | |
| dataloaders["train"], | |
| criterion, | |
| optimizer, | |
| epoch, | |
| train_loss, | |
| train_acc, | |
| mse, | |
| ) | |
| valid_loss, valid_acc = valid( | |
| model, | |
| device, | |
| dataloaders["validation"], | |
| criterion, | |
| epoch, | |
| valid_loss, | |
| valid_acc, | |
| mse, | |
| ) | |
| scheduler.step() | |
| time_elapsed = time.time() - since | |
| print( | |
| "Training complete in {:.0f}m {:.0f}s".format( | |
| time_elapsed // 60, time_elapsed % 60 | |
| ) | |
| ) | |
| print("\nSaving model...\n") | |
| file_path = os.path.join( | |
| "weight", | |
| f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}', | |
| ) | |
| with open(f"{file_path}.pkl", "wb") as f: | |
| pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f) | |
| state = { | |
| "epoch": num_epochs + 1, | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "min_loss": epoch_loss, | |
| } | |
| weight = f"{file_path}.pth" | |
| torch.save(state, weight) | |
| print("Done.") | |
| if test_model: | |
| test(model, dataloaders, dataset_sizes, mod, weight) | |
| def test(model, dataloaders, dataset_sizes, mod, weight): | |
| print("\nRunning test...\n") | |
| model.eval() | |
| checkpoint = torch.load(weight, map_location="cpu") | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| _ = model.eval() | |
| Sum = 0 | |
| counter = 0 | |
| for inputs, labels in dataloaders["test"]: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| if mod == "ed": | |
| output = model(inputs).to(device).float() | |
| else: | |
| output = model(inputs)[0].to(device).float() | |
| _, prediction = torch.max(output, 1) | |
| pred_label = labels[prediction] | |
| pred_label = pred_label.detach().cpu().numpy() | |
| main_label = labels.detach().cpu().numpy() | |
| bool_list = list(map(lambda x, y: x == y, pred_label, main_label)) | |
| Sum += sum(np.array(bool_list) * 1) | |
| counter += 1 | |
| print(f"Pediction: {Sum}/{len(inputs)*counter}") | |
| print( | |
| f'Prediction: {Sum}/{dataset_sizes["test"]} {(Sum / dataset_sizes["test"]) * 100:.2f}%' | |
| ) | |
| def gen_parser(): | |
| parser = optparse.OptionParser("Train GenConViT model.") | |
| parser.add_option( | |
| "-e", | |
| "--epoch", | |
| type=int, | |
| dest="epoch", | |
| help="Number of epochs used for training the GenConvNextViT model.", | |
| ) | |
| parser.add_option("-v", "--version", dest="version", help="Version 0.1.") | |
| parser.add_option("-d", "--dir", dest="dir", help="Training data path.") | |
| parser.add_option( | |
| "-m", | |
| "--model", | |
| dest="model", | |
| help="model ed or model vae, model variant: genconvit (A) ed or genconvit (B) vae.", | |
| ) | |
| parser.add_option( | |
| "-p", | |
| "--pretrained", | |
| dest="pretrained", | |
| help="Saved model file name. If you want to continue from the previous trained model.", | |
| ) | |
| parser.add_option("-t", "--test", dest="test", help="run test on test dataset.") | |
| parser.add_option("-b", "--batch_size", dest="batch_size", help="batch size.") | |
| (options, _) = parser.parse_args() | |
| dir_path = options.dir | |
| epoch = options.epoch | |
| mod = "ed" if options.model == "ed" else "vae" | |
| test_model = "y" if options.test else None | |
| pretrained_model_filename = options.pretrained if options.pretrained else None | |
| batch_size = options.batch_size if options.batch_size else config["batch_size"] | |
| return dir_path, mod, epoch, pretrained_model_filename, test_model, int(batch_size) | |
| def main(): | |
| start_time = perf_counter() | |
| path, mod, epoch, pretrained_model_filename, test_model, batch_size = gen_parser() | |
| train_model(path, mod, epoch, pretrained_model_filename, test_model, batch_size) | |
| end_time = perf_counter() | |
| print("\n\n--- %s seconds ---" % (end_time - start_time)) | |
| if __name__ == "__main__": | |
| main() | |