| | import os
|
| | import pretty_midi
|
| | import mir_eval
|
| | import numpy as np
|
| | import pandas as pd
|
| |
|
| |
|
| | DIR_GT = "dataset_evaluación/midis_gt"
|
| | DIR_PRED_MY = "dataset_evaluación/resultados_midis/cornetai"
|
| | DIR_PRED_OFF = "dataset_evaluación/resultados_midis/official_bp"
|
| |
|
| |
|
| | ONSET_TOL = 0.150
|
| | OFFSET_RATIO_VAL = 0.5
|
| | FS = 100
|
| |
|
| | def get_frame_accuracy(pm_ref, pm_est):
|
| | """Calcula la precisión por frame (Acc) comparando piano rolls."""
|
| |
|
| | pr_ref = (pm_ref.get_piano_roll(fs=FS) > 0).astype(int)
|
| | pr_est = (pm_est.get_piano_roll(fs=FS) > 0).astype(int)
|
| |
|
| |
|
| | max_len = max(pr_ref.shape[1], pr_est.shape[1])
|
| | pr_ref = np.pad(pr_ref, ((0,0), (0, max_len - pr_ref.shape[1])))
|
| | pr_est = np.pad(pr_est, ((0,0), (0, max_len - pr_est.shape[1])))
|
| |
|
| |
|
| | tp = np.sum((pr_ref == 1) & (pr_est == 1))
|
| | fp = np.sum((pr_ref == 0) & (pr_est == 1))
|
| | fn = np.sum((pr_ref == 1) & (pr_est == 0))
|
| |
|
| | return (tp / (tp + fp + fn)) * 100 if (tp + fp + fn) > 0 else 0
|
| |
|
| | def get_full_metrics(path_ref, path_est):
|
| | try:
|
| | pm_ref = pretty_midi.PrettyMIDI(path_ref)
|
| | pm_est = pretty_midi.PrettyMIDI(path_est)
|
| |
|
| |
|
| | ref_int = np.array([[n.start, n.end] for n in pm_ref.instruments[0].notes])
|
| | ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_ref.instruments[0].notes])
|
| | est_int = np.array([[n.start, n.end] for n in pm_est.instruments[0].notes])
|
| | est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_est.instruments[0].notes])
|
| |
|
| | sc = mir_eval.transcription.evaluate(ref_int, ref_pit, est_int, est_pit,
|
| | onset_tolerance=ONSET_TOL, offset_ratio=OFFSET_RATIO_VAL)
|
| |
|
| |
|
| | acc = get_frame_accuracy(pm_ref, pm_est)
|
| |
|
| | return acc, sc['F-measure_no_offset'] * 100, sc['F-measure'] * 100
|
| | except Exception: return 0, 0, 0
|
| |
|
| | def main():
|
| | print("--- 📊 EVALUACIÓN FINAL TFG (ESTILO PAPER SPOTIFY) ---")
|
| | res = []
|
| | gts = [f for f in os.listdir(DIR_GT) if f.endswith(".mid")]
|
| |
|
| | for gt_file in gts:
|
| | name = os.path.splitext(gt_file)[0]
|
| | c_acc, c_fno, c_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_MY, name + ".mid"))
|
| | o_acc, o_fno, o_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_OFF, name + ".mid"))
|
| |
|
| | res.append({
|
| | "Archivo": name,
|
| | "CAI_Acc": c_acc, "CAI_Fno": c_fno, "CAI_F": c_f,
|
| | "OFF_Acc": o_acc, "OFF_Fno": o_fno, "OFF_F": o_f
|
| | })
|
| |
|
| | df = pd.DataFrame(res)
|
| | m = df.mean(numeric_only=True)
|
| |
|
| | print("\n" + "="*50)
|
| | print(f"{'Model':<20} | {'Acc':<8} | {'Fno':<8} | {'F':<8}")
|
| | print("-" * 50)
|
| | print(f"{'CornetAI (V3)':<20} | {m['CAI_Acc']:<8.2f} | {m['CAI_Fno']:<8.2f} | {m['CAI_F']:<8.2f}")
|
| | print(f"{'Basic Pitch':<20} | {m['OFF_Acc']:<8.2f} | {m['OFF_Fno']:<8.2f} | {m['OFF_F']:<8.2f}")
|
| | print("="*50)
|
| |
|
| | df.to_csv("Tabla_TFG_Estilo_Paper.csv", index=False)
|
| | print("\n Tabla final generada: 'Tabla_TFG_Estilo_Paper.csv'")
|
| |
|
| | if __name__ == "__main__":
|
| | main() |