| import os | |
| import json | |
| import argparse | |
| import math | |
| import torch | |
| from torch import nn, optim | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| from data_utils import TextMelLoader, TextMelCollate | |
| import models | |
| import commons | |
| import utils | |
| class FlowGenerator_DDI(models.FlowGenerator): | |
| """A helper for Data-dependent Initialization""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| for f in self.decoder.flows: | |
| if getattr(f, "set_ddi", False): | |
| f.set_ddi(True) | |
| def main(): | |
| hps = utils.get_hparams() | |
| logger = utils.get_logger(hps.log_dir) | |
| logger.info(hps) | |
| utils.check_git_hash(hps.log_dir) | |
| torch.manual_seed(hps.train.seed) | |
| train_dataset = TextMelLoader(hps.data.training_files, hps.data) | |
| collate_fn = TextMelCollate(1) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| num_workers=8, | |
| shuffle=True, | |
| batch_size=hps.train.batch_size, | |
| pin_memory=True, | |
| drop_last=True, | |
| collate_fn=collate_fn, | |
| ) | |
| symbols = hps.data.punc + hps.data.chars | |
| generator = FlowGenerator_DDI( | |
| len(symbols) + getattr(hps.data, "add_blank", False), | |
| out_channels=hps.data.n_mel_channels, | |
| **hps.model | |
| ).cuda() | |
| optimizer_g = commons.Adam( | |
| generator.parameters(), | |
| scheduler=hps.train.scheduler, | |
| dim_model=hps.model.hidden_channels, | |
| warmup_steps=hps.train.warmup_steps, | |
| lr=hps.train.learning_rate, | |
| betas=hps.train.betas, | |
| eps=hps.train.eps, | |
| ) | |
| generator.train() | |
| for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader): | |
| x, x_lengths = x.cuda(), x_lengths.cuda() | |
| y, y_lengths = y.cuda(), y_lengths.cuda() | |
| _ = generator(x, x_lengths, y, y_lengths, gen=False) | |
| break | |
| utils.save_checkpoint( | |
| generator, | |
| optimizer_g, | |
| hps.train.learning_rate, | |
| 0, | |
| os.path.join(hps.model_dir, "ddi_G.pth"), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |