| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| import datetime |
| import logging |
| import sys |
| from contextlib import nullcontext |
|
|
| |
| |
| import torch |
| from wenet.utils.common import StepTimer |
|
|
| from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward, |
| update_parameter_and_lr, log_per_step, |
| save_model) |
|
|
|
|
| class Executor: |
|
|
| def __init__(self, |
| global_step: int = 0, |
| device: torch.device = torch.device("cpu")): |
| self.step = global_step + 1 |
| self.train_step_timer = None |
| self.cv_step_timer = None |
| self.device = device |
|
|
| def train(self, model, optimizer, scheduler, train_data_loader, |
| cv_data_loader, writer, configs, scaler, group_join): |
| ''' Train one epoch |
| ''' |
| if self.train_step_timer is None: |
| self.train_step_timer = StepTimer(self.step) |
| model.train() |
| info_dict = copy.deepcopy(configs) |
| logging.info('using accumulate grad, new batch size is {} times' |
| ' larger than before'.format(info_dict['accum_grad'])) |
| |
| |
| |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| model_context = model.join |
| else: |
| model_context = nullcontext |
|
|
| with model_context(): |
| for batch_idx, batch_dict in enumerate(train_data_loader): |
| info_dict["tag"] = "TRAIN" |
| info_dict["step"] = self.step |
| info_dict["batch_idx"] = batch_idx |
| if wenet_join(group_join, info_dict): |
| break |
|
|
| if batch_dict["target_lengths"].size(0) == 0: |
| continue |
|
|
| context = None |
| |
| |
| |
| if info_dict.get("train_engine", "torch_ddp") in [ |
| "torch_ddp", "torch_fsdp" |
| ] and (batch_idx + 1) % info_dict["accum_grad"] != 0: |
| context = model.no_sync |
| |
| |
| else: |
| context = nullcontext |
|
|
| with context(): |
| info_dict = batch_forward(model, batch_dict, scaler, |
| info_dict, self.device) |
| info_dict = batch_backward(model, scaler, info_dict) |
|
|
| info_dict = update_parameter_and_lr(model, optimizer, |
| scheduler, scaler, |
| info_dict) |
| |
| log_per_step(writer, info_dict, timer=self.train_step_timer) |
| save_interval = info_dict.get('save_interval', sys.maxsize) |
| if (self.step + |
| 1) % save_interval == 0 and self.step != 0 and ( |
| batch_idx + 1) % info_dict["accum_grad"] == 0: |
| import torch.distributed as dist |
| |
| dist.barrier() |
| loss_dict = self.cv(model, cv_data_loader, configs) |
| model.train() |
| info_dict.update({ |
| "tag": |
| "step_{}".format(self.step), |
| "loss_dict": |
| loss_dict, |
| "save_time": |
| datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), |
| "lrs": |
| [group['lr'] for group in optimizer.param_groups] |
| }) |
| save_model(model, info_dict) |
| |
| log_per_step(writer, info_dict) |
| |
| dist.barrier() |
| self.step += 1 if (batch_idx + |
| 1) % info_dict["accum_grad"] == 0 else 0 |
|
|
| def cv(self, model, cv_data_loader, configs): |
| ''' Cross validation on |
| ''' |
| if self.cv_step_timer is None: |
| self.cv_step_timer = StepTimer(0.0) |
| else: |
| self.cv_step_timer.last_iteration = 0.0 |
| model.eval() |
| info_dict = copy.deepcopy(configs) |
| num_seen_utts, loss_dict, total_acc = 1, {}, [] |
| with torch.no_grad(): |
| for batch_idx, batch_dict in enumerate(cv_data_loader): |
| info_dict["tag"] = "CV" |
| info_dict["step"] = self.step |
| info_dict["batch_idx"] = batch_idx |
| info_dict["cv_step"] = batch_idx |
|
|
| num_utts = batch_dict["target_lengths"].size(0) |
| if num_utts == 0: |
| continue |
|
|
| info_dict = batch_forward(model, batch_dict, None, info_dict, |
| self.device) |
| _dict = info_dict["loss_dict"] |
|
|
| num_seen_utts += num_utts |
| total_acc.append(_dict['th_accuracy'].item( |
| ) if _dict.get('th_accuracy', None) is not None else 0.0) |
| for loss_name, loss_value in _dict.items(): |
| if loss_value is not None and "loss" in loss_name \ |
| and torch.isfinite(loss_value): |
| loss_value = loss_value.item() |
| loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ |
| loss_value * num_utts |
| |
| log_per_step(writer=None, |
| info_dict=info_dict, |
| timer=self.cv_step_timer) |
| for loss_name, loss_value in loss_dict.items(): |
| loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts |
| loss_dict["acc"] = sum(total_acc) / len(total_acc) |
| return loss_dict |
|
|