Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import argparse | |
| import math | |
| from numpy import finfo | |
| import torch | |
| from distributed import apply_gradient_allreduce | |
| import torch.distributed as dist | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.data import DataLoader | |
| from model import Tacotron2 | |
| from data_utils import TextMelLoader, TextMelCollate | |
| from loss_function import Tacotron2Loss | |
| from logger import Tacotron2Logger | |
| from hparams import create_hparams | |
| def reduce_tensor(tensor, n_gpus): | |
| rt = tensor.clone() | |
| dist.all_reduce(rt, op=dist.reduce_op.SUM) | |
| rt /= n_gpus | |
| return rt | |
| def init_distributed(hparams, n_gpus, rank, group_name): | |
| assert torch.cuda.is_available(), "Distributed mode requires CUDA." | |
| print("Initializing Distributed") | |
| # Set cuda device so everything is done on the right GPU. | |
| torch.cuda.set_device(rank % torch.cuda.device_count()) | |
| # Initialize distributed communication | |
| dist.init_process_group( | |
| backend=hparams.dist_backend, init_method=hparams.dist_url, | |
| world_size=n_gpus, rank=rank, group_name=group_name) | |
| print("Done initializing distributed") | |
| def prepare_dataloaders(hparams): | |
| # Get data, data loaders and collate function ready | |
| trainset = TextMelLoader(hparams.training_files, hparams) | |
| valset = TextMelLoader(hparams.validation_files, hparams) | |
| collate_fn = TextMelCollate(hparams.n_frames_per_step) | |
| if hparams.distributed_run: | |
| train_sampler = DistributedSampler(trainset) | |
| shuffle = False | |
| else: | |
| train_sampler = None | |
| shuffle = True | |
| train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, | |
| sampler=train_sampler, | |
| batch_size=hparams.batch_size, pin_memory=False, | |
| drop_last=True, collate_fn=collate_fn) | |
| return train_loader, valset, collate_fn | |
| def prepare_directories_and_logger(output_directory, log_directory, rank): | |
| if rank == 0: | |
| if not os.path.isdir(output_directory): | |
| os.makedirs(output_directory) | |
| os.chmod(output_directory, 0o775) | |
| logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) | |
| else: | |
| logger = None | |
| return logger | |
| def load_model(hparams): | |
| model = Tacotron2(hparams).cuda() | |
| if hparams.fp16_run: | |
| model.decoder.attention_layer.score_mask_value = finfo('float16').min | |
| if hparams.distributed_run: | |
| model = apply_gradient_allreduce(model) | |
| return model | |
| def warm_start_model(checkpoint_path, model, ignore_layers): | |
| assert os.path.isfile(checkpoint_path) | |
| print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') | |
| model_dict = checkpoint_dict['state_dict'] | |
| if len(ignore_layers) > 0: | |
| model_dict = {k: v for k, v in model_dict.items() | |
| if k not in ignore_layers} | |
| dummy_dict = model.state_dict() | |
| dummy_dict.update(model_dict) | |
| model_dict = dummy_dict | |
| model.load_state_dict(model_dict) | |
| return model | |
| def load_checkpoint(checkpoint_path, model, optimizer): | |
| assert os.path.isfile(checkpoint_path) | |
| print("Loading checkpoint '{}'".format(checkpoint_path)) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') | |
| model.load_state_dict(checkpoint_dict['state_dict']) | |
| optimizer.load_state_dict(checkpoint_dict['optimizer']) | |
| learning_rate = checkpoint_dict['learning_rate'] | |
| iteration = checkpoint_dict['iteration'] | |
| print("Loaded checkpoint '{}' from iteration {}" .format( | |
| checkpoint_path, iteration)) | |
| return model, optimizer, learning_rate, iteration | |
| def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): | |
| print("Saving model and optimizer state at iteration {} to {}".format( | |
| iteration, filepath)) | |
| torch.save({'iteration': iteration, | |
| 'state_dict': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'learning_rate': learning_rate}, filepath) | |
| def validate(model, criterion, valset, iteration, batch_size, n_gpus, | |
| collate_fn, logger, distributed_run, rank): | |
| """Handles all the validation scoring and printing""" | |
| model.eval() | |
| with torch.no_grad(): | |
| val_sampler = DistributedSampler(valset) if distributed_run else None | |
| val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, | |
| shuffle=False, batch_size=batch_size, | |
| pin_memory=False, collate_fn=collate_fn) | |
| val_loss = 0.0 | |
| for i, batch in enumerate(val_loader): | |
| x, y = model.parse_batch(batch) | |
| y_pred = model(x) | |
| loss = criterion(y_pred, y) | |
| if distributed_run: | |
| reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() | |
| else: | |
| reduced_val_loss = loss.item() | |
| val_loss += reduced_val_loss | |
| val_loss = val_loss / (i + 1) | |
| model.train() | |
| if rank == 0: | |
| print("Validation loss {}: {:9f} ".format(iteration, val_loss)) | |
| logger.log_validation(val_loss, model, y, y_pred, iteration) | |
| def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, | |
| rank, group_name, hparams): | |
| """Training and validation logging results to tensorboard and stdout | |
| Params | |
| ------ | |
| output_directory (string): directory to save checkpoints | |
| log_directory (string) directory to save tensorboard logs | |
| checkpoint_path(string): checkpoint path | |
| n_gpus (int): number of gpus | |
| rank (int): rank of current gpu | |
| hparams (object): comma separated list of "name=value" pairs. | |
| """ | |
| if hparams.distributed_run: | |
| init_distributed(hparams, n_gpus, rank, group_name) | |
| torch.manual_seed(hparams.seed) | |
| torch.cuda.manual_seed(hparams.seed) | |
| model = load_model(hparams) | |
| learning_rate = hparams.learning_rate | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, | |
| weight_decay=hparams.weight_decay) | |
| if hparams.fp16_run: | |
| from apex import amp | |
| model, optimizer = amp.initialize( | |
| model, optimizer, opt_level='O2') | |
| if hparams.distributed_run: | |
| model = apply_gradient_allreduce(model) | |
| criterion = Tacotron2Loss() | |
| logger = prepare_directories_and_logger( | |
| output_directory, log_directory, rank) | |
| train_loader, valset, collate_fn = prepare_dataloaders(hparams) | |
| # Load checkpoint if one exists | |
| iteration = 0 | |
| epoch_offset = 0 | |
| if checkpoint_path is not None: | |
| if warm_start: | |
| model = warm_start_model( | |
| checkpoint_path, model, hparams.ignore_layers) | |
| else: | |
| model, optimizer, _learning_rate, iteration = load_checkpoint( | |
| checkpoint_path, model, optimizer) | |
| if hparams.use_saved_learning_rate: | |
| learning_rate = _learning_rate | |
| iteration += 1 # next iteration is iteration + 1 | |
| epoch_offset = max(0, int(iteration / len(train_loader))) | |
| model.train() | |
| is_overflow = False | |
| # ================ MAIN TRAINNIG LOOP! =================== | |
| for epoch in range(epoch_offset, hparams.epochs): | |
| print("Epoch: {}".format(epoch)) | |
| for i, batch in enumerate(train_loader): | |
| start = time.perf_counter() | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = learning_rate | |
| model.zero_grad() | |
| x, y = model.parse_batch(batch) | |
| y_pred = model(x) | |
| loss = criterion(y_pred, y) | |
| if hparams.distributed_run: | |
| reduced_loss = reduce_tensor(loss.data, n_gpus).item() | |
| else: | |
| reduced_loss = loss.item() | |
| if hparams.fp16_run: | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| loss.backward() | |
| if hparams.fp16_run: | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| amp.master_params(optimizer), hparams.grad_clip_thresh) | |
| is_overflow = math.isnan(grad_norm) | |
| else: | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), hparams.grad_clip_thresh) | |
| optimizer.step() | |
| if not is_overflow and rank == 0: | |
| duration = time.perf_counter() - start | |
| print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( | |
| iteration, reduced_loss, grad_norm, duration)) | |
| logger.log_training( | |
| reduced_loss, grad_norm, learning_rate, duration, iteration) | |
| if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): | |
| validate(model, criterion, valset, iteration, | |
| hparams.batch_size, n_gpus, collate_fn, logger, | |
| hparams.distributed_run, rank) | |
| if rank == 0: | |
| checkpoint_path = os.path.join( | |
| output_directory, "checkpoint_{}".format(iteration)) | |
| save_checkpoint(model, optimizer, learning_rate, iteration, | |
| checkpoint_path) | |
| iteration += 1 | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-o', '--output_directory', type=str, | |
| help='directory to save checkpoints') | |
| parser.add_argument('-l', '--log_directory', type=str, | |
| help='directory to save tensorboard logs') | |
| parser.add_argument('-c', '--checkpoint_path', type=str, default=None, | |
| required=False, help='checkpoint path') | |
| parser.add_argument('--warm_start', action='store_true', | |
| help='load model weights only, ignore specified layers') | |
| parser.add_argument('--n_gpus', type=int, default=1, | |
| required=False, help='number of gpus') | |
| parser.add_argument('--rank', type=int, default=0, | |
| required=False, help='rank of current gpu') | |
| parser.add_argument('--group_name', type=str, default='group_name', | |
| required=False, help='Distributed group name') | |
| parser.add_argument('--hparams', type=str, | |
| required=False, help='comma separated name=value pairs') | |
| args = parser.parse_args() | |
| hparams = create_hparams(args.hparams) | |
| torch.backends.cudnn.enabled = hparams.cudnn_enabled | |
| torch.backends.cudnn.benchmark = hparams.cudnn_benchmark | |
| print("FP16 Run:", hparams.fp16_run) | |
| print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling) | |
| print("Distributed Run:", hparams.distributed_run) | |
| print("cuDNN Enabled:", hparams.cudnn_enabled) | |
| print("cuDNN Benchmark:", hparams.cudnn_benchmark) | |
| train(args.output_directory, args.log_directory, args.checkpoint_path, | |
| args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) | |