Spaces:
Runtime error
Runtime error
| import spaces | |
| from timbre_trap.framework.modules import TimbreTrap | |
| from pyharp import * | |
| import gradio as gr | |
| import torchaudio | |
| import torch | |
| import os | |
| model = TimbreTrap(sample_rate=22050, | |
| n_octaves=9, | |
| bins_per_octave=60, | |
| secs_per_block=3, | |
| latent_size=128, | |
| model_complexity=2, | |
| skip_connections=False) | |
| model.eval() | |
| model_path_orig = os.path.join('models', 'tt-orig.pt') | |
| #model_path_demo = os.path.join('models', 'tt-demo.pt') | |
| tt_weights_orig = torch.load(model_path_orig, map_location='cpu') | |
| #tt_weights_demo = torch.load(model_path_demo, map_location='cpu') | |
| # if torch.cuda.is_available(): | |
| # model = model.cuda() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model_card = ModelCard( | |
| name='Timbre-Trap', | |
| description='De-timbre your audio!', | |
| author='Frank Cwitkowitz', | |
| tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering'] | |
| ) | |
| def process_fn(audio_path, transcribe):#, demo): | |
| # Load the audio with torchaudio | |
| audio, fs = torchaudio.load(audio_path) | |
| # Average channels to obtain mono-channel | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| # Resample audio to the specified sampling rate | |
| audio = torchaudio.functional.resample(audio, fs, 22050) | |
| # Add a batch dimension | |
| audio = audio.unsqueeze(0) | |
| # Determine original number of samples | |
| n_samples = audio.size(-1) | |
| """ | |
| if demo: | |
| # Load weights of the demo version | |
| model.load_state_dict(tt_weights_demo) | |
| else: | |
| """ | |
| # Load weights of the original model | |
| model.load_state_dict(tt_weights_orig) | |
| audio = audio.to(device) | |
| # Obtain transcription or reconstructed spectral coefficients | |
| coefficients = model.chunked_inference(audio, transcribe) | |
| # Invert coefficients to produce audio | |
| audio = model.sliCQ.decode(coefficients) | |
| # Trim to original number of samples | |
| audio = audio[..., :n_samples] | |
| # Remove batch dimension | |
| audio = audio.squeeze(0) | |
| # Low-pass filter the audio in attempt to remove artifacts | |
| audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000) | |
| # Resample audio back to the original sampling rate | |
| audio = torchaudio.functional.resample(audio, 22050, fs) | |
| audio = audio.cpu() | |
| # Create a temporary directory for output | |
| os.makedirs('_outputs', exist_ok=True) | |
| # Create a path for saving the audio | |
| save_path = os.path.join('_outputs', 'output.wav') | |
| # Save the audio | |
| torchaudio.save(save_path, audio, fs) | |
| # No output labels | |
| output_labels = LabelList() | |
| return save_path, output_labels | |
| # Build Gradio endpoint | |
| with gr.Blocks() as demo: | |
| components = [ | |
| gr.Checkbox( | |
| value=False, | |
| label='Remove Timbre' | |
| ) | |
| ] | |
| app = build_endpoint(model_card=model_card, | |
| components=components, | |
| process_fn=process_fn) | |
| demo.queue() | |
| demo.launch(share=True) | |