Spaces:
Paused
Paused
| import os | |
| import shutil | |
| import hashlib | |
| from pathlib import Path | |
| from typing import Tuple | |
| from demucs.separate import main as demucs | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from zerorvc import RVC | |
| from .zero import zero | |
| from .model import device | |
| import yt_dlp | |
| def download_audio(url): | |
| ydl_opts = { | |
| "format": "bestaudio/best", | |
| "outtmpl": "ytdl/%(title)s.%(ext)s", | |
| "postprocessors": [ | |
| { | |
| "key": "FFmpegExtractAudio", | |
| "preferredcodec": "wav", | |
| "preferredquality": "192", | |
| } | |
| ], | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info_dict = ydl.extract_info(url, download=True) | |
| file_path = ydl.prepare_filename(info_dict).rsplit(".", 1)[0] + ".wav" | |
| sample_rate, audio_data = read(file_path) | |
| audio_array = np.asarray(audio_data, dtype=np.int16) | |
| return sample_rate, audio_array | |
| def infer( | |
| exp_dir: str, original_audio: str, pitch_mod: int, protect: float | |
| ) -> Tuple[int, np.ndarray]: | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| if not os.path.exists(checkpoint_dir): | |
| raise gr.Error("Model not found") | |
| # rename the original audio to the hash | |
| with open(original_audio, "rb") as f: | |
| original_audio_hash = hashlib.md5(f.read()).hexdigest() | |
| ext = Path(original_audio).suffix | |
| original_audio_hashed = os.path.join(exp_dir, f"{original_audio_hash}{ext}") | |
| shutil.copy(original_audio, original_audio_hashed) | |
| out = os.path.join("separated", "htdemucs", original_audio_hash, "vocals.wav") | |
| if not os.path.exists(out): | |
| demucs( | |
| [ | |
| "--two-stems", | |
| "vocals", | |
| "-d", | |
| str(device), | |
| "-n", | |
| "htdemucs", | |
| original_audio_hashed, | |
| ] | |
| ) | |
| rvc = RVC.from_pretrained(checkpoint_dir) | |
| samples = rvc.convert(out, pitch_modification=pitch_mod, protect=protect) | |
| file = os.path.join(exp_dir, "infer.wav") | |
| sf.write(file, samples, rvc.sr) | |
| return file | |
| def merge(exp_dir: str, original_audio: str, vocal: Tuple[int, np.ndarray]) -> str: | |
| with open(original_audio, "rb") as f: | |
| original_audio_hash = hashlib.md5(f.read()).hexdigest() | |
| music = os.path.join("separated", "htdemucs", original_audio_hash, "no_vocals.wav") | |
| tmp = os.path.join(exp_dir, "tmp.wav") | |
| sf.write(tmp, vocal[1], vocal[0]) | |
| os.system( | |
| f"ffmpeg -i {music} -i {tmp} -filter_complex '[1]volume=2[a];[0][a]amix=inputs=2:duration=first:dropout_transition=2' -ac 2 -y {tmp}.merged.mp3" | |
| ) | |
| return f"{tmp}.merged.mp3" | |
| class InferenceTab: | |
| def __init__(self): | |
| pass | |
| def ui(self): | |
| gr.Markdown("# Inference") | |
| gr.Markdown( | |
| "After trained model is pruned, you can use it to infer on new music. \n" | |
| "Upload the original audio and adjust the F0 add value to generate the inferred audio." | |
| ) | |
| with gr.Row(): | |
| self.original_audio = gr.Audio( | |
| label="Upload original audio", | |
| type="filepath", | |
| show_download_button=True, | |
| ) | |
| with gr.Accordion("inference by Link", open=False): | |
| with gr.Row(): | |
| youtube_link = gr.Textbox( | |
| label="Link", | |
| placeholder="Paste the link here", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "You can paste the link to the video/audio from many sites, check the complete list [here](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)" | |
| ) | |
| with gr.Row(): | |
| download_button = gr.Button("Download!", variant="primary") | |
| download_button.click( | |
| download_audio, [youtube_link], [self.original_audio] | |
| ) | |
| with gr.Column(): | |
| self.pitch_mod = gr.Slider( | |
| label="Pitch Modification +/-", | |
| minimum=-16, | |
| maximum=16, | |
| step=1, | |
| value=0, | |
| ) | |
| self.protect = gr.Slider( | |
| label="Protect", | |
| minimum=0, | |
| maximum=0.5, | |
| step=0.01, | |
| value=0.33, | |
| ) | |
| self.infer_btn = gr.Button(value="Infer", variant="primary") | |
| with gr.Row(): | |
| self.infer_output = gr.Audio( | |
| label="Inferred audio", show_download_button=True, format="mp3" | |
| ) | |
| with gr.Row(): | |
| self.merge_output = gr.Audio( | |
| label="Merged audio", show_download_button=True, format="mp3" | |
| ) | |
| def build(self, exp_dir: gr.Textbox): | |
| self.infer_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| exp_dir, | |
| self.original_audio, | |
| self.pitch_mod, | |
| self.protect, | |
| ], | |
| outputs=[self.infer_output], | |
| ).success( | |
| fn=merge, | |
| inputs=[exp_dir, self.original_audio, self.infer_output], | |
| outputs=[self.merge_output], | |
| ) | |