Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import yaml | |
| import librosa | |
| from tqdm.auto import tqdm | |
| import spaces | |
| import look2hear.models | |
| from ml_collections import ConfigDict | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_audio(file_path): | |
| audio, samplerate = librosa.load(file_path, mono=False, sr=44100) | |
| print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}') | |
| #audio = dBgain(audio, -6) | |
| return torch.from_numpy(audio), samplerate | |
| def get_config(config_path): | |
| with open(config_path) as f: | |
| #config = OmegaConf.load(config_path) | |
| config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) | |
| return config | |
| def _getWindowingArray(window_size, fade_size): | |
| # IMPORTANT NOTE : | |
| # no fades here in the end, only removing the failed ending of the chunk | |
| fadein = torch.linspace(1, 1, fade_size) | |
| fadeout = torch.linspace(0, 0, fade_size) | |
| window = torch.ones(window_size) | |
| window[-fade_size:] *= fadeout | |
| window[:fade_size] *= fadein | |
| return window | |
| description = f''' | |
| This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo | |
| ''' | |
| apollo_config = get_config('configs/apollo.yaml') | |
| apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml') | |
| apollo_uni_config = get_config('configs/config_apollo_uni.yaml') | |
| apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device) | |
| apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device) | |
| apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device) | |
| apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device) | |
| models = { | |
| 'apollo': apollo_model, | |
| 'apollo_vocal': apollo_vocal, | |
| 'apollo_vocal2': apollo_vocal2, | |
| 'apollo_uni': apollo_uni | |
| } | |
| choices = [ | |
| ('MP3 restore', 'apollo'), | |
| ('Apollo vocal', 'apollo_vocal'), | |
| ('Apollo vocal2', 'apollo_vocal2'), | |
| ('Apollo universal', 'apollo_uni') | |
| ] | |
| def enchance(choice, audio): | |
| print(choice) | |
| model = models[choice] | |
| test_data, samplerate = load_audio(audio) | |
| C = 10 * samplerate # chunk_size seconds to samples | |
| N = 2 | |
| step = C // N | |
| fade_size = 3 * 44100 # 3 seconds | |
| print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}") | |
| border = C - step | |
| # handle mono inputs correctly | |
| if len(test_data.shape) == 1: | |
| test_data = test_data.unsqueeze(0) | |
| # Pad the input if necessary | |
| if test_data.shape[1] > 2 * border and (border > 0): | |
| test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect') | |
| windowingArray = _getWindowingArray(C, fade_size) | |
| result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) | |
| counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) | |
| i = 0 | |
| progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False) | |
| while i < test_data.shape[1]: | |
| part = test_data[:, i:i + C] | |
| length = part.shape[-1] | |
| if length < C: | |
| if length > C // 2 + 1: | |
| part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') | |
| else: | |
| part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) | |
| chunk = part.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = model(chunk).squeeze(0).squeeze(0).cpu() | |
| window = windowingArray | |
| if i == 0: # First audio chunk, no fadein | |
| window[:fade_size] = 1 | |
| elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout | |
| window[-fade_size:] = 1 | |
| result[..., i:i+length] += out[..., :length] * window[..., :length] | |
| counter[..., i:i+length] += window[..., :length] | |
| i += step | |
| progress_bar.update(step) | |
| progress_bar.close() | |
| final_output = result / counter | |
| final_output = final_output.squeeze(0).numpy() | |
| np.nan_to_num(final_output, copy=False, nan=0.0) | |
| # Remove padding if added earlier | |
| if test_data.shape[1] > 2 * border and (border > 0): | |
| final_output = final_output[..., border:-border] | |
| return samplerate, final_output.T | |
| if __name__ == "__main__": | |
| i = gr.Interface( | |
| fn=enchance, | |
| description=description, | |
| inputs=[ | |
| gr.Dropdown(label="Model", choices=choices, value=choices[0][1]), | |
| gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}), | |
| ], | |
| outputs=[ | |
| gr.Audio( | |
| label="Output Audio", | |
| autoplay=False, | |
| streaming=False, | |
| type="numpy", | |
| ), | |
| ], | |
| allow_flagging ='never', | |
| cache_examples=False, | |
| title='Apollo audio restoration', | |
| ) | |
| i.queue(max_size=20, default_concurrency_limit=4) | |
| i.launch(share=False, server_name="0.0.0.0") | |