Spaces:
Paused
Paused
| from glob import glob | |
| import os | |
| from typing import Tuple | |
| from demucs.separate import main as demucs | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from configs.config import Config | |
| from infer.modules.vc.modules import VC | |
| from zero import zero | |
| from model import device | |
| def infer( | |
| exp_dir: str, original_audio: str, f0add: int, index_rate: float, protect: float | |
| ) -> Tuple[int, np.ndarray]: | |
| model = os.path.join(exp_dir, "model.pth") | |
| if not os.path.exists(model): | |
| raise gr.Error("Model not found") | |
| index = glob(f"{exp_dir}/added_*.index") | |
| if index: | |
| index = index[0] | |
| else: | |
| index = None | |
| base = os.path.basename(original_audio) | |
| base = os.path.splitext(base)[0] | |
| demucs( | |
| ["--two-stems", "vocals", "-d", str(device), "-n", "htdemucs", original_audio] | |
| ) | |
| out = os.path.join("separated", "htdemucs", base, "vocals.wav") | |
| cfg = Config() | |
| vc = VC(cfg) | |
| vc.get_vc(model) | |
| _, wav_opt = vc.vc_single( | |
| 0, | |
| out, | |
| f0add, | |
| None, | |
| "rmvpe", | |
| index, | |
| None, | |
| index_rate, | |
| 3, # this only has effect when f0_method is "harvest" | |
| 0, | |
| 1, | |
| protect, | |
| ) | |
| sr = wav_opt[0] | |
| data = wav_opt[1] | |
| return sr, data | |
| def merge(exp_dir: str, original_audio: str, vocal: Tuple[int, np.ndarray]) -> str: | |
| base = os.path.basename(original_audio) | |
| base = os.path.splitext(base)[0] | |
| music = os.path.join("separated", "htdemucs", base, "no_vocals.wav") | |
| tmp = os.path.join(exp_dir, "tmp.wav") | |
| sf.write(tmp, vocal[1], vocal[0]) | |
| os.system( | |
| f"ffmpeg -i {music} -i {tmp} -filter_complex '[1]volume=2[a];[0][a]amix=inputs=2:duration=first:dropout_transition=2' -ac 2 -y {tmp}.merged.mp3" | |
| ) | |
| return f"{tmp}.merged.mp3" | |
| class InferenceTab: | |
| def __init__(self): | |
| pass | |
| def ui(self): | |
| gr.Markdown("# Inference") | |
| gr.Markdown( | |
| "After trained model is pruned, you can use it to infer on new music. \n" | |
| "Upload the original audio and adjust the F0 add value to generate the inferred audio." | |
| ) | |
| with gr.Row(): | |
| self.original_audio = gr.Audio( | |
| label="Upload original audio", | |
| type="filepath", | |
| show_download_button=True, | |
| ) | |
| with gr.Column(): | |
| self.f0add = gr.Slider( | |
| label="F0 +/-", | |
| minimum=-16, | |
| maximum=16, | |
| step=1, | |
| value=0, | |
| ) | |
| self.index_rate = gr.Slider( | |
| label="Index rate", | |
| minimum=-0, | |
| maximum=1, | |
| step=0.01, | |
| value=0.5, | |
| ) | |
| self.protect = gr.Slider( | |
| label="Protect", | |
| minimum=0, | |
| maximum=1, | |
| step=0.01, | |
| value=0.33, | |
| ) | |
| self.infer_btn = gr.Button(value="Infer", variant="primary") | |
| with gr.Row(): | |
| self.infer_output = gr.Audio( | |
| label="Inferred audio", show_download_button=True, format="mp3" | |
| ) | |
| with gr.Row(): | |
| self.merge_output = gr.Audio( | |
| label="Merged audio", show_download_button=True, format="mp3" | |
| ) | |
| def build(self, exp_dir: gr.Textbox): | |
| self.infer_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| exp_dir, | |
| self.original_audio, | |
| self.f0add, | |
| self.index_rate, | |
| self.protect, | |
| ], | |
| outputs=[self.infer_output], | |
| ).success( | |
| fn=merge, | |
| inputs=[exp_dir, self.original_audio, self.infer_output], | |
| outputs=[self.merge_output], | |
| ) | |