Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import json | |
| import wget | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from dcc_tf import Net as Waveformer | |
| TARGETS = [ | |
| "Acoustic_guitar", "Applause", "Bark", "Bass_drum", | |
| "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", | |
| "Computer_keyboard", "Cough", "Cowbell", "Double_bass", | |
| "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", | |
| "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", | |
| "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", | |
| "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", | |
| "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", | |
| "Trumpet", "Violin_or_fiddle", "Writing" | |
| ] | |
| if not os.path.exists('default_config.json'): | |
| config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' | |
| print("Downloading model configuration from %s:" % config_url) | |
| wget.download(config_url) | |
| if not os.path.exists('default_ckpt.pt'): | |
| ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' | |
| print("\nDownloading the checkpoint from %s:" % ckpt_url) | |
| wget.download(ckpt_url) | |
| # Instantiate model | |
| with open('default_config.json') as f: | |
| params = json.load(f) | |
| model = Waveformer(**params['model_params']) | |
| model.load_state_dict( | |
| torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict']) | |
| model.eval() | |
| def waveformer(audio, label_choices): | |
| # Read input audio | |
| fs, mixture = audio | |
| if fs != 44100: | |
| raise ValueError("Sampling rate must be 44100, but got %d" % fs) | |
| mixture = torch.from_numpy( | |
| mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) | |
| # Construct the query vector | |
| query = torch.zeros(1, len(TARGETS)) | |
| for t in label_choices: | |
| query[0, TARGETS.index(t)] = 1. | |
| with torch.no_grad(): | |
| output = (2.0 ** 15) * model(mixture, query) | |
| return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() | |
| input_audio = gr.Audio(label="Input audio") | |
| label_checkbox = gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)") | |
| output_audio = gr.Audio(label="Output audio") | |
| demo = gr.Interface(fn=waveformer, inputs=[input_audio, label_checkbox], outputs=output_audio) | |
| demo.launch(show_error=True) | |