Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import json | |
| import os | |
| import time | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from models.svc.base import SVCInference | |
| from models.svc.vits.vits import SynthesizerTrn | |
| from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator | |
| from utils.io import save_audio | |
| from utils.audio_slicer import is_silence | |
| class VitsInference(SVCInference): | |
| def __init__(self, args=None, cfg=None, infer_type="from_dataset"): | |
| SVCInference.__init__(self, args, cfg) | |
| def _build_model(self): | |
| net_g = SynthesizerTrn( | |
| self.cfg.preprocess.n_fft // 2 + 1, | |
| self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, | |
| self.cfg, | |
| ) | |
| self.model = net_g | |
| return net_g | |
| def build_save_dir(self, dataset, speaker): | |
| save_dir = os.path.join( | |
| self.args.output_dir, | |
| "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode), | |
| ) | |
| if dataset is not None: | |
| save_dir = os.path.join(save_dir, "data_{}".format(dataset)) | |
| if speaker != -1: | |
| save_dir = os.path.join( | |
| save_dir, | |
| "spk_{}".format(speaker), | |
| ) | |
| os.makedirs(save_dir, exist_ok=True) | |
| print("Saving to ", save_dir) | |
| return save_dir | |
| def inference(self): | |
| res = [] | |
| for i, batch in enumerate(self.test_dataloader): | |
| pred_audio_list = self._inference_each_batch(batch) | |
| for it, wav in zip(self.test_dataset.metadata, pred_audio_list): | |
| uid = it["Uid"] | |
| file = os.path.join(self.args.output_dir, f"{uid}.wav") | |
| wav = wav.numpy(force=True) | |
| save_audio( | |
| file, | |
| wav, | |
| self.cfg.preprocess.sample_rate, | |
| add_silence=False, | |
| turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), | |
| ) | |
| res.append(file) | |
| return res | |
| def _inference_each_batch(self, batch_data, noise_scale=0.667): | |
| device = self.accelerator.device | |
| pred_res = [] | |
| self.model.eval() | |
| with torch.no_grad(): | |
| # Put the data to device | |
| # device = self.accelerator.device | |
| for k, v in batch_data.items(): | |
| batch_data[k] = v.to(device) | |
| audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale) | |
| pred_res.extend(audios) | |
| return pred_res | |