CornetAI / src /calificador.py
jmp684's picture
Upload 12 files
4f8d622 verified
import os
import sys
import numpy as np
import pretty_midi
import mir_eval
import matplotlib.pyplot as plt
from basic_pitch.inference import predict
# --- CONFIGURACIÓN ---
MODEL_PATH = "CornetAI_SavedModel"
ONSET_TOL = 0.150
CEILING_F1 = 0.858
def get_friendly_score(f1_value):
score = (f1_value / CEILING_F1) * 10
return min(10.0, round(score, 1))
def midi_a_espanol_corneta(midi_num):
"""Mapeo específico simplificado para corneta."""
n = midi_num % 12
mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'}
nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si']
return mapping.get(n, nombres_base[n])
def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename):
"""Genera la comparativa y la guarda como imagen con el informe incluido."""
fig = plt.figure(figsize=(15, 8))
ax = fig.add_axes([0.1, 0.1, 0.6, 0.8])
# 1. MIDI DE REFERENCIA
for i, note in enumerate(pm_ref.instruments[0].notes):
ax.barh(note.pitch, note.end - note.start, left=note.start,
height=0.4, color='green', alpha=0.3,
edgecolor='black', linewidth=1.5,
label="Partitura (Referencia)" if i == 0 else "")
# 2. MIDI DE EJECUCIÓN
for i, note in enumerate(pm_est.instruments[0].notes):
ax.barh(note.pitch, note.end - note.start, left=note.start,
height=0.4, color='blue', alpha=0.5,
label="Tu Ejecución (CornetAI)" if i == 0 else "")
# Obtener todas las notas MIDI únicas y crear etiquetas en español
all_pitches = set()
for note in pm_ref.instruments[0].notes:
all_pitches.add(note.pitch)
for note in pm_est.instruments[0].notes:
all_pitches.add(note.pitch)
all_pitches = sorted(all_pitches)
ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches]
ax.set_yticks(all_pitches)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Tiempo (segundos)")
ax.set_ylabel("Nota")
ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}")
ax.legend(loc='upper left')
ax.grid(axis='x', linestyle='--', alpha=0.3)
# 3. PANEL DE TEXTO (INFORME)
info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n"
info_text += "CORRECCIONES:\n"
info_text += fallos_texto
fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue')
fig.text(0.72, 0.5, info_text, fontsize=12, va='center',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
img_name = output_filename.replace(".wav", "_resultado.png")
plt.savefig(img_name, dpi=300, bbox_inches='tight')
print(f"✅ Imagen de resultados guardada como: {img_name}")
plt.show()
def evaluar_ejecucion(audio_path, midi_gt_path):
if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path):
print("❌ Error: Archivos no encontrados.")
return
print(f"Analizando interpretación...")
try:
_, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH)
except Exception as e:
print(f"Error en la inferencia: {e}")
return
pm_ref = pretty_midi.PrettyMIDI(midi_gt_path)
pm_est = midi_data
ref_notes = pm_ref.instruments[0].notes
est_notes = pm_est.instruments[0].notes
if not est_notes:
print("No se han detectado notas.")
return
ref_int = np.array([[n.start, n.end] for n in ref_notes])
ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes])
est_int = np.array([[n.start, n.end] for n in est_notes])
est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes])
metrics = mir_eval.transcription.evaluate(
ref_int, ref_pit, est_int, est_pit,
onset_tolerance=ONSET_TOL, offset_ratio=None
)
puntuacion = get_friendly_score(metrics['F-measure_no_offset'])
# Identificar notas falladas basándose en el onset (150ms)
fallados_idx = []
for i, ref_note in enumerate(ref_notes):
ref_onset = ref_note.start
# Buscar si hay alguna nota estimada dentro del margen de tolerancia
tiene_match_temporal = False
for est_note in est_notes:
est_onset = est_note.start
if abs(ref_onset - est_onset) <= ONSET_TOL:
tiene_match_temporal = True
break
if not tiene_match_temporal:
fallados_idx.append(i)
#Lista de fallos
fallos_lista = []
print("\n" + "="*45)
print(f"EVALUACIÓN DE CORNETAI")
print("="*45)
print(f">> NOTA FINAL: {puntuacion} / 10 <<")
print("="*45)
if fallados_idx:
for idx in fallados_idx[:8]: #Máximo 8 fallos listados
nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch)
tiempo = round(ref_notes[idx].start, 2)
fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)")
print(f" - Revisa el {nota_es} en el segundo {tiempo}")
if len(fallados_idx) > 6:
fallos_lista.append(f"... y {len(fallados_idx)-6} más.")
else:
fallos_lista.append("¡Interpretación Perfecta!")
fallos_texto = "\n".join(fallos_lista)
plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path))
if __name__ == "__main__":
if len(sys.argv) == 3:
evaluar_ejecucion(sys.argv[1], sys.argv[2])
else:
print("Uso: python evaluador_individual.py <audio.wav> <referencia.mid>")