| | |
| | |
| | |
| | |
| |
|
| | import collections |
| | import json |
| | import os |
| | import sys |
| | import time |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel |
| | from torch.utils.data import ConcatDataset, DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | from models.base.base_sampler import BatchSampler |
| | from utils.util import ( |
| | Logger, |
| | remove_older_ckpt, |
| | save_config, |
| | set_all_random_seed, |
| | ValueWindow, |
| | ) |
| |
|
| |
|
| | class BaseTrainer(object): |
| | def __init__(self, args, cfg): |
| | self.args = args |
| | self.log_dir = args.log_dir |
| | self.cfg = cfg |
| |
|
| | self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints") |
| | os.makedirs(self.checkpoint_dir, exist_ok=True) |
| | if not cfg.train.ddp or args.local_rank == 0: |
| | self.sw = SummaryWriter(os.path.join(args.log_dir, "events")) |
| | self.logger = self.build_logger() |
| | self.time_window = ValueWindow(50) |
| |
|
| | self.step = 0 |
| | self.epoch = -1 |
| | self.max_epochs = self.cfg.train.epochs |
| | self.max_steps = self.cfg.train.max_steps |
| |
|
| | |
| | set_all_random_seed(self.cfg.train.random_seed) |
| | if cfg.train.ddp: |
| | dist.init_process_group(backend="nccl") |
| |
|
| | if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]: |
| | self.singers = self.build_singers_lut() |
| |
|
| | |
| | self.data_loader = self.build_data_loader() |
| |
|
| | |
| | self.model = self.build_model() |
| | print(self.model) |
| |
|
| | if isinstance(self.model, dict): |
| | for key, value in self.model.items(): |
| | value.cuda(self.args.local_rank) |
| | if key == "PQMF": |
| | continue |
| | if cfg.train.ddp: |
| | self.model[key] = DistributedDataParallel( |
| | value, device_ids=[self.args.local_rank] |
| | ) |
| | else: |
| | self.model.cuda(self.args.local_rank) |
| | if cfg.train.ddp: |
| | self.model = DistributedDataParallel( |
| | self.model, device_ids=[self.args.local_rank] |
| | ) |
| |
|
| | |
| | self.criterion = self.build_criterion() |
| | if isinstance(self.criterion, dict): |
| | for key, value in self.criterion.items(): |
| | self.criterion[key].cuda(args.local_rank) |
| | else: |
| | self.criterion.cuda(self.args.local_rank) |
| |
|
| | |
| | self.optimizer = self.build_optimizer() |
| | self.scheduler = self.build_scheduler() |
| |
|
| | |
| | self.config_save_path = os.path.join(self.checkpoint_dir, "args.json") |
| |
|
| | def build_logger(self): |
| | log_file = os.path.join(self.checkpoint_dir, "train.log") |
| | logger = Logger(log_file, level=self.args.log_level).logger |
| |
|
| | return logger |
| |
|
| | def build_dataset(self): |
| | raise NotImplementedError |
| |
|
| | def build_data_loader(self): |
| | Dataset, Collator = self.build_dataset() |
| | |
| | datasets_list = [] |
| | for dataset in self.cfg.dataset: |
| | subdataset = Dataset(self.cfg, dataset, is_valid=False) |
| | datasets_list.append(subdataset) |
| | train_dataset = ConcatDataset(datasets_list) |
| |
|
| | train_collate = Collator(self.cfg) |
| | |
| | if self.cfg.train.ddp: |
| | raise NotImplementedError("DDP is not supported yet.") |
| |
|
| | |
| | batch_sampler = BatchSampler( |
| | cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list |
| | ) |
| |
|
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | collate_fn=train_collate, |
| | num_workers=self.args.num_workers, |
| | batch_sampler=batch_sampler, |
| | pin_memory=False, |
| | ) |
| | if not self.cfg.train.ddp or self.args.local_rank == 0: |
| | datasets_list = [] |
| | for dataset in self.cfg.dataset: |
| | subdataset = Dataset(self.cfg, dataset, is_valid=True) |
| | datasets_list.append(subdataset) |
| | valid_dataset = ConcatDataset(datasets_list) |
| | valid_collate = Collator(self.cfg) |
| | batch_sampler = BatchSampler( |
| | cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list |
| | ) |
| | valid_loader = DataLoader( |
| | valid_dataset, |
| | collate_fn=valid_collate, |
| | num_workers=1, |
| | batch_sampler=batch_sampler, |
| | ) |
| | else: |
| | raise NotImplementedError("DDP is not supported yet.") |
| | |
| | data_loader = {"train": train_loader, "valid": valid_loader} |
| | return data_loader |
| |
|
| | def build_singers_lut(self): |
| | |
| | if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)): |
| | singers = collections.OrderedDict() |
| | else: |
| | with open( |
| | os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r" |
| | ) as singer_file: |
| | singers = json.load(singer_file) |
| | singer_count = len(singers) |
| | for dataset in self.cfg.dataset: |
| | singer_lut_path = os.path.join( |
| | self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id |
| | ) |
| | with open(singer_lut_path, "r") as singer_lut_path: |
| | singer_lut = json.load(singer_lut_path) |
| | for singer in singer_lut.keys(): |
| | if singer not in singers: |
| | singers[singer] = singer_count |
| | singer_count += 1 |
| | with open( |
| | os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w" |
| | ) as singer_file: |
| | json.dump(singers, singer_file, indent=4, ensure_ascii=False) |
| | print( |
| | "singers have been dumped to {}".format( |
| | os.path.join(self.log_dir, self.cfg.preprocess.spk2id) |
| | ) |
| | ) |
| | return singers |
| |
|
| | def build_model(self): |
| | raise NotImplementedError() |
| |
|
| | def build_optimizer(self): |
| | raise NotImplementedError |
| |
|
| | def build_scheduler(self): |
| | raise NotImplementedError() |
| |
|
| | def build_criterion(self): |
| | raise NotImplementedError |
| |
|
| | def get_state_dict(self): |
| | raise NotImplementedError |
| |
|
| | def save_config_file(self): |
| | save_config(self.config_save_path, self.cfg) |
| |
|
| | |
| | def save_checkpoint(self, state_dict, saved_model_path): |
| | torch.save(state_dict, saved_model_path) |
| |
|
| | def load_checkpoint(self): |
| | checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint") |
| | assert os.path.exists(checkpoint_path) |
| | checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() |
| | model_path = os.path.join(self.checkpoint_dir, checkpoint_filename) |
| | assert os.path.exists(model_path) |
| | if not self.cfg.train.ddp or self.args.local_rank == 0: |
| | self.logger.info(f"Re(store) from {model_path}") |
| | checkpoint = torch.load(model_path, map_location="cpu") |
| | return checkpoint |
| |
|
| | def load_model(self, checkpoint): |
| | raise NotImplementedError |
| |
|
| | def restore(self): |
| | checkpoint = self.load_checkpoint() |
| | self.load_model(checkpoint) |
| |
|
| | def train_step(self, data): |
| | raise NotImplementedError( |
| | f"Need to implement function {sys._getframe().f_code.co_name} in " |
| | f"your sub-class of {self.__class__.__name__}. " |
| | ) |
| |
|
| | @torch.no_grad() |
| | def eval_step(self): |
| | raise NotImplementedError( |
| | f"Need to implement function {sys._getframe().f_code.co_name} in " |
| | f"your sub-class of {self.__class__.__name__}. " |
| | ) |
| |
|
| | def write_summary(self, losses, stats): |
| | raise NotImplementedError( |
| | f"Need to implement function {sys._getframe().f_code.co_name} in " |
| | f"your sub-class of {self.__class__.__name__}. " |
| | ) |
| |
|
| | def write_valid_summary(self, losses, stats): |
| | raise NotImplementedError( |
| | f"Need to implement function {sys._getframe().f_code.co_name} in " |
| | f"your sub-class of {self.__class__.__name__}. " |
| | ) |
| |
|
| | def echo_log(self, losses, mode="Training"): |
| | message = [ |
| | "{} - Epoch {} Step {}: [{:.3f} s/step]".format( |
| | mode, self.epoch + 1, self.step, self.time_window.average |
| | ) |
| | ] |
| |
|
| | for key in sorted(losses.keys()): |
| | if isinstance(losses[key], dict): |
| | for k, v in losses[key].items(): |
| | message.append( |
| | str(k).split("/")[-1] + "=" + str(round(float(v), 5)) |
| | ) |
| | else: |
| | message.append( |
| | str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) |
| | ) |
| | self.logger.info(", ".join(message)) |
| |
|
| | def eval_epoch(self): |
| | self.logger.info("Validation...") |
| | valid_losses = {} |
| | for i, batch_data in enumerate(self.data_loader["valid"]): |
| | for k, v in batch_data.items(): |
| | if isinstance(v, torch.Tensor): |
| | batch_data[k] = v.cuda() |
| | valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i) |
| | for key in valid_loss: |
| | if key not in valid_losses: |
| | valid_losses[key] = 0 |
| | valid_losses[key] += valid_loss[key] |
| |
|
| | |
| | |
| | for key in valid_losses: |
| | valid_losses[key] /= i + 1 |
| | self.echo_log(valid_losses, "Valid") |
| | return valid_losses, valid_stats |
| |
|
| | def train_epoch(self): |
| | for i, batch_data in enumerate(self.data_loader["train"]): |
| | start_time = time.time() |
| | |
| | for k, v in batch_data.items(): |
| | if isinstance(v, torch.Tensor): |
| | batch_data[k] = v.cuda(self.args.local_rank) |
| |
|
| | |
| | train_losses, train_stats, total_loss = self.train_step(batch_data) |
| | self.time_window.append(time.time() - start_time) |
| |
|
| | if self.args.local_rank == 0 or not self.cfg.train.ddp: |
| | if self.step % self.args.stdout_interval == 0: |
| | self.echo_log(train_losses, "Training") |
| |
|
| | if self.step % self.cfg.train.save_summary_steps == 0: |
| | self.logger.info(f"Save summary as step {self.step}") |
| | self.write_summary(train_losses, train_stats) |
| |
|
| | if ( |
| | self.step % self.cfg.train.save_checkpoints_steps == 0 |
| | and self.step != 0 |
| | ): |
| | saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format( |
| | self.step, total_loss |
| | ) |
| | saved_model_path = os.path.join( |
| | self.checkpoint_dir, saved_model_name |
| | ) |
| | saved_state_dict = self.get_state_dict() |
| | self.save_checkpoint(saved_state_dict, saved_model_path) |
| | self.save_config_file() |
| | |
| | remove_older_ckpt( |
| | saved_model_name, |
| | self.checkpoint_dir, |
| | max_to_keep=self.cfg.train.keep_checkpoint_max, |
| | ) |
| |
|
| | if self.step != 0 and self.step % self.cfg.train.valid_interval == 0: |
| | if isinstance(self.model, dict): |
| | for key in self.model.keys(): |
| | self.model[key].eval() |
| | else: |
| | self.model.eval() |
| | |
| | valid_losses, valid_stats = self.eval_epoch() |
| | if isinstance(self.model, dict): |
| | for key in self.model.keys(): |
| | self.model[key].train() |
| | else: |
| | self.model.train() |
| | |
| | self.write_valid_summary(valid_losses, valid_stats) |
| | self.step += 1 |
| |
|
| | def train(self): |
| | for epoch in range(max(0, self.epoch), self.max_epochs): |
| | self.train_epoch() |
| | self.epoch += 1 |
| | if self.step > self.max_steps: |
| | self.logger.info("Training finished!") |
| | break |
| |
|