File size: 5,789 Bytes
4f8d622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import sys
import numpy as np
import pretty_midi
import mir_eval
import matplotlib.pyplot as plt
from basic_pitch.inference import predict

# --- CONFIGURACIÓN ---
MODEL_PATH = "CornetAI_SavedModel" 
ONSET_TOL = 0.150  
CEILING_F1 = 0.858 

def get_friendly_score(f1_value):
    score = (f1_value / CEILING_F1) * 10
    return min(10.0, round(score, 1))

def midi_a_espanol_corneta(midi_num):
    """Mapeo específico simplificado para corneta."""
    n = midi_num % 12
    mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'}
    nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si']
    return mapping.get(n, nombres_base[n])

def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename):
    """Genera la comparativa y la guarda como imagen con el informe incluido."""
    fig = plt.figure(figsize=(15, 8))
    ax = fig.add_axes([0.1, 0.1, 0.6, 0.8])
    
    # 1. MIDI DE REFERENCIA
    for i, note in enumerate(pm_ref.instruments[0].notes):
        ax.barh(note.pitch, note.end - note.start, left=note.start, 
                 height=0.4, color='green', alpha=0.3, 
                 edgecolor='black', linewidth=1.5,
                 label="Partitura (Referencia)" if i == 0 else "")
                 
    # 2. MIDI DE EJECUCIÓN
    for i, note in enumerate(pm_est.instruments[0].notes):
        ax.barh(note.pitch, note.end - note.start, left=note.start, 
                 height=0.4, color='blue', alpha=0.5, 
                 label="Tu Ejecución (CornetAI)" if i == 0 else "")
    
    # Obtener todas las notas MIDI únicas y crear etiquetas en español
    all_pitches = set()
    for note in pm_ref.instruments[0].notes:
        all_pitches.add(note.pitch)
    for note in pm_est.instruments[0].notes:
        all_pitches.add(note.pitch)
    
    all_pitches = sorted(all_pitches)
    ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches]
    
    ax.set_yticks(all_pitches)
    ax.set_yticklabels(ytick_labels)
                 
    ax.set_xlabel("Tiempo (segundos)")
    ax.set_ylabel("Nota")
    ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}")
    ax.legend(loc='upper left')
    ax.grid(axis='x', linestyle='--', alpha=0.3)

    # 3. PANEL DE TEXTO (INFORME)
    info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n"
    info_text += "CORRECCIONES:\n"
    info_text += fallos_texto
    
    fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue')
    fig.text(0.72, 0.5, info_text, fontsize=12, va='center', 
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    img_name = output_filename.replace(".wav", "_resultado.png")
    plt.savefig(img_name, dpi=300, bbox_inches='tight')
    print(f"✅ Imagen de resultados guardada como: {img_name}")
    
    plt.show()

def evaluar_ejecucion(audio_path, midi_gt_path):
    if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path):
        print("❌ Error: Archivos no encontrados.")
        return

    print(f"Analizando interpretación...")
    
    try:
        _, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH)
    except Exception as e:
        print(f"Error en la inferencia: {e}")
        return

    pm_ref = pretty_midi.PrettyMIDI(midi_gt_path)
    pm_est = midi_data 
    
    ref_notes = pm_ref.instruments[0].notes
    est_notes = pm_est.instruments[0].notes
    
    if not est_notes:
        print("No se han detectado notas.")
        return

    ref_int = np.array([[n.start, n.end] for n in ref_notes])
    ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes])
    est_int = np.array([[n.start, n.end] for n in est_notes])
    est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes])
    
    metrics = mir_eval.transcription.evaluate(
        ref_int, ref_pit, est_int, est_pit,
        onset_tolerance=ONSET_TOL, offset_ratio=None
    )
    
    puntuacion = get_friendly_score(metrics['F-measure_no_offset'])
    
    # Identificar notas falladas basándose en el onset (150ms)
    fallados_idx = []
    for i, ref_note in enumerate(ref_notes):
        ref_onset = ref_note.start
        # Buscar si hay alguna nota estimada dentro del margen de tolerancia
        tiene_match_temporal = False
        for est_note in est_notes:
            est_onset = est_note.start
            if abs(ref_onset - est_onset) <= ONSET_TOL:
                tiene_match_temporal = True
                break
        
        if not tiene_match_temporal:
            fallados_idx.append(i)
    #Lista de fallos
    fallos_lista = []
    print("\n" + "="*45)
    print(f"EVALUACIÓN DE CORNETAI")
    print("="*45)
    print(f">> NOTA FINAL: {puntuacion} / 10 <<")
    print("="*45)

    if fallados_idx:
        for idx in fallados_idx[:8]: #Máximo 8 fallos listados
            nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch)
            tiempo = round(ref_notes[idx].start, 2)
            fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)")
            print(f"   - Revisa el {nota_es} en el segundo {tiempo}")
        
        if len(fallados_idx) > 6:
            fallos_lista.append(f"... y {len(fallados_idx)-6} más.")
    else:
        fallos_lista.append("¡Interpretación Perfecta!")

    fallos_texto = "\n".join(fallos_lista)

    plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path))

if __name__ == "__main__":
    if len(sys.argv) == 3:
        evaluar_ejecucion(sys.argv[1], sys.argv[2])
    else:
        print("Uso: python evaluador_individual.py <audio.wav> <referencia.mid>")