| import time | |
| import numpy as np | |
| import torch | |
| import sys | |
| import torch.nn as nn | |
| def get_model_from_config(model_type, config): | |
| if model_type == 'mel_band_roformer': | |
| from models.mel_band_roformer import MelBandRoformer | |
| model = MelBandRoformer( | |
| **dict(config.model) | |
| ) | |
| else: | |
| print('Unknown model: {}'.format(model_type)) | |
| model = None | |
| return model | |
| def get_windowing_array(window_size, fade_size, device): | |
| fadein = torch.linspace(0, 1, fade_size) | |
| fadeout = torch.linspace(1, 0, fade_size) | |
| window = torch.ones(window_size) | |
| window[-fade_size:] *= fadeout | |
| window[:fade_size] *= fadein | |
| return window.to(device) | |
| def demix_track(config, model, mix, device, first_chunk_time=None): | |
| C = config.inference.chunk_size | |
| N = config.inference.num_overlap | |
| step = C // N | |
| fade_size = C // 10 | |
| border = C - step | |
| if mix.shape[1] > 2 * border and border > 0: | |
| mix = nn.functional.pad(mix, (border, border), mode='reflect') | |
| windowing_array = get_windowing_array(C, fade_size, device) | |
| with torch.cuda.amp.autocast(): | |
| with torch.no_grad(): | |
| if config.training.target_instrument is not None: | |
| req_shape = (1, ) + tuple(mix.shape) | |
| else: | |
| req_shape = (len(config.training.instruments),) + tuple(mix.shape) | |
| mix = mix.to(device) | |
| result = torch.zeros(req_shape, dtype=torch.float32).to(device) | |
| counter = torch.zeros(req_shape, dtype=torch.float32).to(device) | |
| i = 0 | |
| total_length = mix.shape[1] | |
| num_chunks = (total_length + step - 1) // step | |
| if first_chunk_time is None: | |
| start_time = time.time() | |
| first_chunk = True | |
| else: | |
| start_time = None | |
| first_chunk = False | |
| while i < total_length: | |
| part = mix[:, i:i + C] | |
| length = part.shape[-1] | |
| if length < C: | |
| if length > C // 2 + 1: | |
| part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') | |
| else: | |
| part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) | |
| if first_chunk and i == 0: | |
| chunk_start_time = time.time() | |
| x = model(part.unsqueeze(0))[0] | |
| window = windowing_array.clone() | |
| if i == 0: | |
| window[:fade_size] = 1 | |
| elif i + C >= total_length: | |
| window[-fade_size:] = 1 | |
| result[..., i:i+length] += x[..., :length] * window[..., :length] | |
| counter[..., i:i+length] += window[..., :length] | |
| i += step | |
| if first_chunk and i == step: | |
| chunk_time = time.time() - chunk_start_time | |
| first_chunk_time = chunk_time | |
| estimated_total_time = chunk_time * num_chunks | |
| print(f"Estimated total processing time for this track: {estimated_total_time:.2f} seconds") | |
| first_chunk = False | |
| if first_chunk_time is not None and i > step: | |
| chunks_processed = i // step | |
| time_remaining = first_chunk_time * (num_chunks - chunks_processed) | |
| sys.stdout.write(f"\rEstimated time remaining: {time_remaining:.2f} seconds") | |
| sys.stdout.flush() | |
| print() | |
| estimated_sources = result / counter | |
| estimated_sources = estimated_sources.cpu().numpy() | |
| np.nan_to_num(estimated_sources, copy=False, nan=0.0) | |
| if mix.shape[1] > 2 * border and border > 0: | |
| estimated_sources = estimated_sources[..., border:-border] | |
| if config.training.target_instrument is None: | |
| return {k: v for k, v in zip(config.training.instruments, estimated_sources)}, first_chunk_time | |
| else: | |
| return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}, first_chunk_time |