CornetAI / src /entrenar.py
jmp684's picture
Upload 12 files
4f8d622 verified
import os
import numpy as np
import tensorflow as tf
import pretty_midi
import librosa
from basic_pitch.models import model as bp_model
from basic_pitch import ICASSP_2022_MODEL_PATH
from sklearn.model_selection import train_test_split
# --- CONFIGURACIÓN DE RUTAS ---
PATH_DATASET_WAV = "dataset_wavs"
PATH_DATASET_MIDI = "dataset_midis"
MODEL_SAVE_PATH = "CornetaAI.h5"
# IDENTIFICADOR DE ARCHIVOS REALES
FILTRO_REAL = "ejercicio"
# Parámetros Técnicos
SAMPLE_RATE = 22050
SAMPLES_PER_CLIP = 43844
HOP_LENGTH = 256
ANNOTATIONS_FPS = SAMPLE_RATE / HOP_LENGTH
N_FREQ_BINS = 88
N_FREQ_BINS_CONTOUR = 264
TRAIN_DURATION = SAMPLES_PER_CLIP / SAMPLE_RATE
# Hiperparámetros Optimizados
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 0.0003
# --- FUNCIÓN DE PÉRDIDA BALANCEADA (95/5) ---
def weighted_binary_crossentropy(y_true, y_pred):
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
loss = -(0.95 * y_true * tf.math.log(y_pred) + 0.05 * (1.0 - y_true) * tf.math.log(1.0 - y_pred))
return tf.reduce_mean(loss)
def get_data(wav, midi, frames_output, augment=False):
"""Carga audio/midi con Pitch Shifting aleatorio para robustez."""
try:
audio, _ = librosa.load(wav, sr=SAMPLE_RATE)
if len(audio) < SAMPLES_PER_CLIP:
audio = np.pad(audio, (0, SAMPLES_PER_CLIP - len(audio)))
max_start = len(audio) - SAMPLES_PER_CLIP
start_sample = np.random.randint(0, max_start + 1)
audio_crop = audio[start_sample : start_sample + SAMPLES_PER_CLIP]
if augment:
n_steps = np.random.uniform(-0.5, 0.5)
audio_crop = librosa.effects.pitch_shift(audio_crop, sr=SAMPLE_RATE, n_steps=n_steps)
audio_crop += np.random.normal(0, 0.001, audio_crop.shape)
if np.max(np.abs(audio_crop)) > 0:
audio_crop = audio_crop / np.max(np.abs(audio_crop))
pm = pretty_midi.PrettyMIDI(midi)
start_time, end_time = start_sample / SAMPLE_RATE, (start_sample / SAMPLE_RATE) + TRAIN_DURATION
targets = {
"note": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
"onset": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
"contour": np.zeros((frames_output, N_FREQ_BINS_CONTOUR), dtype=np.float32)
}
for note in pm.instruments[0].notes:
if note.end < start_time or note.start > end_time: continue
rel_start, rel_end = max(0.0, note.start - start_time), min(TRAIN_DURATION, note.end - start_time)
s, e = int(rel_start * ANNOTATIONS_FPS), int(rel_end * ANNOTATIONS_FPS)
p = note.pitch - 21
if 0 <= p < N_FREQ_BINS:
s, e = max(0, min(s, frames_output - 1)), max(0, min(e, frames_output))
if s < e:
targets["note"][s:e, p] = 1.0
if note.start >= start_time:
targets["onset"][max(0, s-1):min(frames_output, s+2), p] = 1.0
c = p * 3 + 1
targets["contour"][s:e, c] = 1.0
if c > 0: targets["contour"][s:e, c-1] = 0.5
if c < 263: targets["contour"][s:e, c+1] = 0.5
return audio_crop[:, np.newaxis], targets
except Exception: return None, None
class MasterGenerator(tf.keras.utils.Sequence):
def __init__(self, wavs, midis, frames_output, augment=False):
self.wavs, self.midis, self.frames_output, self.augment = wavs, midis, frames_output, augment
self.indices = np.arange(len(self.wavs))
self.on_epoch_end()
def __len__(self): return int(np.ceil(len(self.wavs)/BATCH_SIZE))
def __getitem__(self, idx):
batch_indices = self.indices[idx * BATCH_SIZE : (idx + 1) * BATCH_SIZE]
audios, notes, onsets, contours = [], [], [], []
for i in batch_indices:
a, t = get_data(self.wavs[i], self.midis[i], self.frames_output, augment=self.augment)
if a is not None:
audios.append(a); notes.append(t["note"]); onsets.append(t["onset"]); contours.append(t["contour"])
return np.array(audios), {"note": np.array(notes), "onset": np.array(onsets), "contour": np.array(contours)}
def on_epoch_end(self): np.random.shuffle(self.indices)
if __name__ == "__main__":
print("--- FINE-TUNING: VALIDACIÓN REAL PURA (25 Train / 10 Val) ---")
# 1. Cargar y clasificar archivos
all_wavs = [f for f in os.listdir(PATH_DATASET_WAV) if f.endswith(".wav")]
wav_real, midi_real, wav_sint, midi_sint = [], [], [], []
for w in all_wavs:
m = w.replace(".wav", ".mid")
if os.path.exists(os.path.join(PATH_DATASET_MIDI, m)):
path_w, path_m = os.path.join(PATH_DATASET_WAV, w), os.path.join(PATH_DATASET_MIDI, m)
if FILTRO_REAL.lower() in w.lower():
wav_real.append(path_w); midi_real.append(path_m)
else:
wav_sint.append(path_w); midi_sint.append(path_m)
# 2. Split de Reales (25 para entrenar, 10 para validar)
tr_rw, val_w, tr_rm, val_m = train_test_split(wav_real, midi_real, test_size=10, random_state=42)
# 3. Mezcla de Entrenamiento (500 sintéticos + 25 reales)
train_w, train_m = wav_sint + tr_rw, midi_sint + tr_rm
# 4. Configurar Modelo
model = bp_model()
model.load_weights(ICASSP_2022_MODEL_PATH)
frames_out = model(np.zeros((1, SAMPLES_PER_CLIP, 1))).get('note').shape[1]
for l in model.layers: l.trainable = 'cqt' not in l.name
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
loss={"note": "binary_crossentropy", "onset": weighted_binary_crossentropy, "contour": "binary_crossentropy"},
loss_weights={"note": 1.0, "onset": 1.5, "contour": 0.5})
train_gen = MasterGenerator(train_w, train_m, frames_out, augment=True)
val_gen = MasterGenerator(val_w, val_m, frames_out, augment=False)
callbacks = [
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True)
]
print(f"📊 Train: {len(train_w)} (500 sint + 25 real) | Val: {len(val_w)} (10 real).")
model.fit(train_gen, validation_data=val_gen, epochs=EPOCHS, callbacks=callbacks)