Spaces:
Runtime error
Runtime error
Upload test.py
Browse files
test.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import museval
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import data.utils
|
| 8 |
+
import model.utils as model_utils
|
| 9 |
+
import utils
|
| 10 |
+
|
| 11 |
+
def compute_model_output(model, inputs):
|
| 12 |
+
'''
|
| 13 |
+
Computes outputs of model with given inputs. Does NOT allow propagating gradients! See compute_loss for training.
|
| 14 |
+
Procedure depends on whether we have one model for each source or not
|
| 15 |
+
:param model: Model to train with
|
| 16 |
+
:param compute_grad: Whether to compute gradients
|
| 17 |
+
:return: Model outputs, Average loss over batch
|
| 18 |
+
'''
|
| 19 |
+
all_outputs = {}
|
| 20 |
+
|
| 21 |
+
if model.separate:
|
| 22 |
+
for inst in model.instruments:
|
| 23 |
+
output = model(inputs, inst)
|
| 24 |
+
all_outputs[inst] = output[inst].detach().clone()
|
| 25 |
+
else:
|
| 26 |
+
all_outputs = model(inputs)
|
| 27 |
+
|
| 28 |
+
return all_outputs
|
| 29 |
+
|
| 30 |
+
def predict(audio, model):
|
| 31 |
+
'''
|
| 32 |
+
Predict sources for a given audio input signal, with a given model. Audio is split into chunks to make predictions on each chunk before they are concatenated.
|
| 33 |
+
:param audio: Audio input tensor, either Pytorch tensor or numpy array
|
| 34 |
+
:param model: Pytorch model
|
| 35 |
+
:return: Source predictions, dictionary with source names as keys
|
| 36 |
+
'''
|
| 37 |
+
if isinstance(audio, torch.Tensor):
|
| 38 |
+
is_cuda = audio.is_cuda()
|
| 39 |
+
audio = audio.detach().cpu().numpy()
|
| 40 |
+
return_mode = "pytorch"
|
| 41 |
+
else:
|
| 42 |
+
return_mode = "numpy"
|
| 43 |
+
|
| 44 |
+
expected_outputs = audio.shape[1]
|
| 45 |
+
|
| 46 |
+
# Pad input if it is not divisible in length by the frame shift number
|
| 47 |
+
output_shift = model.shapes["output_frames"]
|
| 48 |
+
pad_back = audio.shape[1] % output_shift
|
| 49 |
+
pad_back = 0 if pad_back == 0 else output_shift - pad_back
|
| 50 |
+
if pad_back > 0:
|
| 51 |
+
audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0)
|
| 52 |
+
|
| 53 |
+
target_outputs = audio.shape[1]
|
| 54 |
+
outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments}
|
| 55 |
+
|
| 56 |
+
# Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal
|
| 57 |
+
pad_front_context = model.shapes["output_start_frame"]
|
| 58 |
+
pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
|
| 59 |
+
audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0)
|
| 60 |
+
|
| 61 |
+
# Iterate over mixture magnitudes, fetch network prediction
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
for target_start_pos in range(0, target_outputs, model.shapes["output_frames"]):
|
| 64 |
+
# Prepare mixture excerpt by selecting time interval
|
| 65 |
+
curr_input = audio[:, target_start_pos:target_start_pos + model.shapes["input_frames"]] # Since audio was front-padded input of [targetpos:targetpos+inputframes] actually predicts [targetpos:targetpos+outputframes] target range
|
| 66 |
+
|
| 67 |
+
# Convert to Pytorch tensor for model prediction
|
| 68 |
+
curr_input = torch.from_numpy(curr_input).unsqueeze(0)
|
| 69 |
+
|
| 70 |
+
# Predict
|
| 71 |
+
for key, curr_targets in compute_model_output(model, curr_input).items():
|
| 72 |
+
outputs[key][:,target_start_pos:target_start_pos+model.shapes["output_frames"]] = curr_targets.squeeze(0).cpu().numpy()
|
| 73 |
+
|
| 74 |
+
# Crop to expected length (since we padded to handle the frame shift)
|
| 75 |
+
outputs = {key : outputs[key][:,:expected_outputs] for key in outputs.keys()}
|
| 76 |
+
|
| 77 |
+
if return_mode == "pytorch":
|
| 78 |
+
outputs = torch.from_numpy(outputs)
|
| 79 |
+
if is_cuda:
|
| 80 |
+
outputs = outputs.cuda()
|
| 81 |
+
return outputs
|
| 82 |
+
|
| 83 |
+
def predict_song(args, audio_path, model):
|
| 84 |
+
'''
|
| 85 |
+
Predicts sources for an audio file for which the file path is given, using a given model.
|
| 86 |
+
Takes care of resampling the input audio to the models sampling rate and resampling predictions back to input sampling rate.
|
| 87 |
+
:param args: Options dictionary
|
| 88 |
+
:param audio_path: Path to mixture audio file
|
| 89 |
+
:param model: Pytorch model
|
| 90 |
+
:return: Source estimates given as dictionary with keys as source names
|
| 91 |
+
'''
|
| 92 |
+
model.eval()
|
| 93 |
+
|
| 94 |
+
# Load mixture in original sampling rate
|
| 95 |
+
mix_audio, mix_sr = data.utils.load(audio_path, sr=None, mono=False)
|
| 96 |
+
mix_channels = mix_audio.shape[0]
|
| 97 |
+
mix_len = mix_audio.shape[1]
|
| 98 |
+
|
| 99 |
+
# Adapt mixture channels to required input channels
|
| 100 |
+
if args.channels == 1:
|
| 101 |
+
mix_audio = np.mean(mix_audio, axis=0, keepdims=True)
|
| 102 |
+
else:
|
| 103 |
+
if mix_channels == 1: # Duplicate channels if input is mono but model is stereo
|
| 104 |
+
mix_audio = np.tile(mix_audio, [args.channels, 1])
|
| 105 |
+
else:
|
| 106 |
+
assert(mix_channels == args.channels)
|
| 107 |
+
|
| 108 |
+
# resample to model sampling rate
|
| 109 |
+
mix_audio = data.utils.resample(mix_audio, mix_sr, args.sr)
|
| 110 |
+
|
| 111 |
+
sources = predict(mix_audio, model)
|
| 112 |
+
|
| 113 |
+
# Resample back to mixture sampling rate in case we had model on different sampling rate
|
| 114 |
+
sources = {key : data.utils.resample(sources[key], args.sr, mix_sr) for key in sources.keys()}
|
| 115 |
+
|
| 116 |
+
# In case we had to pad the mixture at the end, or we have a few samples too many due to inconsistent down- and upsamṕling, remove those samples from source prediction now
|
| 117 |
+
for key in sources.keys():
|
| 118 |
+
diff = sources[key].shape[1] - mix_len
|
| 119 |
+
if diff > 0:
|
| 120 |
+
print("WARNING: Cropping " + str(diff) + " samples")
|
| 121 |
+
sources[key] = sources[key][:, :-diff]
|
| 122 |
+
elif diff < 0:
|
| 123 |
+
print("WARNING: Padding output by " + str(diff) + " samples")
|
| 124 |
+
sources[key] = np.pad(sources[key], [(0,0), (0, -diff)], "constant", 0.0)
|
| 125 |
+
|
| 126 |
+
# Adapt channels
|
| 127 |
+
if mix_channels > args.channels:
|
| 128 |
+
assert(args.channels == 1)
|
| 129 |
+
# Duplicate mono predictions
|
| 130 |
+
sources[key] = np.tile(sources[key], [mix_channels, 1])
|
| 131 |
+
elif mix_channels < args.channels:
|
| 132 |
+
assert(mix_channels == 1)
|
| 133 |
+
# Reduce model output to mono
|
| 134 |
+
sources[key] = np.mean(sources[key], axis=0, keepdims=True)
|
| 135 |
+
|
| 136 |
+
sources[key] = np.asfortranarray(sources[key]) # So librosa does not complain if we want to save it
|
| 137 |
+
|
| 138 |
+
return sources
|
| 139 |
+
|
| 140 |
+
def evaluate(args, dataset, model, instruments):
|
| 141 |
+
'''
|
| 142 |
+
Evaluates a given model on a given dataset
|
| 143 |
+
:param args: Options dict
|
| 144 |
+
:param dataset: Dataset object
|
| 145 |
+
:param model: Pytorch model
|
| 146 |
+
:param instruments: List of source names
|
| 147 |
+
:return: Performance metric dictionary, list with each element describing one dataset sample's results
|
| 148 |
+
'''
|
| 149 |
+
perfs = list()
|
| 150 |
+
model.eval()
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
for example in dataset:
|
| 153 |
+
print("Evaluating " + example["mix"])
|
| 154 |
+
|
| 155 |
+
# Load source references in their original sr and channel number
|
| 156 |
+
target_sources = np.stack([data.utils.load(example[instrument], sr=None, mono=False)[0].T for instrument in instruments])
|
| 157 |
+
|
| 158 |
+
# Predict using mixture
|
| 159 |
+
pred_sources = predict_song(args, example["mix"], model)
|
| 160 |
+
pred_sources = np.stack([pred_sources[key].T for key in instruments])
|
| 161 |
+
|
| 162 |
+
# Evaluate
|
| 163 |
+
SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(target_sources, pred_sources)
|
| 164 |
+
song = {}
|
| 165 |
+
for idx, name in enumerate(instruments):
|
| 166 |
+
song[name] = {"SDR" : SDR[idx], "ISR" : ISR[idx], "SIR" : SIR[idx], "SAR" : SAR[idx]}
|
| 167 |
+
perfs.append(song)
|
| 168 |
+
|
| 169 |
+
return perfs
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def validate(args, model, criterion, test_data):
|
| 173 |
+
'''
|
| 174 |
+
Iterate with a given model over a given test dataset and compute the desired loss
|
| 175 |
+
:param args: Options dictionary
|
| 176 |
+
:param model: Pytorch model
|
| 177 |
+
:param criterion: Loss function to use (similar to Pytorch criterions)
|
| 178 |
+
:param test_data: Test dataset (Pytorch dataset)
|
| 179 |
+
:return:
|
| 180 |
+
'''
|
| 181 |
+
# PREPARE DATA
|
| 182 |
+
dataloader = torch.utils.data.DataLoader(test_data,
|
| 183 |
+
batch_size=args.batch_size,
|
| 184 |
+
shuffle=False,
|
| 185 |
+
num_workers=args.num_workers)
|
| 186 |
+
|
| 187 |
+
# VALIDATE
|
| 188 |
+
model.eval()
|
| 189 |
+
total_loss = 0.
|
| 190 |
+
with tqdm(total=len(test_data) // args.batch_size) as pbar, torch.no_grad():
|
| 191 |
+
for example_num, (x, targets) in enumerate(dataloader):
|
| 192 |
+
if args.cuda:
|
| 193 |
+
x = x.cuda()
|
| 194 |
+
for k in list(targets.keys()):
|
| 195 |
+
targets[k] = targets[k].cuda()
|
| 196 |
+
|
| 197 |
+
_, avg_loss = model_utils.compute_loss(model, x, targets, criterion)
|
| 198 |
+
|
| 199 |
+
total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss)
|
| 200 |
+
|
| 201 |
+
pbar.set_description("Current loss: {:.4f}".format(total_loss))
|
| 202 |
+
pbar.update(1)
|
| 203 |
+
|
| 204 |
+
return total_loss
|