Spaces:
No application file
No application file
| import argparse | |
| import json | |
| import math | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from fish_audio_preprocess.utils import loudness_norm | |
| from loguru import logger | |
| from mmengine import Config | |
| from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS | |
| from fish_diffusion.utils.tensor import repeat_expand | |
| from train import FishDiffusion | |
| def inference( | |
| config, | |
| checkpoint, | |
| input_path, | |
| output_path, | |
| dictionary_path="dictionaries/opencpop-strict.txt", | |
| speaker_id=0, | |
| sampler_interval=None, | |
| sampler_progress=False, | |
| device="cuda", | |
| ): | |
| """Inference | |
| Args: | |
| config: config | |
| checkpoint: checkpoint path | |
| input_path: input path | |
| output_path: output path | |
| dictionary_path: dictionary path | |
| speaker_id: speaker id | |
| sampler_interval: sampler interval, lower value means higher quality | |
| sampler_progress: show sampler progress | |
| device: device | |
| """ | |
| if sampler_interval is not None: | |
| config.model.diffusion.sampler_interval = sampler_interval | |
| if os.path.isdir(checkpoint): | |
| # Find the latest checkpoint | |
| checkpoints = sorted(os.listdir(checkpoint)) | |
| logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}") | |
| checkpoint = os.path.join(checkpoint, checkpoints[-1]) | |
| # Load models | |
| phoneme_features_extractor = FEATURE_EXTRACTORS.build( | |
| config.preprocessing.phoneme_features_extractor | |
| ).to(device) | |
| phoneme_features_extractor.eval() | |
| model = FishDiffusion(config) | |
| state_dict = torch.load(checkpoint, map_location="cpu") | |
| if "state_dict" in state_dict: # Checkpoint is saved by pl | |
| state_dict = state_dict["state_dict"] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor) | |
| assert pitch_extractor is not None, "Pitch extractor not found" | |
| # Load dictionary | |
| phones_list = [] | |
| for i in open(dictionary_path): | |
| _, phones = i.strip().split("\t") | |
| for j in phones.split(): | |
| if j not in phones_list: | |
| phones_list.append(j) | |
| phones_list = ["<PAD>", "<EOS>", "<UNK>", "AP", "SP"] + sorted(phones_list) | |
| # Load ds file | |
| with open(input_path) as f: | |
| ds = json.load(f) | |
| generated_audio = np.zeros( | |
| math.ceil( | |
| ( | |
| float(ds[-1]["offset"]) | |
| + float(ds[-1]["f0_timestep"]) * len(ds[-1]["f0_seq"].split(" ")) | |
| ) | |
| * config.sampling_rate | |
| ) | |
| ) | |
| for idx, chunk in enumerate(ds): | |
| offset = float(chunk["offset"]) | |
| phones = np.array([phones_list.index(i) for i in chunk["ph_seq"].split(" ")]) | |
| durations = np.array([0] + [float(i) for i in chunk["ph_dur"].split(" ")]) | |
| durations = np.cumsum(durations) | |
| f0_timestep = float(chunk["f0_timestep"]) | |
| f0_seq = torch.FloatTensor([float(i) for i in chunk["f0_seq"].split(" ")]) | |
| f0_seq *= 2 ** (6 / 12) | |
| total_duration = f0_timestep * len(f0_seq) | |
| logger.info( | |
| f"Processing segment {idx + 1}/{len(ds)}, duration: {total_duration:.2f}s" | |
| ) | |
| n_mels = round(total_duration * config.sampling_rate / 512) | |
| f0_seq = repeat_expand(f0_seq, n_mels, mode="linear") | |
| f0_seq = f0_seq.to(device) | |
| # aligned is in 20ms | |
| aligned_phones = torch.zeros(int(total_duration * 50), dtype=torch.long) | |
| for i, phone in enumerate(phones): | |
| start = int(durations[i] / f0_timestep / 4) | |
| end = int(durations[i + 1] / f0_timestep / 4) | |
| aligned_phones[start:end] = phone | |
| # Extract text features | |
| phoneme_features = phoneme_features_extractor.forward( | |
| aligned_phones.to(device) | |
| )[0] | |
| phoneme_features = repeat_expand(phoneme_features, n_mels).T | |
| # Predict | |
| src_lens = torch.tensor([phoneme_features.shape[0]]).to(device) | |
| features = model.model.forward_features( | |
| speakers=torch.tensor([speaker_id]).long().to(device), | |
| contents=phoneme_features[None].to(device), | |
| src_lens=src_lens, | |
| max_src_len=max(src_lens), | |
| mel_lens=src_lens, | |
| max_mel_len=max(src_lens), | |
| pitches=f0_seq[None], | |
| ) | |
| result = model.model.diffusion(features["features"], progress=sampler_progress) | |
| wav = model.vocoder.spec2wav(result[0].T, f0=f0_seq).cpu().numpy() | |
| start = round(offset * config.sampling_rate) | |
| max_wav_len = generated_audio.shape[-1] - start | |
| generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len] | |
| # Loudness normalization | |
| generated_audio = loudness_norm.loudness_norm(generated_audio, config.sampling_rate) | |
| sf.write(output_path, generated_audio, config.sampling_rate) | |
| logger.info("Done") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="configs/svc_hubert_soft.py", | |
| help="Path to the config file", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| required=True, | |
| help="Path to the checkpoint file", | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| type=str, | |
| required=True, | |
| help="Path to the input audio file", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| required=True, | |
| help="Path to the output audio file", | |
| ) | |
| parser.add_argument( | |
| "--speaker_id", | |
| type=int, | |
| default=0, | |
| help="Speaker id", | |
| ) | |
| parser.add_argument( | |
| "--sampler_interval", | |
| type=int, | |
| default=None, | |
| required=False, | |
| help="Sampler interval, if not specified, will be taken from config", | |
| ) | |
| parser.add_argument( | |
| "--sampler_progress", | |
| action="store_true", | |
| help="Show sampler progress", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Device to use", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device(args.device) | |
| inference( | |
| Config.fromfile(args.config), | |
| args.checkpoint, | |
| args.input, | |
| args.output, | |
| speaker_id=args.speaker_id, | |
| sampler_interval=args.sampler_interval, | |
| sampler_progress=args.sampler_progress, | |
| device=device, | |
| ) | |