Spaces:
Runtime error
Runtime error
| import uuid | |
| import ffmpeg | |
| import gradio as gr | |
| from pathlib import Path | |
| from denoisers.SpectralGating import SpectralGating | |
| from huggingface_hub import hf_hub_download | |
| from denoisers.demucs import Demucs | |
| import torch | |
| import torchaudio | |
| import yaml | |
| import argparse | |
| import os | |
| os.environ['CURL_CA_BUNDLE'] = '' | |
| SAMPLE_RATE = 32000 | |
| def denoising_transform(audio, model): | |
| src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4()))) | |
| tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))) | |
| src_path.parent.mkdir(exist_ok=True, parents=True) | |
| tgt_path.parent.mkdir(exist_ok=True, parents=True) | |
| (ffmpeg.input(audio) | |
| .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=SAMPLE_RATE) | |
| .run() | |
| ) | |
| wav, rate = torchaudio.load(src_path) | |
| reduced_noise = model.predict(wav) | |
| torchaudio.save(tgt_path, reduced_noise, rate) | |
| return src_path, tgt_path | |
| def run_app(model_filename, config_filename, port, concurrency_count, max_size): | |
| model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename) | |
| config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename) | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| model = Demucs(config['demucs']) | |
| checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| title = "Denoising" | |
| with gr.Blocks(title=title) as app: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| # Denoising | |
| ## Instruction: \n | |
| 1. Press "Record from microphone" | |
| 2. Press "Stop recording" | |
| 3. Press "Enhance" \n | |
| - You can switch to the tab "File" to upload a prerecorded .wav audio instead of recording from microphone. | |
| """ | |
| ) | |
| with gr.Tab("Microphone"): | |
| microphone = gr.Audio(label="Source Audio", source="microphone", type='filepath') | |
| with gr.Row(): | |
| microphone_button = gr.Button("Enhance", variant="primary") | |
| with gr.Tab("File"): | |
| upload = gr.Audio(label="Upload Audio", source="upload", type='filepath') | |
| with gr.Row(): | |
| upload_button = gr.Button("Enhance", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| gr.Examples(examples=[[path] for path in Path("testing/wavs/").glob("*.wav")], | |
| inputs=[microphone, upload]) | |
| with gr.Column(): | |
| outputs = [gr.Audio(label="Input Audio", type='filepath'), | |
| gr.Audio(label="Demucs Enhancement", type='filepath'), | |
| gr.Audio(label="Spectral Gating Enhancement", type='filepath') | |
| ] | |
| def submit(audio): | |
| src_path, demucs_tgt_path = denoising_transform(audio, model) | |
| _, spectral_gating_tgt_path = denoising_transform(audio, SpectralGating()) | |
| return src_path, demucs_tgt_path, spectral_gating_tgt_path, gr.update(visible=False), gr.update(visible=False) | |
| microphone_button.click( | |
| submit, | |
| microphone, | |
| outputs + [microphone, upload] | |
| ) | |
| upload_button.click( | |
| submit, | |
| upload, | |
| outputs + [microphone, upload] | |
| ) | |
| def restart(): | |
| return microphone.update(visible=True, value=None), upload.update(visible=True, value=None), None, None, None | |
| clear_btn.click(restart, inputs=[], outputs=[microphone, upload] + outputs) | |
| app.queue(concurrency_count=concurrency_count, max_size=max_size) | |
| app.launch( | |
| server_name='0.0.0.0', | |
| server_port=port, | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Running demo.') | |
| parser.add_argument('--port', | |
| type=int, | |
| default=7860) | |
| parser.add_argument('--model_filename', | |
| type=str, | |
| default="paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt") | |
| parser.add_argument('--config_filename', | |
| type=str, | |
| default="paper_replica_10_epoch/config.yaml") | |
| parser.add_argument('--concurrency_count', | |
| type=int, | |
| default=4) | |
| parser.add_argument('--max_size', | |
| type=int, | |
| default=15) | |
| args = parser.parse_args() | |
| run_app(args.model_filename, args.config_filename, args.port, args.concurrency_count, args.max_size) | |