Ubuntu
update llm training
b6ec358
raw
history blame
5.74 kB
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
from cosyvoice.utils.train_utils import (batch_backward, batch_forward,
cosyvoice_join, log_per_save,
log_per_step, save_model,
update_parameter_and_lr)
from loguru import logger
class Executor:
"""Executor for training and cross validation"""
def __init__(
self,
gan: bool = False,
ref_model: torch.nn.Module = None,
dpo_loss: torch.nn.Module = None,
use_contrastive_fm: bool = False
):
self.gan = gan
self.ref_model = ref_model
self.dpo_loss = dpo_loss
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get("RANK", 0))
self.device = torch.device(f"cuda:{self.rank}")
self.use_contrastive_fm = use_contrastive_fm
def train_one_epoc(
self,
model,
optimizer,
scheduler,
train_data_loader,
experiment,
info_dict,
scaler,
model_type
):
"""Train one epoch"""
lr = optimizer.param_groups[0]["lr"]
logger.info(
f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}"
)
logger.info(
f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before"
)
model.train()
if self.ref_model is not None:
self.ref_model.eval()
use_ddp = info_dict["train_engine"] == "torch_ddp"
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
for key, value in batch_dict.items():
if isinstance(value, torch.Tensor):
print(f'{key} {value.shape}')
if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
else:
context = nullcontext
with context():
logger.info(f'{self.rank} batch_forward')
info_dict = batch_forward(
model,
batch_dict,
scaler,
info_dict,
ref_model=self.ref_model,
dpo_loss=self.dpo_loss,
)
logger.info(f'{self.rank} batch_backward')
info_dict = batch_backward(model, scaler, info_dict)
logger.info(f'{self.rank} update_parameter_and_lr')
info_dict = update_parameter_and_lr(
model, optimizer, scheduler, scaler, info_dict, model_type=model_type
)
logger.info(f'{self.rank} log_per_step')
log_per_step(experiment, info_dict)
if (
info_dict.get("save_per_step", -1) > 0
and (self.step) % info_dict["save_per_step"] == 0
and (batch_idx) % info_dict["accum_grad"] == 0
):
if dist.is_initialized():
dist.barrier()
model_name = (
f"epoch_{self.epoch}_step_{self.step + 1}"
)
save_model(model, model_name, info_dict)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
@torch.inference_mode()
def cv(self, model, cv_data_loader, experiment, info_dict, on_batch_end=True):
"""Cross validation on"""
logger.info(f"Epoch {self.epoch} Step {self.step + 1} on_batch_end {on_batch_end} CV rank {self.rank}")
model.eval()
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
num_utts = len(batch_dict["utts"])
total_num_utts += num_utts
if self.gan is True:
batch_dict["turn"] = "generator"
info_dict = batch_forward(model, batch_dict, None, info_dict)
for k, v in info_dict["loss_dict"].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts
info_dict["loss_dict"] = total_loss_dict
log_per_save(experiment, info_dict)
model_name = (
f"epoch_{self.epoch}_whole"
if on_batch_end
else f"epoch_{self.epoch}_step_{self.step + 1}"
)
save_model(model, model_name, info_dict)