import os import pretty_midi import mir_eval import numpy as np import pandas as pd # --- CONFIGURACIÓN --- DIR_GT = "dataset_evaluación/midis_gt" DIR_PRED_MY = "dataset_evaluación/resultados_midis/cornetai" DIR_PRED_OFF = "dataset_evaluación/resultados_midis/official_bp" # Tolerancias ONSET_TOL = 0.150 OFFSET_RATIO_VAL = 0.5 FS = 100 # Frecuencia de muestreo para piano roll (hops de 10ms) def get_frame_accuracy(pm_ref, pm_est): """Calcula la precisión por frame (Acc) comparando piano rolls.""" # Obtener piano rolls (binarizados) pr_ref = (pm_ref.get_piano_roll(fs=FS) > 0).astype(int) pr_est = (pm_est.get_piano_roll(fs=FS) > 0).astype(int) # Alinear longitudes max_len = max(pr_ref.shape[1], pr_est.shape[1]) pr_ref = np.pad(pr_ref, ((0,0), (0, max_len - pr_ref.shape[1]))) pr_est = np.pad(pr_est, ((0,0), (0, max_len - pr_est.shape[1]))) # Calcular TP, FP, FN a nivel de frame/pitch tp = np.sum((pr_ref == 1) & (pr_est == 1)) fp = np.sum((pr_ref == 0) & (pr_est == 1)) fn = np.sum((pr_ref == 1) & (pr_est == 0)) return (tp / (tp + fp + fn)) * 100 if (tp + fp + fn) > 0 else 0 def get_full_metrics(path_ref, path_est): try: pm_ref = pretty_midi.PrettyMIDI(path_ref) pm_est = pretty_midi.PrettyMIDI(path_est) # 1. Note Metrics (Fno y F) ref_int = np.array([[n.start, n.end] for n in pm_ref.instruments[0].notes]) ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_ref.instruments[0].notes]) est_int = np.array([[n.start, n.end] for n in pm_est.instruments[0].notes]) est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_est.instruments[0].notes]) sc = mir_eval.transcription.evaluate(ref_int, ref_pit, est_int, est_pit, onset_tolerance=ONSET_TOL, offset_ratio=OFFSET_RATIO_VAL) # 2. Frame Accuracy (Acc) acc = get_frame_accuracy(pm_ref, pm_est) return acc, sc['F-measure_no_offset'] * 100, sc['F-measure'] * 100 except Exception: return 0, 0, 0 def main(): print("--- 📊 EVALUACIÓN FINAL TFG (ESTILO PAPER SPOTIFY) ---") res = [] gts = [f for f in os.listdir(DIR_GT) if f.endswith(".mid")] for gt_file in gts: name = os.path.splitext(gt_file)[0] c_acc, c_fno, c_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_MY, name + ".mid")) o_acc, o_fno, o_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_OFF, name + ".mid")) res.append({ "Archivo": name, "CAI_Acc": c_acc, "CAI_Fno": c_fno, "CAI_F": c_f, "OFF_Acc": o_acc, "OFF_Fno": o_fno, "OFF_F": o_f }) df = pd.DataFrame(res) m = df.mean(numeric_only=True) print("\n" + "="*50) print(f"{'Model':<20} | {'Acc':<8} | {'Fno':<8} | {'F':<8}") print("-" * 50) print(f"{'CornetAI (V3)':<20} | {m['CAI_Acc']:<8.2f} | {m['CAI_Fno']:<8.2f} | {m['CAI_F']:<8.2f}") print(f"{'Basic Pitch':<20} | {m['OFF_Acc']:<8.2f} | {m['OFF_Fno']:<8.2f} | {m['OFF_F']:<8.2f}") print("="*50) df.to_csv("Tabla_TFG_Estilo_Paper.csv", index=False) print("\n Tabla final generada: 'Tabla_TFG_Estilo_Paper.csv'") if __name__ == "__main__": main()