File size: 5,789 Bytes
4f8d622 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import os
import sys
import numpy as np
import pretty_midi
import mir_eval
import matplotlib.pyplot as plt
from basic_pitch.inference import predict
# --- CONFIGURACIÓN ---
MODEL_PATH = "CornetAI_SavedModel"
ONSET_TOL = 0.150
CEILING_F1 = 0.858
def get_friendly_score(f1_value):
score = (f1_value / CEILING_F1) * 10
return min(10.0, round(score, 1))
def midi_a_espanol_corneta(midi_num):
"""Mapeo específico simplificado para corneta."""
n = midi_num % 12
mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'}
nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si']
return mapping.get(n, nombres_base[n])
def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename):
"""Genera la comparativa y la guarda como imagen con el informe incluido."""
fig = plt.figure(figsize=(15, 8))
ax = fig.add_axes([0.1, 0.1, 0.6, 0.8])
# 1. MIDI DE REFERENCIA
for i, note in enumerate(pm_ref.instruments[0].notes):
ax.barh(note.pitch, note.end - note.start, left=note.start,
height=0.4, color='green', alpha=0.3,
edgecolor='black', linewidth=1.5,
label="Partitura (Referencia)" if i == 0 else "")
# 2. MIDI DE EJECUCIÓN
for i, note in enumerate(pm_est.instruments[0].notes):
ax.barh(note.pitch, note.end - note.start, left=note.start,
height=0.4, color='blue', alpha=0.5,
label="Tu Ejecución (CornetAI)" if i == 0 else "")
# Obtener todas las notas MIDI únicas y crear etiquetas en español
all_pitches = set()
for note in pm_ref.instruments[0].notes:
all_pitches.add(note.pitch)
for note in pm_est.instruments[0].notes:
all_pitches.add(note.pitch)
all_pitches = sorted(all_pitches)
ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches]
ax.set_yticks(all_pitches)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Tiempo (segundos)")
ax.set_ylabel("Nota")
ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}")
ax.legend(loc='upper left')
ax.grid(axis='x', linestyle='--', alpha=0.3)
# 3. PANEL DE TEXTO (INFORME)
info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n"
info_text += "CORRECCIONES:\n"
info_text += fallos_texto
fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue')
fig.text(0.72, 0.5, info_text, fontsize=12, va='center',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
img_name = output_filename.replace(".wav", "_resultado.png")
plt.savefig(img_name, dpi=300, bbox_inches='tight')
print(f"✅ Imagen de resultados guardada como: {img_name}")
plt.show()
def evaluar_ejecucion(audio_path, midi_gt_path):
if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path):
print("❌ Error: Archivos no encontrados.")
return
print(f"Analizando interpretación...")
try:
_, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH)
except Exception as e:
print(f"Error en la inferencia: {e}")
return
pm_ref = pretty_midi.PrettyMIDI(midi_gt_path)
pm_est = midi_data
ref_notes = pm_ref.instruments[0].notes
est_notes = pm_est.instruments[0].notes
if not est_notes:
print("No se han detectado notas.")
return
ref_int = np.array([[n.start, n.end] for n in ref_notes])
ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes])
est_int = np.array([[n.start, n.end] for n in est_notes])
est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes])
metrics = mir_eval.transcription.evaluate(
ref_int, ref_pit, est_int, est_pit,
onset_tolerance=ONSET_TOL, offset_ratio=None
)
puntuacion = get_friendly_score(metrics['F-measure_no_offset'])
# Identificar notas falladas basándose en el onset (150ms)
fallados_idx = []
for i, ref_note in enumerate(ref_notes):
ref_onset = ref_note.start
# Buscar si hay alguna nota estimada dentro del margen de tolerancia
tiene_match_temporal = False
for est_note in est_notes:
est_onset = est_note.start
if abs(ref_onset - est_onset) <= ONSET_TOL:
tiene_match_temporal = True
break
if not tiene_match_temporal:
fallados_idx.append(i)
#Lista de fallos
fallos_lista = []
print("\n" + "="*45)
print(f"EVALUACIÓN DE CORNETAI")
print("="*45)
print(f">> NOTA FINAL: {puntuacion} / 10 <<")
print("="*45)
if fallados_idx:
for idx in fallados_idx[:8]: #Máximo 8 fallos listados
nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch)
tiempo = round(ref_notes[idx].start, 2)
fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)")
print(f" - Revisa el {nota_es} en el segundo {tiempo}")
if len(fallados_idx) > 6:
fallos_lista.append(f"... y {len(fallados_idx)-6} más.")
else:
fallos_lista.append("¡Interpretación Perfecta!")
fallos_texto = "\n".join(fallos_lista)
plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path))
if __name__ == "__main__":
if len(sys.argv) == 3:
evaluar_ejecucion(sys.argv[1], sys.argv[2])
else:
print("Uso: python evaluador_individual.py <audio.wav> <referencia.mid>") |