| | import os
|
| | import sys
|
| | import numpy as np
|
| | import pretty_midi
|
| | import mir_eval
|
| | import matplotlib.pyplot as plt
|
| | from basic_pitch.inference import predict
|
| |
|
| |
|
| | MODEL_PATH = "CornetAI_SavedModel"
|
| | ONSET_TOL = 0.150
|
| | CEILING_F1 = 0.858
|
| |
|
| | def get_friendly_score(f1_value):
|
| | score = (f1_value / CEILING_F1) * 10
|
| | return min(10.0, round(score, 1))
|
| |
|
| | def midi_a_espanol_corneta(midi_num):
|
| | """Mapeo específico simplificado para corneta."""
|
| | n = midi_num % 12
|
| | mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'}
|
| | nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si']
|
| | return mapping.get(n, nombres_base[n])
|
| |
|
| | def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename):
|
| | """Genera la comparativa y la guarda como imagen con el informe incluido."""
|
| | fig = plt.figure(figsize=(15, 8))
|
| | ax = fig.add_axes([0.1, 0.1, 0.6, 0.8])
|
| |
|
| |
|
| | for i, note in enumerate(pm_ref.instruments[0].notes):
|
| | ax.barh(note.pitch, note.end - note.start, left=note.start,
|
| | height=0.4, color='green', alpha=0.3,
|
| | edgecolor='black', linewidth=1.5,
|
| | label="Partitura (Referencia)" if i == 0 else "")
|
| |
|
| |
|
| | for i, note in enumerate(pm_est.instruments[0].notes):
|
| | ax.barh(note.pitch, note.end - note.start, left=note.start,
|
| | height=0.4, color='blue', alpha=0.5,
|
| | label="Tu Ejecución (CornetAI)" if i == 0 else "")
|
| |
|
| |
|
| | all_pitches = set()
|
| | for note in pm_ref.instruments[0].notes:
|
| | all_pitches.add(note.pitch)
|
| | for note in pm_est.instruments[0].notes:
|
| | all_pitches.add(note.pitch)
|
| |
|
| | all_pitches = sorted(all_pitches)
|
| | ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches]
|
| |
|
| | ax.set_yticks(all_pitches)
|
| | ax.set_yticklabels(ytick_labels)
|
| |
|
| | ax.set_xlabel("Tiempo (segundos)")
|
| | ax.set_ylabel("Nota")
|
| | ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}")
|
| | ax.legend(loc='upper left')
|
| | ax.grid(axis='x', linestyle='--', alpha=0.3)
|
| |
|
| |
|
| | info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n"
|
| | info_text += "CORRECCIONES:\n"
|
| | info_text += fallos_texto
|
| |
|
| | fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue')
|
| | fig.text(0.72, 0.5, info_text, fontsize=12, va='center',
|
| | bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| |
|
| | img_name = output_filename.replace(".wav", "_resultado.png")
|
| | plt.savefig(img_name, dpi=300, bbox_inches='tight')
|
| | print(f"✅ Imagen de resultados guardada como: {img_name}")
|
| |
|
| | plt.show()
|
| |
|
| | def evaluar_ejecucion(audio_path, midi_gt_path):
|
| | if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path):
|
| | print("❌ Error: Archivos no encontrados.")
|
| | return
|
| |
|
| | print(f"Analizando interpretación...")
|
| |
|
| | try:
|
| | _, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH)
|
| | except Exception as e:
|
| | print(f"Error en la inferencia: {e}")
|
| | return
|
| |
|
| | pm_ref = pretty_midi.PrettyMIDI(midi_gt_path)
|
| | pm_est = midi_data
|
| |
|
| | ref_notes = pm_ref.instruments[0].notes
|
| | est_notes = pm_est.instruments[0].notes
|
| |
|
| | if not est_notes:
|
| | print("No se han detectado notas.")
|
| | return
|
| |
|
| | ref_int = np.array([[n.start, n.end] for n in ref_notes])
|
| | ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes])
|
| | est_int = np.array([[n.start, n.end] for n in est_notes])
|
| | est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes])
|
| |
|
| | metrics = mir_eval.transcription.evaluate(
|
| | ref_int, ref_pit, est_int, est_pit,
|
| | onset_tolerance=ONSET_TOL, offset_ratio=None
|
| | )
|
| |
|
| | puntuacion = get_friendly_score(metrics['F-measure_no_offset'])
|
| |
|
| |
|
| | fallados_idx = []
|
| | for i, ref_note in enumerate(ref_notes):
|
| | ref_onset = ref_note.start
|
| |
|
| | tiene_match_temporal = False
|
| | for est_note in est_notes:
|
| | est_onset = est_note.start
|
| | if abs(ref_onset - est_onset) <= ONSET_TOL:
|
| | tiene_match_temporal = True
|
| | break
|
| |
|
| | if not tiene_match_temporal:
|
| | fallados_idx.append(i)
|
| |
|
| | fallos_lista = []
|
| | print("\n" + "="*45)
|
| | print(f"EVALUACIÓN DE CORNETAI")
|
| | print("="*45)
|
| | print(f">> NOTA FINAL: {puntuacion} / 10 <<")
|
| | print("="*45)
|
| |
|
| | if fallados_idx:
|
| | for idx in fallados_idx[:8]:
|
| | nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch)
|
| | tiempo = round(ref_notes[idx].start, 2)
|
| | fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)")
|
| | print(f" - Revisa el {nota_es} en el segundo {tiempo}")
|
| |
|
| | if len(fallados_idx) > 6:
|
| | fallos_lista.append(f"... y {len(fallados_idx)-6} más.")
|
| | else:
|
| | fallos_lista.append("¡Interpretación Perfecta!")
|
| |
|
| | fallos_texto = "\n".join(fallos_lista)
|
| |
|
| | plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path))
|
| |
|
| | if __name__ == "__main__":
|
| | if len(sys.argv) == 3:
|
| | evaluar_ejecucion(sys.argv[1], sys.argv[2])
|
| | else:
|
| | print("Uso: python evaluador_individual.py <audio.wav> <referencia.mid>") |