import os import sys import numpy as np import pretty_midi import mir_eval import matplotlib.pyplot as plt from basic_pitch.inference import predict # --- CONFIGURACIÓN --- MODEL_PATH = "CornetAI_SavedModel" ONSET_TOL = 0.150 CEILING_F1 = 0.858 def get_friendly_score(f1_value): score = (f1_value / CEILING_F1) * 10 return min(10.0, round(score, 1)) def midi_a_espanol_corneta(midi_num): """Mapeo específico simplificado para corneta.""" n = midi_num % 12 mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'} nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si'] return mapping.get(n, nombres_base[n]) def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename): """Genera la comparativa y la guarda como imagen con el informe incluido.""" fig = plt.figure(figsize=(15, 8)) ax = fig.add_axes([0.1, 0.1, 0.6, 0.8]) # 1. MIDI DE REFERENCIA for i, note in enumerate(pm_ref.instruments[0].notes): ax.barh(note.pitch, note.end - note.start, left=note.start, height=0.4, color='green', alpha=0.3, edgecolor='black', linewidth=1.5, label="Partitura (Referencia)" if i == 0 else "") # 2. MIDI DE EJECUCIÓN for i, note in enumerate(pm_est.instruments[0].notes): ax.barh(note.pitch, note.end - note.start, left=note.start, height=0.4, color='blue', alpha=0.5, label="Tu Ejecución (CornetAI)" if i == 0 else "") # Obtener todas las notas MIDI únicas y crear etiquetas en español all_pitches = set() for note in pm_ref.instruments[0].notes: all_pitches.add(note.pitch) for note in pm_est.instruments[0].notes: all_pitches.add(note.pitch) all_pitches = sorted(all_pitches) ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches] ax.set_yticks(all_pitches) ax.set_yticklabels(ytick_labels) ax.set_xlabel("Tiempo (segundos)") ax.set_ylabel("Nota") ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}") ax.legend(loc='upper left') ax.grid(axis='x', linestyle='--', alpha=0.3) # 3. PANEL DE TEXTO (INFORME) info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n" info_text += "CORRECCIONES:\n" info_text += fallos_texto fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue') fig.text(0.72, 0.5, info_text, fontsize=12, va='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) img_name = output_filename.replace(".wav", "_resultado.png") plt.savefig(img_name, dpi=300, bbox_inches='tight') print(f"✅ Imagen de resultados guardada como: {img_name}") plt.show() def evaluar_ejecucion(audio_path, midi_gt_path): if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path): print("❌ Error: Archivos no encontrados.") return print(f"Analizando interpretación...") try: _, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH) except Exception as e: print(f"Error en la inferencia: {e}") return pm_ref = pretty_midi.PrettyMIDI(midi_gt_path) pm_est = midi_data ref_notes = pm_ref.instruments[0].notes est_notes = pm_est.instruments[0].notes if not est_notes: print("No se han detectado notas.") return ref_int = np.array([[n.start, n.end] for n in ref_notes]) ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes]) est_int = np.array([[n.start, n.end] for n in est_notes]) est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes]) metrics = mir_eval.transcription.evaluate( ref_int, ref_pit, est_int, est_pit, onset_tolerance=ONSET_TOL, offset_ratio=None ) puntuacion = get_friendly_score(metrics['F-measure_no_offset']) # Identificar notas falladas basándose en el onset (150ms) fallados_idx = [] for i, ref_note in enumerate(ref_notes): ref_onset = ref_note.start # Buscar si hay alguna nota estimada dentro del margen de tolerancia tiene_match_temporal = False for est_note in est_notes: est_onset = est_note.start if abs(ref_onset - est_onset) <= ONSET_TOL: tiene_match_temporal = True break if not tiene_match_temporal: fallados_idx.append(i) #Lista de fallos fallos_lista = [] print("\n" + "="*45) print(f"EVALUACIÓN DE CORNETAI") print("="*45) print(f">> NOTA FINAL: {puntuacion} / 10 <<") print("="*45) if fallados_idx: for idx in fallados_idx[:8]: #Máximo 8 fallos listados nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch) tiempo = round(ref_notes[idx].start, 2) fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)") print(f" - Revisa el {nota_es} en el segundo {tiempo}") if len(fallados_idx) > 6: fallos_lista.append(f"... y {len(fallados_idx)-6} más.") else: fallos_lista.append("¡Interpretación Perfecta!") fallos_texto = "\n".join(fallos_lista) plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path)) if __name__ == "__main__": if len(sys.argv) == 3: evaluar_ejecucion(sys.argv[1], sys.argv[2]) else: print("Uso: python evaluador_individual.py ")