Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import json | |
| import uuid | |
| import torch | |
| import os | |
| from torch.utils.data import random_split | |
| from torch_geometric.loader import DataLoader | |
| from data import PolyphemusDataset | |
| import torch.optim as optim | |
| from model import VAE | |
| from utils import set_seed, print_params, print_divider | |
| from training import PolyphemusTrainer, ExpDecayLRScheduler, StepBetaScheduler | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Trains Polyphemus.' | |
| ) | |
| parser.add_argument( | |
| 'dataset_dir', | |
| type=str, | |
| help='Directory of the Polyphemus dataset to be used for training.' | |
| ) | |
| parser.add_argument( | |
| 'output_dir', | |
| type=str, | |
| help='Directory to save the output of the training.' | |
| ) | |
| parser.add_argument( | |
| 'config_file', | |
| type=str, | |
| help='Path to the JSON training configuration file.' | |
| ) | |
| parser.add_argument( | |
| '--model_name', | |
| type=str, | |
| help='Name of the model to be trained.' | |
| ) | |
| parser.add_argument( | |
| '--save_every', | |
| type=int, | |
| default=10, | |
| help="If set to n, the script will save the model every n batches. " | |
| "Default is 10." | |
| ) | |
| parser.add_argument( | |
| '--print_every', | |
| type=int, | |
| default=1, | |
| help="If set to n, the script will print statistics every n batches. " | |
| "Default is 1." | |
| ) | |
| parser.add_argument( | |
| '--eval', | |
| action='store_true', | |
| default=False, | |
| help='Flag to enable evaluation on a validation set.' | |
| ) | |
| parser.add_argument( | |
| '--eval_every', | |
| type=int, | |
| help="If the eval flag is set, when set to n, the script will evaluate " | |
| "the model on the validation set every n batches. " | |
| "Default is every epoch." | |
| ) | |
| parser.add_argument( | |
| '--use_gpu', | |
| action='store_true', | |
| default=False, | |
| help='Flag to enable or disable GPU usage. Default is False.' | |
| ) | |
| parser.add_argument( | |
| '--gpu_id', | |
| type=int, | |
| default='0', | |
| help='Index of the GPU to be used. Default is 0.' | |
| ) | |
| parser.add_argument( | |
| '--num_workers', | |
| type=int, | |
| default='10', | |
| help="The number of processes to use for loading the data. " | |
| "Default is 10." | |
| ) | |
| parser.add_argument( | |
| '--tr_split', | |
| type=float, | |
| default='0.7', | |
| help="Percentage of samples in the dataset used for the training split." | |
| " Default is 0.7." | |
| ) | |
| parser.add_argument( | |
| '--vl_split', | |
| type=float, | |
| default='0.1', | |
| help="Percentage of samples in the dataset used for the validation " | |
| "split. Default is 0.1. This value is ignored if the --eval option is " | |
| "not specified." | |
| ) | |
| parser.add_argument( | |
| '--max_epochs', | |
| type=int, | |
| default='100', | |
| ) | |
| parser.add_argument( | |
| '--seed', | |
| type=int | |
| ) | |
| args = parser.parse_args() | |
| print_divider() | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| device = torch.device("cuda") if args.use_gpu else torch.device("cpu") | |
| if args.use_gpu: | |
| torch.cuda.set_device(args.gpu_id) | |
| # Load config file | |
| print("Loading the configuration file {}...".format(args.config_file)) | |
| # Load structure tensor from file | |
| with open(args.config_file, 'r') as f: | |
| training_config = json.load(f) | |
| n_bars = training_config['model']['n_bars'] | |
| batch_size = training_config['batch_size'] | |
| print("Preparing datasets and dataloaders...") | |
| dataset = PolyphemusDataset(args.dataset_dir, n_bars) | |
| tr_len = int(args.tr_split * len(dataset)) | |
| if args.eval: | |
| vl_len = int(args.vl_split * len(dataset)) | |
| ts_len = len(dataset) - tr_len - vl_len | |
| lengths = (tr_len, vl_len, ts_len) | |
| else: | |
| ts_len = len(dataset) - tr_len | |
| lengths = (tr_len, ts_len) | |
| split = random_split(dataset, lengths) | |
| tr_set = split[0] | |
| vl_set = split[1] if args.eval else None | |
| trainloader = DataLoader(tr_set, batch_size=batch_size, shuffle=True, | |
| num_workers=args.num_workers) | |
| if args.eval: | |
| validloader = DataLoader(vl_set, batch_size=batch_size, shuffle=False, | |
| num_workers=args.num_workers) | |
| eval_every = len(trainloader) | |
| else: | |
| validloader = None | |
| eval_every = None | |
| model_name = (args.model_name if args.model_name is not None | |
| else str(uuid.uuid1())) | |
| model_dir = os.path.join(args.output_dir, model_name) | |
| # Create output directory if it does not exist | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Create model output directory (raise error if it already exists to avoid | |
| # overwriting a trained model) | |
| os.makedirs(model_dir, exist_ok=False) | |
| # Create the model | |
| print("Creating the model and moving it on {} device...".format(device)) | |
| vae = VAE(**training_config['model'], device=device).to(device) | |
| print_params(vae) | |
| print() | |
| # Creating optimizer and schedulers | |
| optimizer = optim.Adam(vae.parameters(), **training_config['optimizer']) | |
| lr_scheduler = ExpDecayLRScheduler( | |
| optimizer=optimizer, | |
| **training_config['lr_scheduler'] | |
| ) | |
| beta_scheduler = StepBetaScheduler(**training_config['beta_scheduler']) | |
| # Save config | |
| config_path = os.path.join(model_dir, 'configuration') | |
| torch.save(training_config, config_path) | |
| print("Starting training...") | |
| print_divider() | |
| trainer = PolyphemusTrainer( | |
| model_dir, | |
| vae, | |
| optimizer, | |
| lr_scheduler=lr_scheduler, | |
| beta_scheduler=beta_scheduler, | |
| save_every=args.save_every, | |
| print_every=args.print_every, | |
| eval_every=eval_every, | |
| device=device | |
| ) | |
| trainer.train(trainloader, validloader=validloader, epochs=args.max_epochs) | |
| if __name__ == '__main__': | |
| main() | |