Spaces:
Sleeping
Sleeping
File size: 5,722 Bytes
32d5b2b ba2c5eb 32d5b2b ba2c5eb 1c43d7b ba2c5eb 32d5b2b ba2c5eb 32d5b2b ba2c5eb 55ac664 ba2c5eb 32d5b2b ba2c5eb 55ac664 ba2c5eb a7dc8e9 ba2c5eb a7dc8e9 ba2c5eb 32d5b2b ca7dd21 32d5b2b 55ac664 ca7dd21 32d5b2b b6ec358 55ac664 ca7dd21 55ac664 ca7dd21 b6ec358 ca7dd21 ba2c5eb b6ec358 ca7dd21 b6ec358 ca7dd21 b6ec358 ca7dd21 5805255 ca7dd21 32d5b2b ca7dd21 32d5b2b a7dc8e9 ba2c5eb 32d5b2b ba2c5eb 32d5b2b ba2c5eb 32d5b2b ba2c5eb a7dc8e9 ba2c5eb 434855f ba2c5eb 434855f ba2c5eb 32d5b2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
from cosyvoice.utils.train_utils import (batch_backward, batch_forward,
log_per_save,
log_per_step, save_model,
update_parameter_and_lr)
from loguru import logger
class Executor:
"""Executor for training and cross validation"""
def __init__(
self,
gan: bool = False,
ref_model: torch.nn.Module = None,
dpo_loss: torch.nn.Module = None,
use_contrastive_fm: bool = False
):
self.gan = gan
self.ref_model = ref_model
self.dpo_loss = dpo_loss
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get("RANK", 0))
self.device = torch.device(f"cuda:{self.rank}")
self.use_contrastive_fm = use_contrastive_fm
def train_one_epoc(
self,
model,
optimizer,
scheduler,
train_data_loader,
experiment,
info_dict,
scaler,
model_type
):
"""Train one epoch"""
lr = optimizer.param_groups[0]["lr"]
logger.info(
f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}"
)
logger.info(
f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before"
)
model.train()
if self.ref_model is not None:
self.ref_model.eval()
use_ddp = info_dict["train_engine"] == "torch_ddp"
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
for key, value in batch_dict.items():
if isinstance(value, torch.Tensor):
print(f'{key} {value.shape}')
if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
else:
context = nullcontext
with context():
logger.info(f'{self.rank} batch_forward')
info_dict = batch_forward(
model,
batch_dict,
scaler,
info_dict,
ref_model=self.ref_model,
dpo_loss=self.dpo_loss,
)
logger.info(f'{self.rank} batch_backward')
info_dict = batch_backward(model, scaler, info_dict)
logger.info(f'{self.rank} update_parameter_and_lr')
info_dict = update_parameter_and_lr(
model, optimizer, scheduler, scaler, info_dict, model_type=model_type
)
logger.info(f'{self.rank} log_per_step')
log_per_step(experiment, info_dict)
if (
info_dict.get("save_per_step", -1) > 0
and (self.step) % info_dict["save_per_step"] == 0
and (batch_idx) % info_dict["accum_grad"] == 0
):
if dist.is_initialized():
dist.barrier()
model_name = (
f"epoch_{self.epoch}_step_{self.step + 1}"
)
save_model(model, model_name, info_dict)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
@torch.inference_mode()
def cv(self, model, cv_data_loader, experiment, info_dict, on_batch_end=True):
"""Cross validation on"""
logger.info(f"Epoch {self.epoch} Step {self.step + 1} on_batch_end {on_batch_end} CV rank {self.rank}")
model.eval()
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
num_utts = len(batch_dict["utts"])
total_num_utts += num_utts
if self.gan is True:
batch_dict["turn"] = "generator"
info_dict = batch_forward(model, batch_dict, None, info_dict)
for k, v in info_dict["loss_dict"].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts
info_dict["loss_dict"] = total_loss_dict
log_per_save(experiment, info_dict)
model_name = (
f"epoch_{self.epoch}_whole"
if on_batch_end
else f"epoch_{self.epoch}_step_{self.step + 1}"
)
save_model(model, model_name, info_dict)
|