Spaces:
No application file
No application file
| import argparse | |
| import json | |
| import os | |
| from functools import partial | |
| from typing import Union | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from fish_audio_preprocess.utils import loudness_norm, separate_audio | |
| from loguru import logger | |
| from mmengine import Config | |
| from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS | |
| from fish_diffusion.utils.audio import get_mel_from_audio, slice_audio | |
| from fish_diffusion.utils.inference import load_checkpoint | |
| from fish_diffusion.utils.tensor import repeat_expand | |
| def inference( | |
| in_sample, | |
| config_path, | |
| checkpoint, | |
| input_path, | |
| output_path, | |
| speaker_id=0, | |
| pitch_adjust=0, | |
| silence_threshold=60, | |
| max_slice_duration=30.0, | |
| extract_vocals=True, | |
| merge_non_vocals=True, | |
| vocals_loudness_gain=0.0, | |
| sampler_interval=None, | |
| sampler_progress=False, | |
| device="cuda", | |
| gradio_progress=None, | |
| ): | |
| """Inference | |
| Args: | |
| config: config | |
| checkpoint: checkpoint path | |
| input_path: input path | |
| output_path: output path | |
| speaker_id: speaker id | |
| pitch_adjust: pitch adjust | |
| silence_threshold: silence threshold of librosa.effects.split | |
| max_slice_duration: maximum duration of each slice | |
| extract_vocals: extract vocals | |
| merge_non_vocals: merge non-vocals, only works when extract_vocals is True | |
| vocals_loudness_gain: loudness gain of vocals (dB) | |
| sampler_interval: sampler interval, lower value means higher quality | |
| sampler_progress: show sampler progress | |
| device: device | |
| gradio_progress: gradio progress callback | |
| """ | |
| config = Config.fromfile(config_path) | |
| if sampler_interval is not None: | |
| config.model.diffusion.sampler_interval = sampler_interval | |
| if os.path.isdir(checkpoint): | |
| # Find the latest checkpoint | |
| checkpoints = sorted(os.listdir(checkpoint)) | |
| logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}") | |
| checkpoint = os.path.join(checkpoint, checkpoints[-1]) | |
| audio, sr = librosa.load(input_path, config.sampling_rate, mono=True) | |
| #sr = in_sample | |
| #audio = sf.read(input_path) | |
| # Extract vocals | |
| if extract_vocals: | |
| logger.info("Extracting vocals...") | |
| if gradio_progress is not None: | |
| gradio_progress(0, "Extracting vocals...") | |
| model = separate_audio.init_model("htdemucs", device=device) | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=model.samplerate)[None] | |
| # To two channels | |
| audio = np.concatenate([audio, audio], axis=0) | |
| audio = torch.from_numpy(audio).to(device) | |
| tracks = separate_audio.separate_audio( | |
| model, audio, shifts=1, num_workers=0, progress=True | |
| ) | |
| audio = separate_audio.merge_tracks(tracks, filter=["vocals"]).cpu().numpy() | |
| non_vocals = ( | |
| separate_audio.merge_tracks(tracks, filter=["drums", "bass", "other"]) | |
| .cpu() | |
| .numpy() | |
| ) | |
| audio = librosa.resample(audio[0], orig_sr=model.samplerate, target_sr=sr) | |
| non_vocals = librosa.resample( | |
| non_vocals[0], orig_sr=model.samplerate, target_sr=sr | |
| ) | |
| # Normalize loudness | |
| non_vocals = loudness_norm.loudness_norm(non_vocals, sr) | |
| # Normalize loudness | |
| audio = loudness_norm.loudness_norm(audio, sr) | |
| # Slice into segments | |
| segments = list( | |
| slice_audio( | |
| audio, sr, max_duration=max_slice_duration, top_db=silence_threshold | |
| ) | |
| ) | |
| logger.info(f"Sliced into {len(segments)} segments") | |
| # Load models | |
| text_features_extractor = FEATURE_EXTRACTORS.build( | |
| config.preprocessing.text_features_extractor | |
| ).to(device) | |
| text_features_extractor.eval() | |
| model = load_checkpoint(config, checkpoint, device=device) | |
| pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor) | |
| assert pitch_extractor is not None, "Pitch extractor not found" | |
| generated_audio = np.zeros_like(audio) | |
| audio_torch = torch.from_numpy(audio).to(device)[None] | |
| for idx, (start, end) in enumerate(segments): | |
| if gradio_progress is not None: | |
| gradio_progress(idx / len(segments), "Generating audio...") | |
| segment = audio_torch[:, start:end] | |
| logger.info( | |
| f"Processing segment {idx + 1}/{len(segments)}, duration: {segment.shape[-1] / sr:.2f}s" | |
| ) | |
| # Extract mel | |
| mel = get_mel_from_audio(segment, sr) | |
| # Extract pitch (f0) | |
| pitch = pitch_extractor(segment, sr, pad_to=mel.shape[-1]).float() | |
| pitch *= 2 ** (pitch_adjust / 12) | |
| # Extract text features | |
| text_features = text_features_extractor(segment, sr)[0] | |
| text_features = repeat_expand(text_features, mel.shape[-1]).T | |
| # Predict | |
| src_lens = torch.tensor([mel.shape[-1]]).to(device) | |
| features = model.model.forward_features( | |
| speakers=torch.tensor([speaker_id]).long().to(device), | |
| contents=text_features[None].to(device), | |
| src_lens=src_lens, | |
| max_src_len=max(src_lens), | |
| mel_lens=src_lens, | |
| max_mel_len=max(src_lens), | |
| pitches=pitch[None].to(device), | |
| ) | |
| result = model.model.diffusion(features["features"], progress=sampler_progress) | |
| wav = model.vocoder.spec2wav(result[0].T, f0=pitch).cpu().numpy() | |
| max_wav_len = generated_audio.shape[-1] - start | |
| generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len] | |
| # Loudness normalization | |
| generated_audio = loudness_norm.loudness_norm(generated_audio, sr) | |
| # Loudness gain | |
| loudness_float = 10 ** (vocals_loudness_gain / 20) | |
| generated_audio = generated_audio * loudness_float | |
| # Merge non-vocals | |
| if extract_vocals and merge_non_vocals: | |
| generated_audio = (generated_audio + non_vocals) / 2 | |
| logger.info("Done") | |
| if output_path is not None: | |
| sf.write(output_path, generated_audio, sr) | |
| return generated_audio, sr | |
| class SvcFish: | |
| def __init__(self, checkpoint_path, config_path, sampler_interval=None, extract_vocals=True, | |
| merge_non_vocals=True,vocals_loudness_gain=0.0,silence_threshold=60, max_slice_duration=30.0): | |
| self.config_path = config_path | |
| self.checkpoint_path = checkpoint_path | |
| self.sampler_interval = sampler_interval | |
| self.silence_threshold = silence_threshold | |
| self.max_slice_duration = max_slice_duration | |
| self.extract_vocals = extract_vocals | |
| self.merge_non_vocals = merge_non_vocals | |
| self.vocals_loudness_gain = vocals_loudness_gain | |
| def infer(self, input_path, pitch_adjust, speaker_id, in_sample): | |
| return inference( | |
| in_sample=in_sample, | |
| config_path=self.config_path, | |
| checkpoint=self.checkpoint_path, | |
| input_path=input_path, | |
| output_path=None, | |
| speaker_id=speaker_id, | |
| pitch_adjust=pitch_adjust, | |
| silence_threshold=self.silence_threshold, | |
| max_slice_duration=self.max_slice_duration, | |
| extract_vocals=self.extract_vocals, | |
| merge_non_vocals=self.merge_non_vocals, | |
| vocals_loudness_gain=self.vocals_loudness_gain, | |
| sampler_interval=self.sampler_interval, | |
| sampler_progress=True, | |
| device="cuda", | |
| gradio_progress=None, | |
| ) | |