Spaces:
Running
on
Zero
Running
on
Zero
PRamoneda
commited on
Commit
·
a5af45b
1
Parent(s):
45e5657
gpu to cpu
Browse files- __pycache__/get_difficulty.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- get_difficulty.py +13 -21
- model.py +0 -121
- temp.mid +0 -0
__pycache__/get_difficulty.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/get_difficulty.cpython-310.pyc and b/__pycache__/get_difficulty.cpython-310.pyc differ
|
|
|
__pycache__/model.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
|
|
|
get_difficulty.py
CHANGED
|
@@ -32,18 +32,16 @@ def get_cqt_from_mp3(mp3_path):
|
|
| 32 |
log_cqt = log_cqt.T # shape (T, 88)
|
| 33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
| 34 |
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
| 35 |
-
# pdb.set_trace()
|
| 36 |
print(f"cqt shape: {log_cqt.shape}")
|
| 37 |
return cqt_tensor
|
| 38 |
|
| 39 |
def get_pianoroll_from_mp3(mp3_path):
|
| 40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
| 41 |
-
transcriptor = PianoTranscription(device=
|
| 42 |
midi_path = "temp.mid"
|
| 43 |
transcriptor.transcribe(audio, midi_path)
|
| 44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
| 45 |
|
| 46 |
-
# Create pianoroll and onset matrix
|
| 47 |
fs = 5 # original frames per second
|
| 48 |
piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
|
| 49 |
piano_roll = piano_roll / 127
|
|
@@ -64,6 +62,8 @@ def get_pianoroll_from_mp3(mp3_path):
|
|
| 64 |
return out_tensor.transpose(2, 3)
|
| 65 |
|
| 66 |
def predict_difficulty(mp3_path, model_name, rep):
|
|
|
|
|
|
|
| 67 |
if "only_cqt" in rep:
|
| 68 |
only_cqt, only_pr = True, False
|
| 69 |
rep_clean = "multimodal5"
|
|
@@ -74,18 +74,17 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
| 74 |
only_cqt = only_pr = False
|
| 75 |
rep_clean = rep
|
| 76 |
|
| 77 |
-
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
| 78 |
-
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=
|
| 79 |
for i in range(5)]
|
| 80 |
|
| 81 |
-
|
| 82 |
if rep == "cqt5":
|
| 83 |
-
inp_data = get_cqt_from_mp3(mp3_path)
|
| 84 |
elif rep == "pianoroll5":
|
| 85 |
-
inp_data = get_pianoroll_from_mp3(mp3_path)
|
| 86 |
elif rep_clean == "multimodal5":
|
| 87 |
-
x1 = get_pianoroll_from_mp3(mp3_path)
|
| 88 |
-
x2 = get_cqt_from_mp3(mp3_path)
|
| 89 |
inp_data = [x1, x2]
|
| 90 |
else:
|
| 91 |
raise ValueError(f"Representation {rep} not supported")
|
|
@@ -93,23 +92,16 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
| 93 |
preds = []
|
| 94 |
for cheks in checkpoint:
|
| 95 |
model.load_state_dict(cheks["model_state_dict"])
|
| 96 |
-
model
|
| 97 |
with torch.inference_mode():
|
| 98 |
logits = model(inp_data, None)
|
| 99 |
pred = prediction2label(logits).item()
|
| 100 |
preds.append(pred)
|
| 101 |
|
| 102 |
return mean(preds)
|
| 103 |
-
# return preds
|
| 104 |
|
| 105 |
if __name__ == "__main__":
|
| 106 |
mp3_path = "yt_audio.mp3"
|
| 107 |
-
model_name = ""
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# pred_pr = predict_difficulty(mp3_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5")
|
| 112 |
-
# print(f"Predicción dificultad PR: {pred_pr}")
|
| 113 |
-
|
| 114 |
-
pred_multi = predict_difficulty(mp3_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5")
|
| 115 |
-
print(f"Predicción dificultad multimodal: {pred_multi}")
|
|
|
|
| 32 |
log_cqt = log_cqt.T # shape (T, 88)
|
| 33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
| 34 |
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
|
|
|
| 35 |
print(f"cqt shape: {log_cqt.shape}")
|
| 36 |
return cqt_tensor
|
| 37 |
|
| 38 |
def get_pianoroll_from_mp3(mp3_path):
|
| 39 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
| 40 |
+
transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
midi_path = "temp.mid"
|
| 42 |
transcriptor.transcribe(audio, midi_path)
|
| 43 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
| 44 |
|
|
|
|
| 45 |
fs = 5 # original frames per second
|
| 46 |
piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
|
| 47 |
piano_roll = piano_roll / 127
|
|
|
|
| 62 |
return out_tensor.transpose(2, 3)
|
| 63 |
|
| 64 |
def predict_difficulty(mp3_path, model_name, rep):
|
| 65 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 66 |
+
|
| 67 |
if "only_cqt" in rep:
|
| 68 |
only_cqt, only_pr = True, False
|
| 69 |
rep_clean = "multimodal5"
|
|
|
|
| 74 |
only_cqt = only_pr = False
|
| 75 |
rep_clean = rep
|
| 76 |
|
| 77 |
+
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device)
|
| 78 |
+
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False)
|
| 79 |
for i in range(5)]
|
| 80 |
|
|
|
|
| 81 |
if rep == "cqt5":
|
| 82 |
+
inp_data = get_cqt_from_mp3(mp3_path).to(device)
|
| 83 |
elif rep == "pianoroll5":
|
| 84 |
+
inp_data = get_pianoroll_from_mp3(mp3_path).to(device)
|
| 85 |
elif rep_clean == "multimodal5":
|
| 86 |
+
x1 = get_pianoroll_from_mp3(mp3_path).to(device)
|
| 87 |
+
x2 = get_cqt_from_mp3(mp3_path).to(device)
|
| 88 |
inp_data = [x1, x2]
|
| 89 |
else:
|
| 90 |
raise ValueError(f"Representation {rep} not supported")
|
|
|
|
| 92 |
preds = []
|
| 93 |
for cheks in checkpoint:
|
| 94 |
model.load_state_dict(cheks["model_state_dict"])
|
| 95 |
+
model.eval()
|
| 96 |
with torch.inference_mode():
|
| 97 |
logits = model(inp_data, None)
|
| 98 |
pred = prediction2label(logits).item()
|
| 99 |
preds.append(pred)
|
| 100 |
|
| 101 |
return mean(preds)
|
|
|
|
| 102 |
|
| 103 |
if __name__ == "__main__":
|
| 104 |
mp3_path = "yt_audio.mp3"
|
| 105 |
+
model_name = "audio_midi_multi_ps_v5"
|
| 106 |
+
pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5")
|
| 107 |
+
print(f"Predicción dificultad multimodal: {pred_multi}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
CHANGED
|
@@ -212,127 +212,6 @@ class AudioModel(nn.Module):
|
|
| 212 |
return x
|
| 213 |
|
| 214 |
|
| 215 |
-
def get_mse_macro(y_true, y_pred):
|
| 216 |
-
mse_each_class = []
|
| 217 |
-
for true_class in set(y_true):
|
| 218 |
-
tt, pp = zip(*[[tt, pp] for tt, pp in zip(y_true, y_pred) if tt == true_class])
|
| 219 |
-
mse_each_class.append(mean_squared_error(y_true=tt, y_pred=pp))
|
| 220 |
-
return mean(mse_each_class)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def get_cqt(rep, k):
|
| 224 |
-
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 225 |
-
inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
|
| 226 |
-
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
| 227 |
-
return inp_data
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def get_pianoroll(rep, k):
|
| 231 |
-
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 232 |
-
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
| 233 |
-
inp_pr = torch.from_numpy(inp_pr).float().cpu()
|
| 234 |
-
inp_on = torch.from_numpy(inp_on).float().cpu()
|
| 235 |
-
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
| 236 |
-
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
| 237 |
-
return inp_data
|
| 238 |
-
|
| 239 |
-
def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_pr=False):
|
| 240 |
-
seed = 42
|
| 241 |
-
np.random.seed(seed)
|
| 242 |
-
torch.manual_seed(seed)
|
| 243 |
-
if torch.cuda.is_available():
|
| 244 |
-
torch.cuda.manual_seed(seed)
|
| 245 |
-
data = utils.load_json("../videos_download/split_audio.json")
|
| 246 |
-
mse, acc = [], []
|
| 247 |
-
predictions = []
|
| 248 |
-
if only_cqt:
|
| 249 |
-
cache_name = model_name + "_cqt"
|
| 250 |
-
elif only_pr:
|
| 251 |
-
cache_name = model_name + "_pr"
|
| 252 |
-
else:
|
| 253 |
-
cache_name = model_name
|
| 254 |
-
if not os.path.exists(f"cache/{cache_name}.json"):
|
| 255 |
-
for split in range(5):
|
| 256 |
-
#load_model
|
| 257 |
-
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
| 258 |
-
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
|
| 259 |
-
# print(checkpoint["epoch"])
|
| 260 |
-
# print(checkpoint.keys())
|
| 261 |
-
|
| 262 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 263 |
-
model = model.cpu()
|
| 264 |
-
pred_labels, true_labels = [], []
|
| 265 |
-
predictions_split = {}
|
| 266 |
-
model.eval()
|
| 267 |
-
with torch.inference_mode():
|
| 268 |
-
for k, ps in data[str(split)]["test"].items():
|
| 269 |
-
# computar el modelo
|
| 270 |
-
if "cqt" in rep:
|
| 271 |
-
inp_data = get_cqt(rep, k)
|
| 272 |
-
elif "pianoroll" in rep:
|
| 273 |
-
inp_data = get_pianoroll(rep, k)
|
| 274 |
-
elif rep == "multimodal5":
|
| 275 |
-
x1 = get_pianoroll("pianoroll5", k)
|
| 276 |
-
x2 = get_cqt("cqt5", k)[:, :, :x1.shape[2]]
|
| 277 |
-
inp_data = [x1, x2]
|
| 278 |
-
log_prob = model(inp_data, None)
|
| 279 |
-
pred = prediction2label(log_prob).cpu().tolist()[0]
|
| 280 |
-
print(k, ps, pred)
|
| 281 |
-
predictions_split[k] = {
|
| 282 |
-
"true": ps,
|
| 283 |
-
"pred": pred
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
-
true_labels.append(ps)
|
| 287 |
-
pred_labels.append(pred)
|
| 288 |
-
|
| 289 |
-
predictions.append(predictions_split)
|
| 290 |
-
mse.append(get_mse_macro(true_labels, pred_labels))
|
| 291 |
-
acc.append(balanced_accuracy_score(true_labels, pred_labels))
|
| 292 |
-
# with one decimal
|
| 293 |
-
print(f"mse: {mean(mse):.1f}({stdev(mse):.1f})", end=" ")
|
| 294 |
-
print(f"acc: {mean(acc)*100:.1f}({stdev(acc)*100:.1f})")
|
| 295 |
-
utils.save_json({
|
| 296 |
-
"mse": mse,
|
| 297 |
-
"acc": acc,
|
| 298 |
-
"predictions": predictions
|
| 299 |
-
}, f"cache/{cache_name}.json")
|
| 300 |
-
else:
|
| 301 |
-
data = utils.load_json(f"cache/{cache_name}.json")
|
| 302 |
-
tau_c, mse, acc = [], [], []
|
| 303 |
-
for i in range(5):
|
| 304 |
-
pred, true = [], []
|
| 305 |
-
for k, dd in data["predictions"][i].items():
|
| 306 |
-
pred.append(dd["pred"])
|
| 307 |
-
true.append(dd["true"])
|
| 308 |
-
tau_c.append(kendalltau(x=true, y=pred).statistic)
|
| 309 |
-
mse.append(get_mse_macro(true, pred))
|
| 310 |
-
acc.append(balanced_accuracy_score(true, pred))
|
| 311 |
-
print(model_name, end="// ")
|
| 312 |
-
print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
|
| 313 |
-
print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
|
| 314 |
-
print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
def compute_ensemble(truncate=False):
|
| 318 |
-
round_func = lambda x: math.ceil(x) if truncate else math.floor(x)
|
| 319 |
-
data_pr = utils.load_json(f"cache/audio_midi_cqt5_ps_v5.json")
|
| 320 |
-
data_cqt = utils.load_json(f"cache/audio_midi_pianoroll_ps_5_v4.json")
|
| 321 |
-
tau_c, mse, acc = [], [], []
|
| 322 |
-
for i in range(5):
|
| 323 |
-
pred, true = [], []
|
| 324 |
-
for k, dd in data_pr["predictions"][i].items():
|
| 325 |
-
cqt_pred = data_cqt["predictions"][i][k]
|
| 326 |
-
pred.append(round_func((dd["pred"] + cqt_pred["pred"])/2))
|
| 327 |
-
true.append(dd["true"])
|
| 328 |
-
tau_c.append(kendalltau(x=true, y=pred).statistic)
|
| 329 |
-
mse.append(get_mse_macro(true, pred))
|
| 330 |
-
acc.append(balanced_accuracy_score(true, pred))
|
| 331 |
-
print("ensemble", end="// ")
|
| 332 |
-
print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
|
| 333 |
-
print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
|
| 334 |
-
print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
|
| 335 |
-
|
| 336 |
|
| 337 |
def load_json(name_file):
|
| 338 |
with open(name_file, 'r') as fp:
|
|
|
|
| 212 |
return x
|
| 213 |
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
def load_json(name_file):
|
| 217 |
with open(name_file, 'r') as fp:
|
temp.mid
CHANGED
|
Binary files a/temp.mid and b/temp.mid differ
|
|
|