# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from contextlib import nullcontext import torch import torch.distributed as dist from cosyvoice.utils.train_utils import (batch_backward, batch_forward, cosyvoice_join, log_per_save, log_per_step, save_model, update_parameter_and_lr) from loguru import logger class Executor: """Executor for training and cross validation""" def __init__( self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None, use_contrastive_fm: bool = False ): self.gan = gan self.ref_model = ref_model self.dpo_loss = dpo_loss self.step = 0 self.epoch = 0 self.rank = int(os.environ.get("RANK", 0)) self.device = torch.device(f"cuda:{self.rank}") self.use_contrastive_fm = use_contrastive_fm def train_one_epoc( self, model, optimizer, scheduler, train_data_loader, experiment, info_dict, scaler, model_type ): """Train one epoch""" lr = optimizer.param_groups[0]["lr"] logger.info( f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}" ) logger.info( f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before" ) model.train() if self.ref_model is not None: self.ref_model.eval() use_ddp = info_dict["train_engine"] == "torch_ddp" for batch_idx, batch_dict in enumerate(train_data_loader): info_dict["tag"] = "TRAIN" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx for key, value in batch_dict.items(): if isinstance(value, torch.Tensor): print(f'{key} {value.shape}') if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync else: context = nullcontext with context(): logger.info(f'{self.rank} batch_forward') info_dict = batch_forward( model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss, ) logger.info(f'{self.rank} batch_backward') info_dict = batch_backward(model, scaler, info_dict) logger.info(f'{self.rank} update_parameter_and_lr') info_dict = update_parameter_and_lr( model, optimizer, scheduler, scaler, info_dict, model_type=model_type ) logger.info(f'{self.rank} log_per_step') log_per_step(experiment, info_dict) if ( info_dict.get("save_per_step", -1) > 0 and (self.step) % info_dict["save_per_step"] == 0 and (batch_idx) % info_dict["accum_grad"] == 0 ): if dist.is_initialized(): dist.barrier() model_name = ( f"epoch_{self.epoch}_step_{self.step + 1}" ) save_model(model, model_name, info_dict) model.train() if (batch_idx + 1) % info_dict["accum_grad"] == 0: self.step += 1 dist.barrier() @torch.inference_mode() def cv(self, model, cv_data_loader, experiment, info_dict, on_batch_end=True): """Cross validation on""" logger.info(f"Epoch {self.epoch} Step {self.step + 1} on_batch_end {on_batch_end} CV rank {self.rank}") model.eval() total_num_utts, total_loss_dict = 0, {} # avoid division by 0 for batch_idx, batch_dict in enumerate(cv_data_loader): info_dict["tag"] = "CV" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx num_utts = len(batch_dict["utts"]) total_num_utts += num_utts if self.gan is True: batch_dict["turn"] = "generator" info_dict = batch_forward(model, batch_dict, None, info_dict) for k, v in info_dict["loss_dict"].items(): if k not in total_loss_dict: total_loss_dict[k] = [] total_loss_dict[k].append(v.item() * num_utts) log_per_step(None, info_dict) for k, v in total_loss_dict.items(): total_loss_dict[k] = sum(v) / total_num_utts info_dict["loss_dict"] = total_loss_dict log_per_save(experiment, info_dict) model_name = ( f"epoch_{self.epoch}_whole" if on_batch_end else f"epoch_{self.epoch}_step_{self.step + 1}" ) save_model(model, model_name, info_dict)