File size: 3,458 Bytes
4f8d622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import pretty_midi
import mir_eval
import numpy as np
import pandas as pd

# --- CONFIGURACIÓN ---
DIR_GT = "dataset_evaluación/midis_gt"
DIR_PRED_MY = "dataset_evaluación/resultados_midis/cornetai"
DIR_PRED_OFF = "dataset_evaluación/resultados_midis/official_bp"

# Tolerancias
ONSET_TOL = 0.150  
OFFSET_RATIO_VAL = 0.5 
FS = 100  # Frecuencia de muestreo para piano roll (hops de 10ms)

def get_frame_accuracy(pm_ref, pm_est):
    """Calcula la precisión por frame (Acc) comparando piano rolls."""
    # Obtener piano rolls (binarizados)
    pr_ref = (pm_ref.get_piano_roll(fs=FS) > 0).astype(int)
    pr_est = (pm_est.get_piano_roll(fs=FS) > 0).astype(int)
    
    # Alinear longitudes
    max_len = max(pr_ref.shape[1], pr_est.shape[1])
    pr_ref = np.pad(pr_ref, ((0,0), (0, max_len - pr_ref.shape[1])))
    pr_est = np.pad(pr_est, ((0,0), (0, max_len - pr_est.shape[1])))
    
    # Calcular TP, FP, FN a nivel de frame/pitch
    tp = np.sum((pr_ref == 1) & (pr_est == 1))
    fp = np.sum((pr_ref == 0) & (pr_est == 1))
    fn = np.sum((pr_ref == 1) & (pr_est == 0))
    
    return (tp / (tp + fp + fn)) * 100 if (tp + fp + fn) > 0 else 0

def get_full_metrics(path_ref, path_est):
    try:
        pm_ref = pretty_midi.PrettyMIDI(path_ref)
        pm_est = pretty_midi.PrettyMIDI(path_est)
        
        # 1. Note Metrics (Fno y F)
        ref_int = np.array([[n.start, n.end] for n in pm_ref.instruments[0].notes])
        ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_ref.instruments[0].notes])
        est_int = np.array([[n.start, n.end] for n in pm_est.instruments[0].notes])
        est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_est.instruments[0].notes])
        
        sc = mir_eval.transcription.evaluate(ref_int, ref_pit, est_int, est_pit, 
                                            onset_tolerance=ONSET_TOL, offset_ratio=OFFSET_RATIO_VAL)
        
        # 2. Frame Accuracy (Acc)
        acc = get_frame_accuracy(pm_ref, pm_est)
        
        return acc, sc['F-measure_no_offset'] * 100, sc['F-measure'] * 100
    except Exception: return 0, 0, 0

def main():
    print("--- 📊 EVALUACIÓN FINAL TFG (ESTILO PAPER SPOTIFY) ---")
    res = []
    gts = [f for f in os.listdir(DIR_GT) if f.endswith(".mid")]

    for gt_file in gts:
        name = os.path.splitext(gt_file)[0]
        c_acc, c_fno, c_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_MY, name + ".mid"))
        o_acc, o_fno, o_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_OFF, name + ".mid"))
        
        res.append({
            "Archivo": name,
            "CAI_Acc": c_acc, "CAI_Fno": c_fno, "CAI_F": c_f,
            "OFF_Acc": o_acc, "OFF_Fno": o_fno, "OFF_F": o_f
        })

    df = pd.DataFrame(res)
    m = df.mean(numeric_only=True)
    
    print("\n" + "="*50)
    print(f"{'Model':<20} | {'Acc':<8} | {'Fno':<8} | {'F':<8}")
    print("-" * 50)
    print(f"{'CornetAI (V3)':<20} | {m['CAI_Acc']:<8.2f} | {m['CAI_Fno']:<8.2f} | {m['CAI_F']:<8.2f}")
    print(f"{'Basic Pitch':<20} | {m['OFF_Acc']:<8.2f} | {m['OFF_Fno']:<8.2f} | {m['OFF_F']:<8.2f}")
    print("="*50)
    
    df.to_csv("Tabla_TFG_Estilo_Paper.csv", index=False)
    print("\n Tabla final generada: 'Tabla_TFG_Estilo_Paper.csv'")

if __name__ == "__main__":
    main()