| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import argparse |
| import json |
| import os |
| import torch |
|
|
| |
| from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor |
| from torch.utils.data.distributed import DistributedSampler |
| |
|
|
| from torch.utils.data import DataLoader |
| from glow import WaveGlow, WaveGlowLoss |
| from mel2samp import Mel2Samp |
|
|
| def load_checkpoint(checkpoint_path, model, optimizer): |
| assert os.path.isfile(checkpoint_path) |
| checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') |
| iteration = checkpoint_dict['iteration'] |
| optimizer.load_state_dict(checkpoint_dict['optimizer']) |
| model_for_loading = checkpoint_dict['model'] |
| model.load_state_dict(model_for_loading.state_dict()) |
| print("Loaded checkpoint '{}' (iteration {})" .format( |
| checkpoint_path, iteration)) |
| return model, optimizer, iteration |
|
|
| def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): |
| print("Saving model and optimizer state at iteration {} to {}".format( |
| iteration, filepath)) |
| model_for_saving = WaveGlow(**waveglow_config).cuda() |
| model_for_saving.load_state_dict(model.state_dict()) |
| torch.save({'model': model_for_saving, |
| 'iteration': iteration, |
| 'optimizer': optimizer.state_dict(), |
| 'learning_rate': learning_rate}, filepath) |
|
|
| def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, |
| sigma, iters_per_checkpoint, batch_size, seed, fp16_run, |
| checkpoint_path, with_tensorboard): |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| |
| if num_gpus > 1: |
| init_distributed(rank, num_gpus, group_name, **dist_config) |
| |
|
|
| criterion = WaveGlowLoss(sigma) |
| model = WaveGlow(**waveglow_config).cuda() |
|
|
| |
| if num_gpus > 1: |
| model = apply_gradient_allreduce(model) |
| |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
| if fp16_run: |
| from apex import amp |
| model, optimizer = amp.initialize(model, optimizer, opt_level='O1') |
|
|
| |
| iteration = 0 |
| if checkpoint_path != "": |
| model, optimizer, iteration = load_checkpoint(checkpoint_path, model, |
| optimizer) |
| iteration += 1 |
|
|
| trainset = Mel2Samp(**data_config) |
| |
| train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None |
| |
| train_loader = DataLoader(trainset, num_workers=1, shuffle=False, |
| sampler=train_sampler, |
| batch_size=batch_size, |
| pin_memory=False, |
| drop_last=True) |
|
|
| |
| if rank == 0: |
| if not os.path.isdir(output_directory): |
| os.makedirs(output_directory) |
| os.chmod(output_directory, 0o775) |
| print("output directory", output_directory) |
|
|
| if with_tensorboard and rank == 0: |
| from tensorboardX import SummaryWriter |
| logger = SummaryWriter(os.path.join(output_directory, 'logs')) |
|
|
| model.train() |
| epoch_offset = max(0, int(iteration / len(train_loader))) |
| |
| for epoch in range(epoch_offset, epochs): |
| print("Epoch: {}".format(epoch)) |
| for i, batch in enumerate(train_loader): |
| model.zero_grad() |
|
|
| mel, audio = batch |
| mel = torch.autograd.Variable(mel.cuda()) |
| audio = torch.autograd.Variable(audio.cuda()) |
| outputs = model((mel, audio)) |
|
|
| loss = criterion(outputs) |
| if num_gpus > 1: |
| reduced_loss = reduce_tensor(loss.data, num_gpus).item() |
| else: |
| reduced_loss = loss.item() |
|
|
| if fp16_run: |
| with amp.scale_loss(loss, optimizer) as scaled_loss: |
| scaled_loss.backward() |
| else: |
| loss.backward() |
|
|
| optimizer.step() |
|
|
| print("{}:\t{:.9f}".format(iteration, reduced_loss)) |
| if with_tensorboard and rank == 0: |
| logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) |
|
|
| if (iteration % iters_per_checkpoint == 0): |
| if rank == 0: |
| checkpoint_path = "{}/waveglow_{}".format( |
| output_directory, iteration) |
| save_checkpoint(model, optimizer, learning_rate, iteration, |
| checkpoint_path) |
|
|
| iteration += 1 |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-c', '--config', type=str, |
| help='JSON file for configuration') |
| parser.add_argument('-r', '--rank', type=int, default=0, |
| help='rank of process for distributed') |
| parser.add_argument('-g', '--group_name', type=str, default='', |
| help='name of group for distributed') |
| args = parser.parse_args() |
|
|
| |
| with open(args.config) as f: |
| data = f.read() |
| config = json.loads(data) |
| train_config = config["train_config"] |
| global data_config |
| data_config = config["data_config"] |
| global dist_config |
| dist_config = config["dist_config"] |
| global waveglow_config |
| waveglow_config = config["waveglow_config"] |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus > 1: |
| if args.group_name == '': |
| print("WARNING: Multiple GPUs detected but no distributed group set") |
| print("Only running 1 GPU. Use distributed.py for multiple GPUs") |
| num_gpus = 1 |
|
|
| if num_gpus == 1 and args.rank != 0: |
| raise Exception("Doing single GPU training on rank > 0") |
|
|
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = False |
| train(num_gpus, args.rank, args.group_name, **train_config) |
|
|