Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from tqdm import tqdm | |
| import torch | |
| import numpy as np | |
| from torch.utils.data import DataLoader | |
| from torch.nn.parallel import DistributedDataParallel | |
| from optimizer.optimizers import Eve, ScaledAdam | |
| from schedulers.scheduler import NoamScheduler, Eden | |
| from models.tts.valle.valle_dataset import ( | |
| VALLEDataset, | |
| VALLECollator, | |
| VariableSampler, | |
| batch_by_size, | |
| ) | |
| from models.tts.base import TTSTrainer | |
| from models.tts.valle.valle import VALLE | |
| class VALLETrainer(TTSTrainer): | |
| def __init__(self, args, cfg): | |
| TTSTrainer.__init__(self, args, cfg) | |
| def _build_model(self): | |
| model = VALLE(self.cfg.model) | |
| return model | |
| def _build_dataset(self): | |
| return VALLEDataset, VALLECollator | |
| def _build_optimizer(self): | |
| if self.args.train_stage: | |
| if isinstance(self.model, DistributedDataParallel): | |
| model = self.model.module | |
| else: | |
| model = self.model | |
| model_parameters = model.stage_parameters(self.args.train_stage) | |
| else: | |
| model_parameters = self.model.parameters() | |
| if self.cfg.train.optimizer == "ScaledAdam": | |
| parameters_names = [] | |
| if self.args.train_stage != 0: | |
| parameters_names.append( | |
| [ | |
| name_param_pair[0] | |
| for name_param_pair in model.stage_named_parameters( | |
| self.args.train_stage | |
| ) | |
| ] | |
| ) | |
| else: | |
| parameters_names.append( | |
| [name_param_pair[0] for name_param_pair in model.named_parameters()] | |
| ) | |
| optimizer = ScaledAdam( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| clipping_scale=2.0, | |
| parameters_names=parameters_names, | |
| show_dominant_parameters=False, | |
| clipping_update_period=1000, | |
| ) | |
| elif self.cfg.train.optimizer == "Eve": | |
| optimizer = Eve( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.98), | |
| target_rms=0.1, | |
| ) | |
| elif self.cfg.train.optimizer == "AdamW": | |
| optimizer = torch.optim.AdamW( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| weight_decay=1e-2, | |
| eps=1e-8, | |
| ) | |
| elif self.cfg.train.optimizer == "Adam": | |
| optimizer = torch.optim.Adam( | |
| model_parameters, | |
| lr=self.cfg.train.base_lr, | |
| betas=(0.9, 0.95), | |
| eps=1e-8, | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| return optimizer | |
| def _build_scheduler(self): | |
| if self.cfg.train.scheduler.lower() == "eden": | |
| scheduler = Eden( | |
| self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps | |
| ) | |
| elif self.cfg.train.scheduler.lower() == "noam": | |
| scheduler = NoamScheduler( | |
| self.cfg.train.base_lr, | |
| self.optimizer, | |
| self.cfg.model.decoder_dim, | |
| warmup_steps=self.cfg.train.warmup_steps, | |
| ) | |
| elif self.cfg.train.scheduler.lower() == "cosine": | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| self.cfg.train.warmup_steps, | |
| self.optimizer, | |
| eta_min=self.cfg.train.base_lr, | |
| ) | |
| else: | |
| raise NotImplementedError(f"{self.cfg.train.scheduler}") | |
| return scheduler | |
| def _train_epoch(self): | |
| r"""Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| if isinstance(self.model, dict): | |
| for key in self.model.keys(): | |
| self.model[key].train() | |
| else: | |
| self.model.train() | |
| epoch_sum_loss: float = 0.0 | |
| epoch_losses: dict = {} | |
| epoch_step: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Do training step and BP | |
| with self.accelerator.accumulate(self.model): | |
| total_loss, train_losses = self._train_step(batch) | |
| self.accelerator.backward(total_loss) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| self.batch_count += 1 | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]: | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| for k in range(self.cfg.train.gradient_accumulation_step): | |
| if isinstance(self.scheduler, Eden): | |
| self.scheduler.step_batch(self.step) | |
| else: | |
| self.scheduler.step() | |
| epoch_sum_loss += total_loss.detach().cpu().item() | |
| if isinstance(train_losses, dict): | |
| for key, value in train_losses.items(): | |
| if key not in epoch_losses.keys(): | |
| epoch_losses[key] = value | |
| else: | |
| epoch_losses[key] += value | |
| if isinstance(train_losses, dict): | |
| for key, loss in train_losses.items(): | |
| self.accelerator.log( | |
| {"Step/Train {}".format(key): "{:.6f}".format(loss)}, | |
| step=self.step, | |
| ) | |
| else: | |
| self.accelerator.log( | |
| {"Step/Train Loss": loss}, | |
| step=self.step, | |
| ) | |
| self.accelerator.log( | |
| {"Step/lr": self.scheduler.get_last_lr()[0]}, | |
| step=self.step, | |
| ) | |
| # print loss every log_epoch_step steps | |
| # if epoch_step % self.cfg.train.log_epoch_step == 0: | |
| # for key, loss in train_losses.items(): | |
| # self.logger.info("Step/Train {}: {:.6f}".format(key, loss)) | |
| # print("Step/Train {}: {:.6f}".format(key, loss)) | |
| self.step += 1 | |
| epoch_step += 1 | |
| self.accelerator.wait_for_everyone() | |
| epoch_sum_loss = ( | |
| epoch_sum_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = ( | |
| epoch_losses[key] | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| return epoch_sum_loss, epoch_losses | |
| def _train_step(self, batch, is_training=True): | |
| text_tokens = batch["phone_seq"].to(self.device) | |
| text_tokens_lens = batch["phone_len"].to(self.device) | |
| assert text_tokens.ndim == 2 | |
| audio_features = batch["acoustic_token"].to(self.device) | |
| audio_features_lens = batch["target_len"].to(self.device) | |
| assert audio_features.ndim == 3 | |
| with torch.set_grad_enabled(is_training): | |
| loss, losses = self.model( | |
| x=text_tokens, | |
| x_lens=text_tokens_lens, | |
| y=audio_features, | |
| y_lens=audio_features_lens, | |
| train_stage=self.args.train_stage, | |
| ) | |
| assert loss.requires_grad == is_training | |
| loss_dict = {} | |
| frames_sum = (audio_features_lens).sum() | |
| avg_loss = loss / frames_sum | |
| loss_dict["loss"] = avg_loss.detach().cpu().item() | |
| for l in losses: | |
| loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item() | |
| return avg_loss, loss_dict | |
| def _valid_step(self, batch): | |
| valid_losses = {} | |
| total_loss = 0 | |
| valid_stats = {} | |
| total_loss, valid_losses = self._train_step( | |
| batch=batch, | |
| is_training=False, | |
| ) | |
| assert total_loss.requires_grad is False | |
| total_loss = total_loss.detach().cpu().item() | |
| return total_loss, valid_losses, valid_stats | |
| def add_arguments(parser: argparse.ArgumentParser): | |
| parser.add_argument( | |
| "--train_stage", | |
| type=int, | |
| default="1", | |
| help="0: train all modules, 1: AR Decoder, 2: NAR Decoder", | |
| ) | |
| def _build_dataloader(self): | |
| if not self.cfg.train.use_dynamic_batchsize: | |
| return super()._build_dataloader() | |
| if len(self.cfg.dataset) > 1: | |
| raise Exception("use_dynamic_batchsize only supports single dataset now.") | |
| Dataset, Collator = self._build_dataset() | |
| train_dataset = Dataset( | |
| self.cfg, self.cfg.dataset[0], is_valid=False | |
| ) # TODO: support use_dynamic_batchsize for more than one datasets. | |
| train_collate = Collator(self.cfg) | |
| batch_sampler = batch_by_size( | |
| train_dataset.num_frame_indices, | |
| train_dataset.get_num_frames, | |
| max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, | |
| max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, | |
| required_batch_size_multiple=self.accelerator.num_processes, | |
| ) | |
| np.random.seed(1234) | |
| np.random.shuffle(batch_sampler) | |
| print(batch_sampler[:1]) | |
| batches = [ | |
| x[self.accelerator.local_process_index :: self.accelerator.num_processes] | |
| for x in batch_sampler | |
| if len(x) % self.accelerator.num_processes == 0 | |
| ] | |
| train_loader = DataLoader( | |
| train_dataset, | |
| collate_fn=train_collate, | |
| num_workers=self.cfg.train.dataloader.num_worker, | |
| batch_sampler=VariableSampler( | |
| batches, drop_last=False, use_random_sampler=True | |
| ), | |
| pin_memory=False, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True) | |
| valid_collate = Collator(self.cfg) | |
| batch_sampler = batch_by_size( | |
| valid_dataset.num_frame_indices, | |
| valid_dataset.get_num_frames, | |
| max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, | |
| max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, | |
| required_batch_size_multiple=self.accelerator.num_processes, | |
| ) | |
| batches = [ | |
| x[self.accelerator.local_process_index :: self.accelerator.num_processes] | |
| for x in batch_sampler | |
| if len(x) % self.accelerator.num_processes == 0 | |
| ] | |
| valid_loader = DataLoader( | |
| valid_dataset, | |
| collate_fn=valid_collate, | |
| num_workers=self.cfg.train.dataloader.num_worker, | |
| batch_sampler=VariableSampler(batches, drop_last=False), | |
| pin_memory=False, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| return train_loader, valid_loader | |
| def _accelerator_prepare(self): | |
| if not self.cfg.train.use_dynamic_batchsize: | |
| ( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| ) = self.accelerator.prepare( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| ) | |
| if isinstance(self.model, dict): | |
| for key in self.model.keys(): | |
| self.model[key] = self.accelerator.prepare(self.model[key]) | |
| else: | |
| self.model = self.accelerator.prepare(self.model) | |
| if isinstance(self.optimizer, dict): | |
| for key in self.optimizer.keys(): | |
| self.optimizer[key] = self.accelerator.prepare(self.optimizer[key]) | |
| else: | |
| self.optimizer = self.accelerator.prepare(self.optimizer) | |
| if isinstance(self.scheduler, dict): | |
| for key in self.scheduler.keys(): | |
| self.scheduler[key] = self.accelerator.prepare(self.scheduler[key]) | |
| else: | |
| self.scheduler = self.accelerator.prepare(self.scheduler) | |