CornetAI / src /evaluador.py
jmp684's picture
Upload 12 files
4f8d622 verified
import os
import pretty_midi
import mir_eval
import numpy as np
import pandas as pd
# --- CONFIGURACIÓN ---
DIR_GT = "dataset_evaluación/midis_gt"
DIR_PRED_MY = "dataset_evaluación/resultados_midis/cornetai"
DIR_PRED_OFF = "dataset_evaluación/resultados_midis/official_bp"
# Tolerancias
ONSET_TOL = 0.150
OFFSET_RATIO_VAL = 0.5
FS = 100 # Frecuencia de muestreo para piano roll (hops de 10ms)
def get_frame_accuracy(pm_ref, pm_est):
"""Calcula la precisión por frame (Acc) comparando piano rolls."""
# Obtener piano rolls (binarizados)
pr_ref = (pm_ref.get_piano_roll(fs=FS) > 0).astype(int)
pr_est = (pm_est.get_piano_roll(fs=FS) > 0).astype(int)
# Alinear longitudes
max_len = max(pr_ref.shape[1], pr_est.shape[1])
pr_ref = np.pad(pr_ref, ((0,0), (0, max_len - pr_ref.shape[1])))
pr_est = np.pad(pr_est, ((0,0), (0, max_len - pr_est.shape[1])))
# Calcular TP, FP, FN a nivel de frame/pitch
tp = np.sum((pr_ref == 1) & (pr_est == 1))
fp = np.sum((pr_ref == 0) & (pr_est == 1))
fn = np.sum((pr_ref == 1) & (pr_est == 0))
return (tp / (tp + fp + fn)) * 100 if (tp + fp + fn) > 0 else 0
def get_full_metrics(path_ref, path_est):
try:
pm_ref = pretty_midi.PrettyMIDI(path_ref)
pm_est = pretty_midi.PrettyMIDI(path_est)
# 1. Note Metrics (Fno y F)
ref_int = np.array([[n.start, n.end] for n in pm_ref.instruments[0].notes])
ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_ref.instruments[0].notes])
est_int = np.array([[n.start, n.end] for n in pm_est.instruments[0].notes])
est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_est.instruments[0].notes])
sc = mir_eval.transcription.evaluate(ref_int, ref_pit, est_int, est_pit,
onset_tolerance=ONSET_TOL, offset_ratio=OFFSET_RATIO_VAL)
# 2. Frame Accuracy (Acc)
acc = get_frame_accuracy(pm_ref, pm_est)
return acc, sc['F-measure_no_offset'] * 100, sc['F-measure'] * 100
except Exception: return 0, 0, 0
def main():
print("--- 📊 EVALUACIÓN FINAL TFG (ESTILO PAPER SPOTIFY) ---")
res = []
gts = [f for f in os.listdir(DIR_GT) if f.endswith(".mid")]
for gt_file in gts:
name = os.path.splitext(gt_file)[0]
c_acc, c_fno, c_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_MY, name + ".mid"))
o_acc, o_fno, o_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_OFF, name + ".mid"))
res.append({
"Archivo": name,
"CAI_Acc": c_acc, "CAI_Fno": c_fno, "CAI_F": c_f,
"OFF_Acc": o_acc, "OFF_Fno": o_fno, "OFF_F": o_f
})
df = pd.DataFrame(res)
m = df.mean(numeric_only=True)
print("\n" + "="*50)
print(f"{'Model':<20} | {'Acc':<8} | {'Fno':<8} | {'F':<8}")
print("-" * 50)
print(f"{'CornetAI (V3)':<20} | {m['CAI_Acc']:<8.2f} | {m['CAI_Fno']:<8.2f} | {m['CAI_F']:<8.2f}")
print(f"{'Basic Pitch':<20} | {m['OFF_Acc']:<8.2f} | {m['OFF_Fno']:<8.2f} | {m['OFF_F']:<8.2f}")
print("="*50)
df.to_csv("Tabla_TFG_Estilo_Paper.csv", index=False)
print("\n Tabla final generada: 'Tabla_TFG_Estilo_Paper.csv'")
if __name__ == "__main__":
main()