# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # 2023 Horizon Inc. (authors: Xingchen Song) # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import torch import json import re import datetime import yaml import deepspeed import torch.optim as optim import torch.distributed as dist from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from loguru import logger from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live from cosyvoice.dataset.dataset import Dataset from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR def init_distributed(args): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logger.info(f'training on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}') if args.train_engine == 'torch_ddp': torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend) else: deepspeed.init_distributed(dist_backend=args.dist_backend) return world_size, local_rank, rank def init_dataset_and_dataloader(args, configs, dpo): data_pipeline = configs['data_pipeline'] train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=True, partition=True) cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=False, partition=False) # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts train_data_loader = DataLoader(train_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) cv_data_loader = DataLoader(cv_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) return train_dataset, cv_dataset, train_data_loader, cv_data_loader def check_modify_and_save_config(args, configs): """Check and modify config""" if args.train_engine == "torch_ddp": configs['train_conf']["dtype"] = 'fp32' else: with open(args.deepspeed_config, 'r') as fin: ds_configs = json.load(fin) if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: configs['train_conf']["dtype"] = "fp16" elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: configs['train_conf']["dtype"] = "bf16" else: configs['train_conf']["dtype"] = "fp32" assert ds_configs["train_micro_batch_size_per_gpu"] == 1 # if use deepspeed, override ddp config configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] return configs def wrap_cuda_model(args, model): """Wrap model to cuda""" local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if args.train_engine == "torch_ddp": # native pytorch ddp assert (torch.cuda.is_available()) model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) else: if int(os.environ.get('RANK', 0)) == 0: logging.info("Estimating model states memory needs (zero2)...") estimate_zero2_model_states_mem_needs_all_live( model, num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) return model def init_optimizer_and_scheduler(configs, model): """Init optimizer and scheduler""" if configs['train_conf']['optim'] == 'adam': optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) elif configs['train_conf']['optim'] == 'adamw': optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) else: raise ValueError("unknown optimizer: " + configs['train_conf']) # Create schedulers warmup_scheduler = LinearLR( optimizer, start_factor=1e-9, # Start at nearly 0 end_factor=1.0, # End at base learning rate total_iters=5000 # 5k warmup steps ) constant_scheduler = ConstantLR( optimizer, factor=1.0, # Keep learning rate constant total_iters=float('inf') # Run indefinitely ) # Combine schedulers: warmup for 5k steps, then constant scheduler = SequentialLR( optimizer, schedulers=[warmup_scheduler, constant_scheduler], milestones=[5000] # Switch after 5k steps ) return model, optimizer, scheduler def save_model(model, model_name, info_dict): """Save model""" rank = int(os.environ.get('RANK', 0)) model_dir = info_dict["model_dir"] save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) if info_dict["train_engine"] == "torch_ddp": if rank == 0: torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) else: with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, tag=model_name, client_state=info_dict) if rank == 0: info_path = re.sub('.pt$', '.yaml', save_model_path) info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') with open(info_path, 'w') as fout: data = yaml.dump(info_dict) fout.write(data) logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) def cosyvoice_join(group_join, info_dict): """Join all ranks""" world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) if info_dict["batch_idx"] != 0: # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr try: dist.monitored_barrier(group=group_join, timeout=group_join.options._timeout) return False except RuntimeError as e: logging.info("Detected uneven workload distribution: {}\n".format(e) + "Break current worker to manually join all workers, " + "world_size {}, current rank {}, current local_rank {}\n". format(world_size, rank, local_rank)) return True else: return False def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None): """ Forward batch and compute loss""" device = int(os.environ.get('LOCAL_RANK', 0)) dtype = info_dict["dtype"] if dtype == "fp16": dtype = torch.float16 elif dtype == "bf16": dtype = torch.bfloat16 else: # fp32 dtype = torch.float32 if info_dict['train_engine'] == 'torch_ddp': autocast = torch.cuda.amp.autocast(enabled=scaler is not None) else: autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) with autocast: info_dict['loss_dict'] = model(batch, device) if ref_model is not None and dpo_loss is not None: chosen_logps = info_dict['loss_dict']["chosen_logps"] rejected_logps = info_dict['loss_dict']["rejected_logps"] sft_loss = info_dict['loss_dict']['loss'] with torch.no_grad(): ref_loss_dict = ref_model(batch, device) reference_chosen_logps = ref_loss_dict["chosen_logps"] reference_rejected_logps = ref_loss_dict["rejected_logps"] preference_loss, chosen_reward, reject_reward = dpo_loss( chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps ) dpo_acc = (chosen_reward > reject_reward).float().mean() info_dict['loss_dict']["loss"] = preference_loss + sft_loss info_dict['loss_dict']["sft_loss"] = sft_loss info_dict['loss_dict']["dpo_loss"] = preference_loss info_dict['loss_dict']["dpo_acc"] = dpo_acc info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean() info_dict['loss_dict']["reject_reward"] = reject_reward.mean() return info_dict def batch_backward(model, scaler, info_dict): """Backward batch""" if info_dict["train_engine"] == "deepspeed": scaled_loss = model.backward(info_dict['loss_dict']['loss']) else: scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] if scaler is not None: scaler.scale(scaled_loss).backward() else: scaled_loss.backward() info_dict['loss_dict']['loss'] = scaled_loss return info_dict def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, model_type='llm'): """Update parameters and learning rate""" #Define key components based on model type if model_type == 'llm': key_components = { # Text processing components 'text_embedding': [], 'text_encoder': [], 'text_encoder_affine': [], # LLM core components 'llm_embedding': [], 'llm.model': [], # Qwen2 model layers 'llm_decoder': [], # Speech components 'speech_embedding': [], 'spk_embed_affine': [], # Other components 'other': [] } elif model_type == 'flow': key_components = { # Input processing 'input_embedding': [], 'spk_embed_affine': [], # Encoder components 'encoder': [], 'encoder_proj': [], # Flow/Diffusion components 'decoder.cfm': [], # Conditional Flow Matching 'decoder.unet': [], # UNet backbone 'decoder.estimator': [], # Score/velocity estimator 'decoder.time_embedding': [], # Time embeddings 'decoder.conv': [], # Convolutional layers 'decoder.attention': [], # Attention layers # Length regulation 'length_regulator': [], # Other components 'other': [] } grad_norm = 0.0 layer_grad_norms = {} if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: for name, param in model.named_parameters(): if param.grad is not None: # Calculate gradient norm for this parameter param_grad_norm = param.grad.data.norm(2).item() layer_grad_norms[name] = param_grad_norm # Categorize into key components categorized = False for component_key in key_components: if component_key != 'other': # Special handling for decoder sub-components in flow models if model_type == 'flow' and component_key.startswith('decoder.'): component_pattern = component_key.replace('decoder.', '') if 'decoder' in name and component_pattern in name: key_components[component_key].append((name, param_grad_norm)) categorized = True break elif component_key in name: key_components[component_key].append((name, param_grad_norm)) categorized = True break if not categorized: key_components['other'].append((name, param_grad_norm)) # Use mixed precision training if scaler is not None: scaler.unscale_(optimizer) grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): scaler.step(optimizer) else: logging.warning('get infinite grad_norm, check your code/data if it appears frequently') scaler.update() else: grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): optimizer.step() else: logging.warning('get infinite grad_norm, check your code/data if it appears frequently') optimizer.zero_grad() scheduler.step() info_dict["lr"] = optimizer.param_groups[0]['lr'] info_dict["grad_norm"] = grad_norm info_dict["layer_grad_norms"] = layer_grad_norms info_dict["key_component_grads"] = key_components return info_dict def log_per_step(experiment, info_dict): """Log per step using Comet ML""" tag = info_dict["tag"] epoch = info_dict.get('epoch', 0) step = info_dict["step"] batch_idx = info_dict["batch_idx"] loss_dict = info_dict['loss_dict'] rank = int(os.environ.get('RANK', 0)) # Only rank 0 writes to Comet ML to avoid multi-process write if experiment is not None and rank == 0: if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): # Log metrics to Comet ML experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1) experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1) experiment.log_metric(f'{tag}_grad_norm', info_dict['grad_norm'], step=step + 1) # Log all losses for k, v in loss_dict.items(): if isinstance(v, torch.Tensor): v = v.item() experiment.log_metric(f'{tag}_{k}', v, step=step + 1) # TRAIN & CV, Shell log (stdout) if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: log_str = f'{tag} Batch {epoch}/{batch_idx + 1} ' for name, value in loss_dict.items(): if isinstance(value, torch.Tensor): value = value.item() log_str += f'{name} {value:.6f} ' if tag == "TRAIN": log_str += f'lr {info_dict["lr"]:.8f} grad_norm {info_dict["grad_norm"]:.6f}' log_str += f' rank {rank}' logging.debug(log_str) def log_per_save(experiment, info_dict): """Log per save using Comet ML""" tag = info_dict["tag"] epoch = info_dict["epoch"] step = info_dict["step"] loss_dict = info_dict["loss_dict"] lr = info_dict['lr'] rank = int(os.environ.get('RANK', 0)) # Create loss string for logging loss_str = ' '.join([f"{k} {v.item() if isinstance(v, torch.Tensor) else v}" for k, v in loss_dict.items()]) logger.info(f'Epoch {epoch} Step {step + 1} CV info lr {lr} {rank} {loss_str}') if experiment is not None and rank == 0: # Log metrics to Comet ML experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1) experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1) # Log all losses for k, v in loss_dict.items(): if isinstance(v, torch.Tensor): v = v.item() experiment.log_metric(f'{tag}_{k}', v, step=step + 1) # Log additional validation info if tag == "CV": # Calculate average CV loss for the epoch avg_loss = loss_dict.get('loss', 0) if isinstance(avg_loss, torch.Tensor): avg_loss = avg_loss.item() experiment.log_metric('cv_avg_loss_per_epoch', avg_loss, epoch=epoch)