|
|
import torch |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import yaml |
|
|
import librosa |
|
|
from tqdm.auto import tqdm |
|
|
import spaces |
|
|
|
|
|
import look2hear.models |
|
|
from ml_collections import ConfigDict |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def load_audio(file_path): |
|
|
audio, samplerate = librosa.load(file_path, mono=False, sr=44100) |
|
|
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}') |
|
|
|
|
|
return torch.from_numpy(audio), samplerate |
|
|
|
|
|
|
|
|
def get_config(config_path): |
|
|
with open(config_path) as f: |
|
|
|
|
|
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) |
|
|
return config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _getWindowingArray(window_size, fade_size): |
|
|
|
|
|
|
|
|
fadein = torch.linspace(1, 1, fade_size) |
|
|
fadeout = torch.linspace(0, 0, fade_size) |
|
|
window = torch.ones(window_size) |
|
|
window[-fade_size:] *= fadeout |
|
|
window[:fade_size] *= fadein |
|
|
return window |
|
|
|
|
|
|
|
|
|
|
|
description = f''' |
|
|
This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo |
|
|
''' |
|
|
|
|
|
|
|
|
apollo_config = get_config('configs/apollo.yaml') |
|
|
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml') |
|
|
apollo_uni_config = get_config('configs/config_apollo_uni.yaml') |
|
|
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device) |
|
|
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device) |
|
|
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device) |
|
|
apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device) |
|
|
|
|
|
|
|
|
|
|
|
models = { |
|
|
'apollo': apollo_model, |
|
|
'apollo_vocal': apollo_vocal, |
|
|
'apollo_vocal2': apollo_vocal2, |
|
|
'apollo_uni': apollo_uni |
|
|
} |
|
|
|
|
|
choices = [ |
|
|
('MP3 restore', 'apollo'), |
|
|
('Apollo vocal', 'apollo_vocal'), |
|
|
('Apollo vocal2', 'apollo_vocal2'), |
|
|
('Apollo universal', 'apollo_uni') |
|
|
] |
|
|
|
|
|
@spaces.GPU |
|
|
def enchance(choice, audio): |
|
|
print(choice) |
|
|
model = models[choice] |
|
|
test_data, samplerate = load_audio(audio) |
|
|
C = 10 * samplerate |
|
|
N = 2 |
|
|
step = C // N |
|
|
fade_size = 3 * 44100 |
|
|
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}") |
|
|
|
|
|
border = C - step |
|
|
|
|
|
|
|
|
if len(test_data.shape) == 1: |
|
|
test_data = test_data.unsqueeze(0) |
|
|
|
|
|
|
|
|
if test_data.shape[1] > 2 * border and (border > 0): |
|
|
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect') |
|
|
|
|
|
windowingArray = _getWindowingArray(C, fade_size) |
|
|
|
|
|
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) |
|
|
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) |
|
|
|
|
|
i = 0 |
|
|
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False) |
|
|
|
|
|
while i < test_data.shape[1]: |
|
|
part = test_data[:, i:i + C] |
|
|
length = part.shape[-1] |
|
|
if length < C: |
|
|
if length > C // 2 + 1: |
|
|
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') |
|
|
else: |
|
|
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) |
|
|
|
|
|
|
|
|
chunk = part.unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
out = model(chunk).squeeze(0).squeeze(0).cpu() |
|
|
|
|
|
window = windowingArray |
|
|
if i == 0: |
|
|
window[:fade_size] = 1 |
|
|
elif i + C >= test_data.shape[1]: |
|
|
window[-fade_size:] = 1 |
|
|
|
|
|
result[..., i:i+length] += out[..., :length] * window[..., :length] |
|
|
counter[..., i:i+length] += window[..., :length] |
|
|
|
|
|
i += step |
|
|
progress_bar.update(step) |
|
|
|
|
|
progress_bar.close() |
|
|
|
|
|
final_output = result / counter |
|
|
final_output = final_output.squeeze(0).numpy() |
|
|
np.nan_to_num(final_output, copy=False, nan=0.0) |
|
|
|
|
|
|
|
|
if test_data.shape[1] > 2 * border and (border > 0): |
|
|
final_output = final_output[..., border:-border] |
|
|
|
|
|
return samplerate, final_output.T |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
i = gr.Interface( |
|
|
fn=enchance, |
|
|
description=description, |
|
|
inputs=[ |
|
|
gr.Dropdown(label="Model", choices=choices, value=choices[0][1]), |
|
|
gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}), |
|
|
], |
|
|
outputs=[ |
|
|
gr.Audio( |
|
|
label="Output Audio", |
|
|
autoplay=False, |
|
|
streaming=False, |
|
|
type="numpy", |
|
|
), |
|
|
|
|
|
], |
|
|
allow_flagging ='never', |
|
|
cache_examples=False, |
|
|
title='Apollo audio restoration', |
|
|
|
|
|
) |
|
|
i.queue(max_size=20, default_concurrency_limit=4) |
|
|
i.launch(share=False, server_name="0.0.0.0") |
|
|
|