Ubuntu
update llm training
b6ec358
raw
history blame
21.5 kB
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Horizon Inc. (authors: Xingchen Song)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import json
import re
import datetime
import yaml
import deepspeed
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from loguru import logger
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
from cosyvoice.dataset.dataset import Dataset
from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR, _LRScheduler
from loguru import logger
class ResumableSequentialLR(_LRScheduler):
"""A resumable version of SequentialLR that properly manages child schedulers"""
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
"""
Args:
optimizer: Wrapped optimizer
schedulers: List of schedulers to sequentially use
milestones: List of epoch/step numbers when to switch schedulers
last_epoch: The index of last epoch/step
"""
# Validate inputs
if len(schedulers) != len(milestones) + 1:
raise ValueError("Expected len(schedulers) == len(milestones) + 1")
self.schedulers = schedulers
self.milestones = milestones
self._scheduler_idx = 0
# Initialize parent class (this sets last_epoch and calls step())
super().__init__(optimizer, last_epoch)
def _get_scheduler_info(self, epoch):
"""Determine which scheduler to use and its relative epoch"""
scheduler_idx = 0
relative_epoch = epoch
for i, milestone in enumerate(self.milestones):
if epoch >= milestone:
scheduler_idx = i + 1
if i == 0:
relative_epoch = epoch - milestone
else:
relative_epoch = epoch - milestone
else:
break
# Calculate relative epoch for the current scheduler
if scheduler_idx == 0:
relative_epoch = epoch
elif scheduler_idx < len(self.milestones):
if scheduler_idx == 1:
relative_epoch = epoch - self.milestones[0]
else:
relative_epoch = epoch - self.milestones[scheduler_idx - 1]
return scheduler_idx, relative_epoch
def get_lr(self):
"""Get learning rate from the appropriate scheduler"""
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
# Get current scheduler and its relative epoch
scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)
scheduler = self.schedulers[scheduler_idx]
# Set the scheduler's last_epoch to match relative progress
scheduler.last_epoch = relative_epoch
# Get LR from the scheduler
if hasattr(scheduler, '_get_closed_form_lr'):
return scheduler._get_closed_form_lr()
else:
# Temporarily set the flag to avoid warning from child scheduler
scheduler._get_lr_called_within_step = True
lrs = scheduler.get_lr()
scheduler._get_lr_called_within_step = False
return lrs
def step(self, epoch=None):
"""Step the scheduler"""
# Step the parent class (updates last_epoch and sets _get_lr_called_within_step)
super().step(epoch)
def set_step(self, step):
"""Set the current step for resuming training"""
self.last_epoch = step - 1
# Update child schedulers' state
scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)
# Set all previous schedulers to their final state
for i in range(scheduler_idx):
if i < len(self.milestones):
if i == 0:
self.schedulers[i].last_epoch = self.milestones[i] - 1
else:
self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1
# Set current scheduler to its relative position
self.schedulers[scheduler_idx].last_epoch = relative_epoch
# Update optimizer's learning rates
for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):
param_group['lr'] = lr
def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logger.info(f'training on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
if args.train_engine == 'torch_ddp':
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
else:
deepspeed.init_distributed(dist_backend=args.dist_backend)
return world_size, local_rank, rank
def init_dataset_and_dataloader(args, configs, dpo):
data_pipeline = configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=dpo, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
def check_modify_and_save_config(args, configs):
"""Check and modify config"""
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
configs['train_conf']["dtype"] = "fp16"
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
configs['train_conf']["dtype"] = "bf16"
else:
configs['train_conf']["dtype"] = "fp32"
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
# if use deepspeed, override ddp config
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
return configs
def wrap_cuda_model(args, model):
"""Wrap model to cuda"""
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if args.train_engine == "torch_ddp": # native pytorch ddp
assert (torch.cuda.is_available())
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
else:
if int(os.environ.get('RANK', 0)) == 0:
logger.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
return model
def init_optimizer_and_scheduler(configs, model):
"""Init optimizer and scheduler"""
lr = configs['train_conf']['optim_conf']['lr']
logger.info(f"lr base: {lr}")
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), lr=lr)
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), lr=lr)
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
warm_up_steps = configs['train_conf']['scheduler_conf']['warmup_steps']
total_iters = configs['train_conf']['total_iters']
# Create schedulers
warmup_scheduler = LinearLR(
optimizer,
start_factor=1e-4, # Start at nearly 0
end_factor=1.0, # End at base learning rate
total_iters=warm_up_steps # 5k warmup steps
)
constant_scheduler = ConstantLR(
optimizer,
factor=1.0, # Keep learning rate constant
total_iters=total_iters # Run indefinitely
)
# Combine schedulers: warmup for 5k steps, then constant
scheduler = ResumableSequentialLR(
optimizer,
schedulers=[warmup_scheduler, constant_scheduler],
milestones=[warm_up_steps]
)
return model, optimizer, scheduler
def save_model(model, model_name, info_dict):
"""Save model"""
rank = int(os.environ.get('RANK', 0))
model_dir = info_dict["model_dir"]
os.makedirs(model_dir, exist_ok=True)
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if rank == 0:
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
logger.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
def cosyvoice_join(group_join, info_dict):
"""Join all ranks"""
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
if info_dict["batch_idx"] != 0:
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
try:
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
return False
except RuntimeError as e:
logger.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}\n".
format(world_size, rank, local_rank))
return True
else:
return False
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
""" Forward batch and compute loss"""
device = int(os.environ.get('LOCAL_RANK', 0))
dtype = info_dict["dtype"]
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
else: # fp32
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
with autocast:
info_dict['loss_dict'] = model(batch, device)
print('infor_dict loss_dict : ', info_dict['loss_dict'])
if ref_model is not None and dpo_loss is not None:
chosen_logps = info_dict['loss_dict']["chosen_logps"]
rejected_logps = info_dict['loss_dict']["rejected_logps"]
sft_loss = info_dict['loss_dict']['loss']
with torch.no_grad():
ref_loss_dict = ref_model(batch, device)
reference_chosen_logps = ref_loss_dict["chosen_logps"]
reference_rejected_logps = ref_loss_dict["rejected_logps"]
preference_loss, chosen_reward, reject_reward = dpo_loss(
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
)
dpo_acc = (chosen_reward > reject_reward).float().mean()
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
info_dict['loss_dict']["sft_loss"] = sft_loss
info_dict['loss_dict']["dpo_loss"] = preference_loss
info_dict['loss_dict']["dpo_acc"] = dpo_acc
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
return info_dict
def batch_backward(model, scaler, info_dict):
"""Backward batch"""
if info_dict["train_engine"] == "deepspeed":
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
else:
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
if scaler is not None:
scaler.scale(scaled_loss).backward()
else:
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, model_type='llm'):
"""Update parameters and learning rate"""
#Define key components based on model type
if model_type == 'llm':
component_patterns = {
'text_embedding': r'^text_embedding\.',
'text_encoder': r'^text_encoder\.',
'text_encoder_affine': r'^text_encoder_affine\.',
'llm_embedding': r'^llm_embedding\.',
'llm.model': r'^llm\.model\.',
'llm_decoder': r'^llm_decoder\.',
'speech_embedding': r'^speech_embedding\.',
'spk_embed_affine': r'^spk_embed_affine\.',
}
elif model_type == 'flow':
component_patterns = {
'input_embedding': r'^input_embedding\.',
'spk_embed_affine': r'^spk_embed_affine\.',
'encoder': r'^encoder\.',
'encoder_proj': r'^encoder_proj\.',
'decoder.cfm': r'^decoder\..*cfm',
'decoder.unet': r'^decoder\..*unet',
'decoder.estimator': r'^decoder\..*estimator',
'decoder.time_embedding': r'^decoder\..*time_embedding',
'decoder.conv': r'^decoder\..*conv',
'decoder.attention': r'^decoder\..*attention',
'length_regulator': r'^length_regulator\.',
}
else:
raise ValueError(f"Unknown model_type: {model_type}")
key_components = {key: [] for key in component_patterns}
key_components['other'] = []
grad_norm = 0.0
layer_grad_norms = {}
if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
# logger.info('start to calculate grad norm')
for name, param in model.named_parameters():
if param.grad is not None:
# Calculate gradient norm for this parameter
param_grad_norm = param.grad.data.norm(2).item()
layer_grad_norms[name] = param_grad_norm
# Categorize into key components
categorized = False
for component_key in key_components:
if component_key != 'other':
# Special handling for decoder sub-components in flow models
if model_type == 'flow' and component_key.startswith('decoder.'):
component_pattern = component_key.replace('decoder.', '')
if 'decoder' in name and component_pattern in name:
key_components[component_key].append((name, param_grad_norm))
categorized = True
break
elif component_key in name:
key_components[component_key].append((name, param_grad_norm))
categorized = True
break
if not categorized:
key_components['other'].append((name, param_grad_norm))
# Use mixed precision training
if scaler is not None:
scaler.unscale_(optimizer)
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
scaler.step(optimizer)
else:
logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
else:
logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
optimizer.zero_grad()
scheduler.step()
logger.info(f"lr after step {optimizer.param_groups[0]['lr']}")
info_dict["lr"] = optimizer.param_groups[0]['lr']
info_dict["grad_norm"] = grad_norm
info_dict["layer_grad_norms"] = layer_grad_norms
info_dict["key_component_grads"] = key_components
return info_dict
def log_per_step(experiment, info_dict):
"""Log per step using Comet ML"""
tag = info_dict["tag"]
epoch = info_dict.get('epoch', 0)
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
loss_dict = info_dict['loss_dict']
rank = int(os.environ.get('RANK', 0))
# Only rank 0 writes to Comet ML to avoid multi-process write
if experiment is not None and rank == 0:
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
# Log metrics to Comet ML
experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1)
experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1)
experiment.log_metric(f'{tag}_grad_norm', info_dict['grad_norm'], step=step + 1)
# Log all losses
for k, v in loss_dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
experiment.log_metric(f'{tag}_{k}', v, step=step + 1)
# TRAIN & CV, Shell log (stdout)
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
log_str = f'{tag} Batch {epoch}/{batch_idx + 1} step {step} '
for name, value in loss_dict.items():
if isinstance(value, torch.Tensor):
value = value.item()
log_str += f'{name} {value:.6f} '
if tag == "TRAIN":
log_str += f'lr {info_dict["lr"]:.15f} grad_norm {info_dict["grad_norm"]:.6f}'
log_str += f' rank {rank}'
logger.info(log_str)
def log_per_save(experiment, info_dict):
"""Log per save using Comet ML"""
tag = info_dict["tag"]
epoch = info_dict["epoch"]
step = info_dict["step"]
loss_dict = info_dict["loss_dict"]
lr = info_dict['lr']
rank = int(os.environ.get('RANK', 0))
# Create loss string for logger
loss_str = ' '.join([f"{k} {v.item() if isinstance(v, torch.Tensor) else v}" for k, v in loss_dict.items()])
logger.info(f'Epoch {epoch} Step {step + 1} CV info lr {lr} {rank} {loss_str}')
if experiment is not None and rank == 0:
# Log metrics to Comet ML
experiment.log_metric(f'{tag}_epoch', info_dict['epoch'], step=step + 1)
experiment.log_metric(f'{tag}_lr', info_dict['lr'], step=step + 1)
# Log all losses
for k, v in loss_dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
experiment.log_metric(f'{tag}_{k}', v, step=step + 1)
# Log additional validation info
if tag == "CV":
# Calculate average CV loss for the epoch
avg_loss = loss_dict.get('loss', 0)
if isinstance(avg_loss, torch.Tensor):
avg_loss = avg_loss.item()
experiment.log_metric('cv_avg_loss_per_epoch', avg_loss, epoch=epoch)