# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # 2023 Horizon Inc. (authors: Xingchen Song) # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import json import re import datetime import yaml import deepspeed import torch.optim as optim import torch.distributed as dist from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from loguru import logger from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live from cosyvoice.dataset.dataset import Dataset from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR, _LRScheduler from loguru import logger class ResumableSequentialLR(_LRScheduler): """A resumable version of SequentialLR that properly manages child schedulers""" def __init__(self, optimizer, schedulers, milestones, last_epoch=-1): """ Args: optimizer: Wrapped optimizer schedulers: List of schedulers to sequentially use milestones: List of epoch/step numbers when to switch schedulers last_epoch: The index of last epoch/step """ # Validate inputs if len(schedulers) != len(milestones) + 1: raise ValueError("Expected len(schedulers) == len(milestones) + 1") self.schedulers = schedulers self.milestones = milestones self._scheduler_idx = 0 # Initialize parent class (this sets last_epoch and calls step()) super().__init__(optimizer, last_epoch) def _get_scheduler_info(self, epoch): """Determine which scheduler to use and its relative epoch""" scheduler_idx = 0 relative_epoch = epoch for i, milestone in enumerate(self.milestones): if epoch >= milestone: scheduler_idx = i + 1 if i == 0: relative_epoch = epoch - milestone else: relative_epoch = epoch - milestone else: break # Calculate relative epoch for the current scheduler if scheduler_idx == 0: relative_epoch = epoch elif scheduler_idx < len(self.milestones): if scheduler_idx == 1: relative_epoch = epoch - self.milestones[0] else: relative_epoch = epoch - self.milestones[scheduler_idx - 1] return scheduler_idx, relative_epoch def get_lr(self): """Get learning rate from the appropriate scheduler""" if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) # Get current scheduler and its relative epoch scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch) scheduler = self.schedulers[scheduler_idx] # Set the scheduler's last_epoch to match relative progress scheduler.last_epoch = relative_epoch # Get LR from the scheduler if hasattr(scheduler, '_get_closed_form_lr'): return scheduler._get_closed_form_lr() else: # Temporarily set the flag to avoid warning from child scheduler scheduler._get_lr_called_within_step = True lrs = scheduler.get_lr() scheduler._get_lr_called_within_step = False return lrs def step(self, epoch=None): """Step the scheduler""" # Step the parent class (updates last_epoch and sets _get_lr_called_within_step) super().step(epoch) def set_step(self, step): """Set the current step for resuming training""" self.last_epoch = step - 1 # Update child schedulers' state scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1) # Set all previous schedulers to their final state for i in range(scheduler_idx): if i < len(self.milestones): if i == 0: self.schedulers[i].last_epoch = self.milestones[i] - 1 else: self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1 # Set current scheduler to its relative position self.schedulers[scheduler_idx].last_epoch = relative_epoch # Update optimizer's learning rates for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()): param_group['lr'] = lr def init_distributed(args): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logger.info(f'training on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}') if args.train_engine == 'torch_ddp': torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend) else: deepspeed.init_distributed(dist_backend=args.dist_backend) return world_size, local_rank, rank def init_dataset_and_dataloader(args, configs, dpo): data_pipeline = configs['data_pipeline'] train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=True, partition=True) cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=False, partition=False) # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts train_data_loader = DataLoader(train_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) cv_data_loader = DataLoader(cv_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) return train_dataset, cv_dataset, train_data_loader, cv_data_loader def check_modify_and_save_config(args, configs): """Check and modify config""" if args.train_engine == "torch_ddp": configs['train_conf']["dtype"] = 'fp32' else: with open(args.deepspeed_config, 'r') as fin: ds_configs = json.load(fin) if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: configs['train_conf']["dtype"] = "fp16" elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: configs['train_conf']["dtype"] = "bf16" else: configs['train_conf']["dtype"] = "fp32" assert ds_configs["train_micro_batch_size_per_gpu"] == 1 # if use deepspeed, override ddp config configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] return configs def wrap_cuda_model(args, model): """Wrap model to cuda""" local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if args.train_engine == "torch_ddp": # native pytorch ddp assert (torch.cuda.is_available()) model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) else: if int(os.environ.get('RANK', 0)) == 0: logger.info("Estimating model states memory needs (zero2)...") estimate_zero2_model_states_mem_needs_all_live( model, num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) return model def init_optimizer_and_scheduler(configs, model): """Init optimizer and scheduler""" lr = configs['train_conf']['optim_conf']['lr'] logger.info(f"lr base: {lr}") if configs['train_conf']['optim'] == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr) elif configs['train_conf']['optim'] == 'adamw': optimizer = optim.AdamW(model.parameters(), lr=lr) else: raise ValueError("unknown optimizer: " + configs['train_conf']) warm_up_steps = configs['train_conf']['scheduler_conf']['warmup_steps'] total_iters = configs['train_conf']['total_iters'] # Create schedulers warmup_scheduler = LinearLR( optimizer, start_factor=1e-4, # Start at nearly 0 end_factor=1.0, # End at base learning rate total_iters=warm_up_steps # 5k warmup steps ) constant_scheduler = ConstantLR( optimizer, factor=1.0, # Keep learning rate constant total_iters=total_iters # Run indefinitely ) # Combine schedulers: warmup for 5k steps, then constant scheduler = ResumableSequentialLR( optimizer, schedulers=[warmup_scheduler, constant_scheduler], milestones=[warm_up_steps] ) return model, optimizer, scheduler def save_model(model, model_name, info_dict): """Save model""" rank = int(os.environ.get('RANK', 0)) model_dir = info_dict["model_dir"] os.makedirs(model_dir, exist_ok=True) save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) if info_dict["train_engine"] == "torch_ddp": if rank == 0: torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) else: with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, tag=model_name, client_state=info_dict) if rank == 0: info_path = re.sub('.pt$', '.yaml', save_model_path) info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') with open(info_path, 'w') as fout: data = yaml.dump(info_dict) fout.write(data) logger.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) def cosyvoice_join(group_join, info_dict): """Join all ranks""" world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) if info_dict["batch_idx"] != 0: # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr try: dist.monitored_barrier(group=group_join, timeout=group_join.options._timeout) return False except RuntimeError as e: logger.info("Detected uneven workload distribution: {}\n".format(e) + "Break current worker to manually join all workers, " + "world_size {}, current rank {}, current local_rank {}\n". format(world_size, rank, local_rank)) return True else: return False def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None): """ Forward batch and compute loss""" device = int(os.environ.get('LOCAL_RANK', 0)) dtype = info_dict["dtype"] if dtype == "fp16": dtype = torch.float16 elif dtype == "bf16": dtype = torch.bfloat16 else: # fp32 dtype = torch.float32 if info_dict['train_engine'] == 'torch_ddp': autocast = torch.cuda.amp.autocast(enabled=scaler is not None) else: autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) with autocast: info_dict['loss_dict'] = model(batch, device) print('infor_dict loss_dict : ', info_dict['loss_dict']) if ref_model is not None and dpo_loss is not None: chosen_logps = info_dict['loss_dict']["chosen_logps"] rejected_logps = info_dict['loss_dict']["rejected_logps"] sft_loss = info_dict['loss_dict']['loss'] with torch.no_grad(): ref_loss_dict = ref_model(batch, device) reference_chosen_logps = ref_loss_dict["chosen_logps"] reference_rejected_logps = ref_loss_dict["rejected_logps"] preference_loss, chosen_reward, reject_reward = dpo_loss( chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps ) dpo_acc = (chosen_reward > reject_reward).float().mean() info_dict['loss_dict']["loss"] = preference_loss + sft_loss info_dict['loss_dict']["sft_loss"] = sft_loss info_dict['loss_dict']["dpo_loss"] = preference_loss info_dict['loss_dict']["dpo_acc"] = dpo_acc info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean() info_dict['loss_dict']["reject_reward"] = reject_reward.mean() return info_dict def batch_backward(model, scaler, info_dict): """Backward batch""" if info_dict["train_engine"] == "deepspeed": scaled_loss = model.backward(info_dict['loss_dict']['loss']) else: scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] if scaler is not None: scaler.scale(scaled_loss).backward() else: scaled_loss.backward() info_dict['loss_dict']['loss'] = scaled_loss return info_dict def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, model_type='llm'): """Update parameters and learning rate""" #Define key components based on model type if model_type == 'llm': component_patterns = { 'text_embedding': r'^text_embedding\.', 'text_encoder': r'^text_encoder\.', 'text_encoder_affine': r'^text_encoder_affine\.', 'llm_embedding': r'^llm_embedding\.', 'llm.model': r'^llm\.model\.', 'llm_decoder': r'^llm_decoder\.', 'speech_embedding': r'^speech_embedding\.', 'spk_embed_affine': r'^spk_embed_affine\.', } elif model_type == 'flow': component_patterns = { 'input_embedding': r'^input_embedding\.', 'spk_embed_affine': r'^spk_embed_affine\.', 'encoder': r'^encoder\.', 'encoder_proj': r'^encoder_proj\.', 'decoder.cfm': r'^decoder\..*cfm', 'decoder.unet': r'^decoder\..*unet', 'decoder.estimator': r'^decoder\..*estimator', 'decoder.time_embedding': r'^decoder\..*time_embedding', 'decoder.conv': r'^decoder\..*conv', 'decoder.attention': r'^decoder\..*attention', 'length_regulator': r'^length_regulator\.', } else: raise ValueError(f"Unknown model_type: {model_type}") key_components = {key: [] for key in component_patterns} key_components['other'] = [] grad_norm = 0.0 layer_grad_norms = {} if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: # logger.info('start to calculate grad norm') for name, param in model.named_parameters(): if param.grad is not None: # Calculate gradient norm for this parameter param_grad_norm = param.grad.data.norm(2).item() layer_grad_norms[name] = param_grad_norm # Categorize into key components categorized = False for component_key in key_components: if component_key != 'other': # Special handling for decoder sub-components in flow models if model_type == 'flow' and component_key.startswith('decoder.'): component_pattern = component_key.replace('decoder.', '') if 'decoder' in name and component_pattern in name: key_components[component_key].append((name, param_grad_norm)) categorized = True break elif component_key in name: key_components[component_key].append((name, param_grad_norm)) categorized = True break if not categorized: key_components['other'].append((name, param_grad_norm)) # Use mixed precision training if scaler is not None: scaler.unscale_(optimizer) grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): scaler.step(optimizer) else: logger.warning('get infinite grad_norm, check your code/data if it appears frequently') scaler.update() else: grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): optimizer.step() else: logger.warning('get infinite grad_norm, check your code/data if it appears frequently') optimizer.zero_grad() scheduler.step() logger.info(f"lr after step {optimizer.param_groups[0]['lr']}") info_dict["lr"] = optimizer.param_groups[0]['lr'] info_dict["grad_norm"] = grad_norm info_dict["layer_grad_norms"] = layer_grad_norms info_dict["key_component_grads"] = key_components return info_dict def log_per_step(experiment, info_dict): """Log per step using Comet ML""" tag = info_dict["tag"] epoch = info_dict.get('epoch', 0) step = info_dict["step"] batch_idx = info_dict["batch_idx"] loss_dict = info_dict['loss_dict'] rank = int(os.environ.get('RANK', 0)) # Only rank 0 writes to Comet ML to avoid multi-process write if experiment is not None and rank == 0: if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): # Log metrics to Comet ML experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1) experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1) experiment.log_metric(f'{tag}_grad_norm', info_dict['grad_norm'], step=step + 1) # Log all losses for k, v in loss_dict.items(): if isinstance(v, torch.Tensor): v = v.item() experiment.log_metric(f'{tag}_{k}', v, step=step + 1) # TRAIN & CV, Shell log (stdout) if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: log_str = f'{tag} Batch {epoch}/{batch_idx + 1} step {step} ' for name, value in loss_dict.items(): if isinstance(value, torch.Tensor): value = value.item() log_str += f'{name} {value:.6f} ' if tag == "TRAIN": log_str += f'lr {info_dict["lr"]:.15f} grad_norm {info_dict["grad_norm"]:.6f}' log_str += f' rank {rank}' logger.info(log_str) def log_per_save(experiment, info_dict): """Log per save using Comet ML""" tag = info_dict["tag"] epoch = info_dict["epoch"] step = info_dict["step"] loss_dict = info_dict["loss_dict"] lr = info_dict['lr'] rank = int(os.environ.get('RANK', 0)) # Create loss string for logger loss_str = ' '.join([f"{k} {v.item() if isinstance(v, torch.Tensor) else v}" for k, v in loss_dict.items()]) logger.info(f'Epoch {epoch} Step {step + 1} CV info lr {lr} {rank} {loss_str}') if experiment is not None and rank == 0: # Log metrics to Comet ML experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1) experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1) # Log all losses for k, v in loss_dict.items(): if isinstance(v, torch.Tensor): v = v.item() experiment.log_metric(f'{tag}_{k}', v, step=step + 1) # Log additional validation info if tag == "CV": # Calculate average CV loss for the epoch avg_loss = loss_dict.get('loss', 0) if isinstance(avg_loss, torch.Tensor): avg_loss = avg_loss.item() experiment.log_metric('cv_avg_loss_per_epoch', avg_loss, epoch=epoch)