Spaces:
Configuration error
Configuration error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import time | |
| import random | |
| from pathlib import Path | |
| import re | |
| import glob | |
| import accelerate | |
| import json | |
| import numpy as np | |
| import torch | |
| from accelerate.utils import ProjectConfiguration | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from accelerate.logging import get_logger | |
| from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator | |
| from models.codec.codec_sampler import build_samplers | |
| from models.codec.codec_trainer import CodecTrainer | |
| from modules.dac.nn.loss import ( | |
| MultiScaleSTFTLoss, | |
| MelSpectrogramLoss, | |
| GANLoss, | |
| L1Loss, | |
| FocalLoss, | |
| ) | |
| from audiotools import AudioSignal | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| try: | |
| import nemo.collections.asr as nemo_asr | |
| except ImportError: | |
| print( | |
| "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING" | |
| ) | |
| nemo_asr = None | |
| from models.codec.facodec.modules.commons import ( | |
| build_model, | |
| load_checkpoint, | |
| load_F0_models, | |
| log_norm, | |
| ) | |
| from models.codec.facodec.optimizer import build_optimizer | |
| class FAcodecTrainer(CodecTrainer): | |
| def __init__(self, args, cfg): | |
| super().__init__() | |
| self.args = args | |
| self.cfg = cfg | |
| cfg.exp_name = args.exp_name | |
| # Init accelerator | |
| self._init_accelerator() | |
| self.accelerator.wait_for_everyone() | |
| # Init logger | |
| with self.accelerator.main_process_first(): | |
| self.logger = get_logger(args.exp_name, log_level=args.log_level) | |
| self.logger.info("=" * 56) | |
| self.logger.info("||\t\t" + "New training process started." + "\t\t||") | |
| self.logger.info("=" * 56) | |
| self.logger.info("\n") | |
| self.logger.debug(f"Using {args.log_level.upper()} logging level.") | |
| self.logger.info(f"Experiment name: {args.exp_name}") | |
| self.logger.info(f"Experiment directory: {self.exp_dir}") | |
| self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") | |
| if self.accelerator.is_main_process: | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") | |
| # Init training status | |
| self.batch_count: int = 0 | |
| self.step: int = 0 | |
| self.epoch: int = 0 | |
| self.max_epoch = ( | |
| self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") | |
| ) | |
| self.logger.info( | |
| "Max epoch: {}".format( | |
| self.max_epoch if self.max_epoch < float("inf") else "Unlimited" | |
| ) | |
| ) | |
| # Check potential erorrs | |
| if self.accelerator.is_main_process: | |
| self._check_basic_configs() | |
| self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride | |
| self.checkpoints_path = [ | |
| [] for _ in range(len(self.save_checkpoint_stride)) | |
| ] | |
| self.run_eval = self.cfg.train.run_eval | |
| # Set random seed | |
| with self.accelerator.main_process_first(): | |
| start = time.monotonic_ns() | |
| self._set_random_seed(self.cfg.train.random_seed) | |
| end = time.monotonic_ns() | |
| self.logger.debug( | |
| f"Setting random seed done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") | |
| # Build dataloader | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building dataset...") | |
| start = time.monotonic_ns() | |
| self.train_dataloader, self.valid_dataloader = self._build_dataloader() | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") | |
| # Build model | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building model...") | |
| start = time.monotonic_ns() | |
| self.model = self._build_model() | |
| end = time.monotonic_ns() | |
| for _, model in self.model.items(): | |
| self.logger.debug(model) | |
| self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") | |
| self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M") | |
| # Build optimizers and schedulers | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building optimizer and scheduler...") | |
| start = time.monotonic_ns() | |
| self.optimizer = self._build_optimizer() | |
| end = time.monotonic_ns() | |
| self.logger.info( | |
| f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| # Build helper models | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building helper models...") | |
| start = time.monotonic_ns() | |
| self._built_helper_model() | |
| end = time.monotonic_ns() | |
| self.logger.info( | |
| f"Building helper models done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| # Accelerator preparing | |
| self.logger.info("Initializing accelerate...") | |
| start = time.monotonic_ns() | |
| for k in self.model: | |
| self.model[k] = self.accelerator.prepare(self.model[k]) | |
| for k, v in self.optimizer.optimizers.items(): | |
| self.optimizer.optimizers[k] = self.accelerator.prepare( | |
| self.optimizer.optimizers[k] | |
| ) | |
| self.optimizer.schedulers[k] = self.accelerator.prepare( | |
| self.optimizer.schedulers[k] | |
| ) | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") | |
| # Build criterions | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building criterion...") | |
| start = time.monotonic_ns() | |
| self.criterions = self._build_criterion() | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") | |
| # Resume checkpoints | |
| with self.accelerator.main_process_first(): | |
| self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") | |
| if args.resume_type: | |
| self.logger.info("Resuming from checkpoint...") | |
| start = time.monotonic_ns() | |
| ckpt_path = Path(args.checkpoint) | |
| if self._is_valid_pattern(ckpt_path.parts[-1]): | |
| ckpt_path = self._load_model(args.checkpoint, args.resume_type) | |
| else: | |
| ckpt_path = self._load_model( | |
| args.checkpoint, resume_type=args.resume_type | |
| ) | |
| end = time.monotonic_ns() | |
| self.logger.info( | |
| f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| self.checkpoints_path = json.load( | |
| open(os.path.join(ckpt_path, "ckpts.json"), "r") | |
| ) | |
| if self.accelerator.is_main_process: | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") | |
| # Save config | |
| self.config_save_path = os.path.join(self.exp_dir, "args.json") | |
| def _build_dataset(self): | |
| return FAcodecDataset, FAcodecCollator | |
| def _build_criterion(self): | |
| criterions = dict() | |
| stft_criterion = MultiScaleSTFTLoss() | |
| mel_criterion = MelSpectrogramLoss( | |
| n_mels=[5, 10, 20, 40, 80, 160, 320], | |
| window_lengths=[32, 64, 128, 256, 512, 1024, 2048], | |
| mel_fmin=[0, 0, 0, 0, 0, 0, 0], | |
| mel_fmax=[None, None, None, None, None, None, None], | |
| pow=1.0, | |
| mag_weight=0.0, | |
| clamp_eps=1e-5, | |
| ) | |
| content_criterion = FocalLoss(gamma=2) | |
| l1_criterion = L1Loss() | |
| criterions["stft"] = stft_criterion | |
| criterions["mel"] = mel_criterion | |
| criterions["l1"] = l1_criterion | |
| criterions["content"] = content_criterion | |
| return criterions | |
| def _build_model(self): | |
| model = build_model(self.cfg.model_params) | |
| _ = [model[key].to(self.accelerator.device) for key in model] | |
| return model | |
| def _built_helper_model(self): | |
| device = self.accelerator.device | |
| self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device) | |
| # load model and processor | |
| self.w2v_processor = Wav2Vec2Processor.from_pretrained( | |
| "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| ) | |
| self.w2v_model = Wav2Vec2ForCTC.from_pretrained( | |
| "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| ).to(device) | |
| self.w2v_model.eval() | |
| if nemo_asr is None: | |
| self.speaker_model = None | |
| else: | |
| self.speaker_model = ( | |
| nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | |
| "nvidia/speakerverification_en_titanet_large" | |
| ) | |
| ) | |
| self.speaker_model = self.speaker_model.to(device) | |
| self.speaker_model.eval() | |
| def _build_optimizer(self): | |
| scheduler_params = { | |
| "warmup_steps": self.cfg.loss_params.warmup_steps, | |
| "base_lr": self.cfg.loss_params.base_lr, | |
| } | |
| optimizer = build_optimizer( | |
| {key: self.model[key] for key in self.model}, | |
| scheduler_params_dict={key: scheduler_params.copy() for key in self.model}, | |
| lr=float(scheduler_params["base_lr"]), | |
| ) | |
| return optimizer | |
| def train_loop(self): | |
| """Training process""" | |
| self.accelerator.wait_for_everyone() | |
| # Dump config | |
| if self.accelerator.is_main_process: | |
| self._dump_cfg(self.config_save_path) | |
| _ = [self.model[key].train() for key in self.model] | |
| self.optimizer.zero_grad() | |
| # Sync and start training | |
| self.accelerator.wait_for_everyone() | |
| while self.epoch < self.max_epoch: | |
| self.logger.info("\n") | |
| self.logger.info("-" * 32) | |
| self.logger.info("Epoch {}: ".format(self.epoch)) | |
| # Train and Validate | |
| train_total_loss, train_losses = self._train_epoch() | |
| for key, loss in train_losses.items(): | |
| self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) | |
| self.accelerator.log( | |
| {"Epoch/Train {} Loss".format(key): loss}, | |
| step=self.epoch, | |
| ) | |
| self.accelerator.log( | |
| { | |
| "Epoch/Train Total Loss": train_total_loss, | |
| }, | |
| step=self.epoch, | |
| ) | |
| # Update scheduler | |
| self.accelerator.wait_for_everyone() | |
| # Check save checkpoint interval | |
| run_eval = False | |
| if self.accelerator.is_main_process: | |
| save_checkpoint = False | |
| for i, num in enumerate(self.save_checkpoint_stride): | |
| if self.epoch % num == 0: | |
| save_checkpoint = True | |
| run_eval |= self.run_eval[i] | |
| # Save checkpoints | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process and save_checkpoint: | |
| print("Saving..") | |
| state = { | |
| "net": {key: self.model[key].state_dict() for key in self.model}, | |
| "optimizer": self.optimizer.state_dict(), | |
| "scheduler": self.optimizer.scheduler_state_dict(), | |
| "iters": self.step, | |
| "epoch": self.epoch, | |
| } | |
| save_path = os.path.join( | |
| self.checkpoint_dir, | |
| "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters), | |
| ) | |
| torch.save(state, save_path) | |
| json.dump( | |
| self.checkpoints_path, | |
| open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"), | |
| ensure_ascii=False, | |
| indent=4, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| self.epoch += 1 | |
| # Finish training | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process: | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}".format( | |
| self.epoch, | |
| self.step, | |
| ), | |
| ) | |
| print("Saving..") | |
| state = { | |
| "net": {key: self.model[key].state_dict() for key in self.model}, | |
| "optimizer": self.optimizer.state_dict(), | |
| "scheduler": self.optimizer.scheduler_state_dict(), | |
| "iters": self.step, | |
| "epoch": self.epoch, | |
| } | |
| save_path = os.path.join( | |
| self.checkpoint_dir, | |
| "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters), | |
| ) | |
| torch.save(state, save_path) | |
| def _train_epoch(self): | |
| """Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| _ = [self.model[key].train() for key in self.model] | |
| epoch_losses: dict = {} | |
| epoch_total_loss: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Get losses | |
| total_loss, losses = self._train_step(batch) | |
| self.batch_count += 1 | |
| # Log info | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| self.accelerator.log( | |
| { | |
| "Step/Learning Rate": ( | |
| self.optimizer.schedulers["encoder"].get_last_lr()[0] | |
| if self.step != 0 | |
| else 0 | |
| ) | |
| }, | |
| step=self.step, | |
| ) | |
| for key, _ in losses.items(): | |
| self.accelerator.log( | |
| { | |
| "Step/Train {} Loss".format(key): losses[key], | |
| }, | |
| step=self.step, | |
| ) | |
| if not epoch_losses: | |
| epoch_losses = losses | |
| else: | |
| for key, value in losses.items(): | |
| epoch_losses[key] += value | |
| epoch_total_loss += total_loss | |
| self.step += 1 | |
| # Get and log total losses | |
| self.accelerator.wait_for_everyone() | |
| epoch_total_loss = ( | |
| epoch_total_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = ( | |
| epoch_losses[key] | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| return epoch_total_loss, epoch_losses | |
| def _train_step(self, data): | |
| """Training forward step. Should return average loss of a sample over | |
| one batch. Provoke ``_forward_step`` is recommended except for special case. | |
| See ``_train_epoch`` for usage. | |
| """ | |
| # Init losses | |
| train_losses = {} | |
| total_loss = 0 | |
| # Use input feature to get predictions | |
| data = [b.to(self.accelerator.device, non_blocking=True) for b in data] | |
| waves, mels, wave_lengths, mel_input_length = data | |
| # extract semantic latent with w2v model | |
| waves_16k = torchaudio.functional.resample(waves, 24000, 16000) | |
| w2v_input = self.w2v_processor( | |
| waves_16k, sampling_rate=16000, return_tensors="pt" | |
| ).input_values.to(self.accelerator.device) | |
| with torch.no_grad(): | |
| w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits | |
| predicted_ids = torch.argmax(w2v_outputs, dim=-1) | |
| phone_ids = ( | |
| F.interpolate( | |
| predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest" | |
| ) | |
| .long() | |
| .squeeze(0) | |
| ) | |
| # get clips | |
| mel_seg_len = min( | |
| [int(mel_input_length.min().item()), self.cfg.train.max_frame_len] | |
| ) | |
| gt_mel_seg = [] | |
| wav_seg = [] | |
| w2v_seg = [] | |
| for bib in range(len(mel_input_length)): | |
| mel_length = int(mel_input_length[bib].item()) | |
| random_start = ( | |
| np.random.randint(0, mel_length - mel_seg_len) | |
| if mel_length != mel_seg_len | |
| else 0 | |
| ) | |
| gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len]) | |
| # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len]) | |
| w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len]) | |
| y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300] | |
| wav_seg.append(y.to(self.accelerator.device)) | |
| gt_mel_seg = torch.stack(gt_mel_seg).detach() | |
| wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1) | |
| w2v_seg = torch.stack(w2v_seg).float().detach() | |
| with torch.no_grad(): | |
| real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach() | |
| F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1)) | |
| # normalize f0 | |
| # Remove unvoiced frames (replace with -1) | |
| gt_glob_f0s = [] | |
| f0_targets = [] | |
| for bib in range(len(F0_real)): | |
| voiced_indices = F0_real[bib] > 5.0 | |
| f0_voiced = F0_real[bib][voiced_indices] | |
| if len(f0_voiced) != 0: | |
| # Convert to log scale | |
| log_f0 = f0_voiced.log2() | |
| # Calculate mean and standard deviation | |
| mean_f0 = log_f0.mean() | |
| std_f0 = log_f0.std() | |
| # Normalize the F0 sequence | |
| normalized_f0 = (log_f0 - mean_f0) / std_f0 | |
| # Create the normalized F0 sequence with unvoiced frames | |
| normalized_sequence = torch.zeros_like(F0_real[bib]) | |
| normalized_sequence[voiced_indices] = normalized_f0 | |
| normalized_sequence[~voiced_indices] = ( | |
| -10 | |
| ) # Assign -10 to unvoiced frames | |
| gt_glob_f0s.append(mean_f0) | |
| else: | |
| normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0 | |
| gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device)) | |
| # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200]) | |
| f0_targets.append(normalized_sequence) | |
| f0_targets = torch.stack(f0_targets).to(self.accelerator.device) | |
| # fill nan with -10 | |
| f0_targets[torch.isnan(f0_targets)] = -10.0 | |
| # fill inf with -10 | |
| f0_targets[torch.isinf(f0_targets)] = -10.0 | |
| # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate | |
| if self.cfg.preprocess_params.frame_rate != 80: | |
| f0_targets = F.interpolate( | |
| f0_targets.unsqueeze(1), | |
| mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate, | |
| mode="nearest", | |
| ).squeeze(1) | |
| w2v_seg = F.interpolate( | |
| w2v_seg, | |
| mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate, | |
| mode="nearest", | |
| ) | |
| wav_seg_input = wav_seg | |
| wav_seg_target = wav_seg | |
| z = self.model.encoder(wav_seg_input) | |
| z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( | |
| z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths | |
| ) | |
| preds, rev_preds = self.model.fa_predictors(quantized, timbre) | |
| pred_wave = self.model.decoder(z) | |
| len_diff = wav_seg_target.size(-1) - pred_wave.size(-1) | |
| if len_diff > 0: | |
| wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2] | |
| # discriminator loss | |
| d_fake = self.model.discriminator(pred_wave.detach()) | |
| d_real = self.model.discriminator(wav_seg_target) | |
| loss_d = 0 | |
| for x_fake, x_real in zip(d_fake, d_real): | |
| loss_d += torch.mean(x_fake[-1] ** 2) | |
| loss_d += torch.mean((1 - x_real[-1]) ** 2) | |
| self.optimizer.zero_grad() | |
| self.accelerator.backward(loss_d) | |
| grad_norm_d = torch.nn.utils.clip_grad_norm_( | |
| self.model.discriminator.parameters(), 10.0 | |
| ) | |
| self.optimizer.step("discriminator") | |
| self.optimizer.scheduler(key="discriminator") | |
| # generator loss | |
| signal = AudioSignal(wav_seg_target, sample_rate=24000) | |
| recons = AudioSignal(pred_wave, sample_rate=24000) | |
| stft_loss = self.criterions["stft"](recons, signal) | |
| mel_loss = self.criterions["mel"](recons, signal) | |
| waveform_loss = self.criterions["l1"](recons, signal) | |
| d_fake = self.model.discriminator(pred_wave) | |
| d_real = self.model.discriminator(wav_seg_target) | |
| loss_g = 0 | |
| for x_fake in d_fake: | |
| loss_g += torch.mean((1 - x_fake[-1]) ** 2) | |
| loss_feature = 0 | |
| for i in range(len(d_fake)): | |
| for j in range(len(d_fake[i]) - 1): | |
| loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) | |
| pred_f0, pred_uv = preds["f0"], preds["uv"] | |
| rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"] | |
| common_min_size = min(pred_f0.size(-2), f0_targets.size(-1)) | |
| f0_targets = f0_targets[..., :common_min_size] | |
| real_norm = real_norm[..., :common_min_size] | |
| f0_loss = F.smooth_l1_loss( | |
| f0_targets, pred_f0.squeeze(-1)[..., :common_min_size] | |
| ) | |
| uv_loss = F.smooth_l1_loss( | |
| real_norm, pred_uv.squeeze(-1)[..., :common_min_size] | |
| ) | |
| rev_f0_loss = ( | |
| F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size]) | |
| if rev_pred_f0 is not None | |
| else torch.FloatTensor([0]).to(self.accelerator.device) | |
| ) | |
| rev_uv_loss = ( | |
| F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size]) | |
| if rev_pred_uv is not None | |
| else torch.FloatTensor([0]).to(self.accelerator.device) | |
| ) | |
| tot_f0_loss = f0_loss + rev_f0_loss | |
| tot_uv_loss = uv_loss + rev_uv_loss | |
| pred_content = preds["content"] | |
| rev_pred_content = rev_preds["rev_content"] | |
| target_content_latents = w2v_seg[..., :common_min_size] | |
| content_loss = self.criterions["content"]( | |
| pred_content.transpose(1, 2)[..., :common_min_size], | |
| target_content_latents.long(), | |
| ) | |
| rev_content_loss = ( | |
| self.criterions["content"]( | |
| rev_pred_content.transpose(1, 2)[..., :common_min_size], | |
| target_content_latents.long(), | |
| ) | |
| if rev_pred_content is not None | |
| else torch.FloatTensor([0]).to(self.accelerator.device) | |
| ) | |
| tot_content_loss = content_loss + rev_content_loss | |
| if self.speaker_model is not None: | |
| spk_logits = torch.cat( | |
| [ | |
| self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1] | |
| for w16, wl in zip(waves_16k, wave_lengths) | |
| ], | |
| dim=0, | |
| ) | |
| spk_labels = spk_logits.argmax(dim=-1) | |
| else: | |
| spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to( | |
| self.accelerator.device | |
| ) | |
| spk_pred_logits = preds["timbre"] | |
| spk_loss = F.cross_entropy(spk_pred_logits, spk_labels) | |
| x_spk_pred_logits = rev_preds["x_timbre"] | |
| x_spk_loss = ( | |
| F.cross_entropy(x_spk_pred_logits, spk_labels) | |
| if x_spk_pred_logits is not None | |
| else torch.FloatTensor([0]).to(self.accelerator.device) | |
| ) | |
| tot_spk_loss = spk_loss + x_spk_loss | |
| loss_gen_all = ( | |
| mel_loss * 15.0 | |
| + loss_feature * 1.0 | |
| + loss_g * 1.0 | |
| + commitment_loss * 0.25 | |
| + codebook_loss * 1.0 | |
| + tot_f0_loss * 1.0 | |
| + tot_uv_loss * 1.0 | |
| + tot_content_loss * 5.0 | |
| + tot_spk_loss * 5.0 | |
| ) | |
| self.optimizer.zero_grad() | |
| self.accelerator.backward(loss_gen_all) | |
| with torch.no_grad(): | |
| total_loss = loss_gen_all.item() | |
| train_losses["stft"] = stft_loss.item() | |
| train_losses["mel"] = mel_loss.item() | |
| train_losses["l1"] = waveform_loss.item() | |
| train_losses["f0"] = f0_loss.item() | |
| train_losses["uv"] = uv_loss.item() | |
| train_losses["content"] = content_loss.item() | |
| train_losses["speaker"] = spk_loss.item() | |
| train_losses["rev_f0"] = rev_f0_loss.item() | |
| train_losses["rev_uv"] = rev_uv_loss.item() | |
| train_losses["rev_content"] = rev_content_loss.item() | |
| train_losses["rev_speaker"] = x_spk_loss.item() | |
| train_losses["feature"] = loss_feature.item() | |
| train_losses["generator"] = loss_g.item() | |
| train_losses["commitment"] = commitment_loss.item() | |
| train_losses["codebook"] = codebook_loss.item() | |
| # discriminators | |
| train_losses["discriminator"] = loss_d.item() | |
| return total_loss, train_losses | |
| def _inference(self, eval_wave): | |
| """Inference during training for test audios.""" | |
| z = self.model.encoder( | |
| eval_wave[None, None, ...].to(self.accelerator.device).float() | |
| ) | |
| z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( | |
| z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks | |
| ) | |
| full_pred_wave = self.model.decoder(z) | |
| return full_pred_wave[0] | |
| def _load_model(self, checkpoint_path=None, resume_type="resume"): | |
| """Load model from checkpoint. If checkpoint_path is None, it will | |
| load the latest checkpoint in checkpoint_dir. If checkpoint_path is not | |
| None, it will load the checkpoint specified by checkpoint_path. **Only use this | |
| method after** ``accelerator.prepare()``. | |
| """ | |
| if resume_type == "resume": | |
| if checkpoint_path is None: | |
| available_checkpoints = glob.glob( | |
| os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth") | |
| ) | |
| # find the checkpoint that has the highest step number | |
| latest_checkpoint = max( | |
| available_checkpoints, | |
| key=lambda x: int(x.split("_")[-1].split(".")[0]), | |
| ) | |
| earliest_checkpoint = min( | |
| available_checkpoints, | |
| key=lambda x: int(x.split("_")[-1].split(".")[0]), | |
| ) | |
| # delete the earliest checkpoint | |
| if ( | |
| earliest_checkpoint != latest_checkpoint | |
| and self.accelerator.is_main_process | |
| and len(available_checkpoints) > 4 | |
| ): | |
| os.remove(earliest_checkpoint) | |
| print(f"Removed {earliest_checkpoint}") | |
| else: | |
| latest_checkpoint = checkpoint_path | |
| self.model, self.optimizer, self.epoch, self.step = load_checkpoint( | |
| self.model, | |
| self.optimizer, | |
| latest_checkpoint, | |
| load_only_params=False, | |
| ignore_modules=[], | |
| is_distributed=self.accelerator.num_processes > 1, | |
| ) | |
| else: | |
| raise ValueError("Invalid resume type") | |
| return checkpoint_path | |
| def _count_parameters(self): | |
| total_num = sum( | |
| sum(p.numel() for p in self.model[key].parameters()) for key in self.model | |
| ) | |
| # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| return total_num | |