Spaces:
Running
Running
| # Standard Library Imports | |
| # Third Party Imports | |
| import torch | |
| import onnxruntime as ort | |
| # Local Imports | |
| from src.models.MDX_net.mdx_net import Conv_TDF_net_trimm | |
| from src.loader import Loader | |
| # Global Variables | |
| from src.constants import EXECUTION_PROVIDER_LIST, COMPUTATION_DEVICE, ONNX_MODEL_PATH | |
| class KimVocal: | |
| """ | |
| TODO: Put something here for flexibility purposes (model types). | |
| """ | |
| def __init__(self): | |
| pass | |
| def demix_vocals(self, music_tensor, sample_rate, model, streamlit_progressbar): | |
| """ | |
| Removing vocals using a ONNX model. | |
| Args: | |
| music_tensor (torch.Tensor): Input tensor. | |
| model (torch.nn): Model used for inferring. | |
| Returns: | |
| torch.Tensor: Output tensor after passing through the network. | |
| """ | |
| number_of_samples = music_tensor.shape[1] | |
| overlap = model.overlap | |
| # Calculate chunk_size and gen_size based on the sample rate | |
| chunk_size = model.chunk_size | |
| gen_size = chunk_size - 2 * overlap | |
| pad_size = gen_size - number_of_samples % gen_size | |
| mix_padded = torch.cat( | |
| [torch.zeros(2, overlap), music_tensor, torch.zeros(2, pad_size + overlap)], | |
| 1, | |
| ) | |
| # Start running the session for the model | |
| ort_session = ort.InferenceSession( | |
| ONNX_MODEL_PATH, providers=EXECUTION_PROVIDER_LIST | |
| ) | |
| # TODO: any way to optimize against silence? I think that's what skips are for, gotta double check. | |
| # process one chunk at a time (batch_size=1) | |
| demixed_chunks = [] | |
| i = 0 | |
| while i < number_of_samples + pad_size: | |
| # Progress Bar | |
| streamlit_progressbar.progress(i / (number_of_samples + pad_size)) | |
| # Computation | |
| chunk = mix_padded[:, i : i + chunk_size] | |
| x = model.stft(chunk.unsqueeze(0).to(COMPUTATION_DEVICE)) | |
| with torch.no_grad(): | |
| x = torch.tensor(ort_session.run(None, {"input": x.cpu().numpy()})[0]) | |
| x = model.stft.inverse(x).squeeze(0) | |
| x = x[..., overlap:-overlap] | |
| demixed_chunks.append(x) | |
| i += gen_size | |
| vocals_output = torch.cat(demixed_chunks, -1)[..., :-pad_size].cpu() | |
| return vocals_output | |
| if __name__ == "__main__": | |
| kimvocal = KimVocal() | |
| kimvocal.main() | |