| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import random |
| | import re |
| | import time |
| | from abc import abstractmethod |
| | from pathlib import Path |
| |
|
| | import accelerate |
| | import json5 |
| | import numpy as np |
| | import torch |
| | from accelerate.logging import get_logger |
| | from torch.utils.data import DataLoader |
| |
|
| | from models.vocoders.vocoder_inference import synthesis |
| | from utils.io import save_audio |
| | from utils.util import load_config |
| | from utils.audio_slicer import is_silence |
| |
|
| | EPS = 1.0e-12 |
| |
|
| |
|
| | class BaseInference(object): |
| | def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
| | super().__init__() |
| |
|
| | start = time.monotonic_ns() |
| | self.args = args |
| | self.cfg = cfg |
| |
|
| | assert infer_type in ["from_dataset", "from_file"] |
| | self.infer_type = infer_type |
| |
|
| | |
| | self.accelerator = accelerate.Accelerator() |
| | self.accelerator.wait_for_everyone() |
| |
|
| | |
| | with self.accelerator.main_process_first(): |
| | self.logger = get_logger("inference", log_level=args.log_level) |
| |
|
| | |
| | self.logger.info("=" * 56) |
| | self.logger.info("||\t\t" + "New inference process started." + "\t\t||") |
| | self.logger.info("=" * 56) |
| | self.logger.info("\n") |
| | self.logger.debug(f"Using {args.log_level.upper()} logging level.") |
| |
|
| | self.acoustics_dir = args.acoustics_dir |
| | self.logger.debug(f"Acoustic dir: {args.acoustics_dir}") |
| | self.vocoder_dir = args.vocoder_dir |
| | self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") |
| | |
| | |
| | |
| | |
| | |
| |
|
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | |
| | with self.accelerator.main_process_first(): |
| | start = time.monotonic_ns() |
| | self._set_random_seed(self.cfg.train.random_seed) |
| | end = time.monotonic_ns() |
| | self.logger.debug( |
| | f"Setting random seed done in {(end - start) / 1e6:.2f}ms" |
| | ) |
| | self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") |
| |
|
| | |
| | with self.accelerator.main_process_first(): |
| | self.logger.info("Building dataset...") |
| | start = time.monotonic_ns() |
| | self.test_dataloader = self._build_dataloader() |
| | end = time.monotonic_ns() |
| | self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") |
| |
|
| | |
| | with self.accelerator.main_process_first(): |
| | self.logger.info("Building model...") |
| | start = time.monotonic_ns() |
| | self.model = self._build_model() |
| | end = time.monotonic_ns() |
| | |
| | self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") |
| |
|
| | |
| | self.logger.info("Initializing accelerate...") |
| | start = time.monotonic_ns() |
| | self.accelerator = accelerate.Accelerator() |
| | self.model = self.accelerator.prepare(self.model) |
| | end = time.monotonic_ns() |
| | self.accelerator.wait_for_everyone() |
| | self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") |
| |
|
| | with self.accelerator.main_process_first(): |
| | self.logger.info("Loading checkpoint...") |
| | start = time.monotonic_ns() |
| | |
| | self.__load_model(os.path.join(args.acoustics_dir, "checkpoint")) |
| | end = time.monotonic_ns() |
| | self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") |
| |
|
| | self.model.eval() |
| | self.accelerator.wait_for_everyone() |
| |
|
| | |
| | @abstractmethod |
| | def _build_test_dataset(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def _build_model(self): |
| | pass |
| |
|
| | @abstractmethod |
| | @torch.inference_mode() |
| | def _inference_each_batch(self, batch_data): |
| | pass |
| |
|
| | |
| |
|
| | @torch.inference_mode() |
| | def inference(self): |
| | for i, batch in enumerate(self.test_dataloader): |
| | y_pred = self._inference_each_batch(batch).cpu() |
| | mel_min, mel_max = self.test_dataset.target_mel_extrema |
| | y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min |
| | y_ls = y_pred.chunk(self.test_batch_size) |
| | tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size) |
| | j = 0 |
| | for it, l in zip(y_ls, tgt_ls): |
| | l = l.item() |
| | it = it.squeeze(0)[:l] |
| | uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] |
| | torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) |
| | j += 1 |
| |
|
| | vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) |
| |
|
| | res = synthesis( |
| | cfg=vocoder_cfg, |
| | vocoder_weight_file=vocoder_ckpt, |
| | n_samples=None, |
| | pred=[ |
| | torch.load( |
| | os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"])) |
| | ).numpy(force=True) |
| | for i in self.test_dataset.metadata |
| | ], |
| | ) |
| |
|
| | output_audio_files = [] |
| | for it, wav in zip(self.test_dataset.metadata, res): |
| | uid = it["Uid"] |
| | file = os.path.join(self.args.output_dir, f"{uid}.wav") |
| | output_audio_files.append(file) |
| |
|
| | wav = wav.numpy(force=True) |
| | save_audio( |
| | file, |
| | wav, |
| | self.cfg.preprocess.sample_rate, |
| | add_silence=False, |
| | turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), |
| | ) |
| | os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) |
| |
|
| | return sorted(output_audio_files) |
| |
|
| | |
| | def _build_dataloader(self): |
| | datasets, collate = self._build_test_dataset() |
| | self.test_dataset = datasets(self.args, self.cfg, self.infer_type) |
| | self.test_collate = collate(self.cfg) |
| | self.test_batch_size = min( |
| | self.cfg.train.batch_size, len(self.test_dataset.metadata) |
| | ) |
| | test_dataloader = DataLoader( |
| | self.test_dataset, |
| | collate_fn=self.test_collate, |
| | num_workers=1, |
| | batch_size=self.test_batch_size, |
| | shuffle=False, |
| | ) |
| | return test_dataloader |
| |
|
| | def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None): |
| | r"""Load model from checkpoint. If checkpoint_path is None, it will |
| | load the latest checkpoint in checkpoint_dir. If checkpoint_path is not |
| | None, it will load the checkpoint specified by checkpoint_path. **Only use this |
| | method after** ``accelerator.prepare()``. |
| | """ |
| | if checkpoint_path is None: |
| | ls = [] |
| | for i in Path(checkpoint_dir).iterdir(): |
| | if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)): |
| | ls.append(i) |
| | ls.sort( |
| | key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True |
| | ) |
| | checkpoint_path = ls[0] |
| | else: |
| | checkpoint_path = Path(checkpoint_path) |
| | self.accelerator.load_state(str(checkpoint_path)) |
| | |
| | self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1]) |
| | self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1]) |
| | return str(checkpoint_path) |
| |
|
| | @staticmethod |
| | def _set_random_seed(seed): |
| | r"""Set random seed for all possible random modules.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.random.manual_seed(seed) |
| |
|
| | @staticmethod |
| | def _parse_vocoder(vocoder_dir): |
| | r"""Parse vocoder config""" |
| | vocoder_dir = os.path.abspath(vocoder_dir) |
| | ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] |
| | ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) |
| | ckpt_path = str(ckpt_list[0]) |
| | vocoder_cfg = load_config( |
| | os.path.join(vocoder_dir, "args.json"), lowercase=True |
| | ) |
| | return vocoder_cfg, ckpt_path |
| |
|
| | @staticmethod |
| | def __count_parameters(model): |
| | return sum(p.numel() for p in model.parameters()) |
| |
|
| | def __dump_cfg(self, path): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | json5.dump( |
| | self.cfg, |
| | open(path, "w"), |
| | indent=4, |
| | sort_keys=True, |
| | ensure_ascii=False, |
| | quote_keys=True, |
| | ) |
| |
|