import os import numpy as np import tensorflow as tf import pretty_midi import librosa from basic_pitch.models import model as bp_model from basic_pitch import ICASSP_2022_MODEL_PATH from sklearn.model_selection import train_test_split # --- CONFIGURACIÓN DE RUTAS --- PATH_DATASET_WAV = "dataset_wavs" PATH_DATASET_MIDI = "dataset_midis" MODEL_SAVE_PATH = "CornetaAI.h5" # IDENTIFICADOR DE ARCHIVOS REALES FILTRO_REAL = "ejercicio" # Parámetros Técnicos SAMPLE_RATE = 22050 SAMPLES_PER_CLIP = 43844 HOP_LENGTH = 256 ANNOTATIONS_FPS = SAMPLE_RATE / HOP_LENGTH N_FREQ_BINS = 88 N_FREQ_BINS_CONTOUR = 264 TRAIN_DURATION = SAMPLES_PER_CLIP / SAMPLE_RATE # Hiperparámetros Optimizados BATCH_SIZE = 8 EPOCHS = 50 LEARNING_RATE = 0.0003 # --- FUNCIÓN DE PÉRDIDA BALANCEADA (95/5) --- def weighted_binary_crossentropy(y_true, y_pred): epsilon = tf.keras.backend.epsilon() y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon) loss = -(0.95 * y_true * tf.math.log(y_pred) + 0.05 * (1.0 - y_true) * tf.math.log(1.0 - y_pred)) return tf.reduce_mean(loss) def get_data(wav, midi, frames_output, augment=False): """Carga audio/midi con Pitch Shifting aleatorio para robustez.""" try: audio, _ = librosa.load(wav, sr=SAMPLE_RATE) if len(audio) < SAMPLES_PER_CLIP: audio = np.pad(audio, (0, SAMPLES_PER_CLIP - len(audio))) max_start = len(audio) - SAMPLES_PER_CLIP start_sample = np.random.randint(0, max_start + 1) audio_crop = audio[start_sample : start_sample + SAMPLES_PER_CLIP] if augment: n_steps = np.random.uniform(-0.5, 0.5) audio_crop = librosa.effects.pitch_shift(audio_crop, sr=SAMPLE_RATE, n_steps=n_steps) audio_crop += np.random.normal(0, 0.001, audio_crop.shape) if np.max(np.abs(audio_crop)) > 0: audio_crop = audio_crop / np.max(np.abs(audio_crop)) pm = pretty_midi.PrettyMIDI(midi) start_time, end_time = start_sample / SAMPLE_RATE, (start_sample / SAMPLE_RATE) + TRAIN_DURATION targets = { "note": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32), "onset": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32), "contour": np.zeros((frames_output, N_FREQ_BINS_CONTOUR), dtype=np.float32) } for note in pm.instruments[0].notes: if note.end < start_time or note.start > end_time: continue rel_start, rel_end = max(0.0, note.start - start_time), min(TRAIN_DURATION, note.end - start_time) s, e = int(rel_start * ANNOTATIONS_FPS), int(rel_end * ANNOTATIONS_FPS) p = note.pitch - 21 if 0 <= p < N_FREQ_BINS: s, e = max(0, min(s, frames_output - 1)), max(0, min(e, frames_output)) if s < e: targets["note"][s:e, p] = 1.0 if note.start >= start_time: targets["onset"][max(0, s-1):min(frames_output, s+2), p] = 1.0 c = p * 3 + 1 targets["contour"][s:e, c] = 1.0 if c > 0: targets["contour"][s:e, c-1] = 0.5 if c < 263: targets["contour"][s:e, c+1] = 0.5 return audio_crop[:, np.newaxis], targets except Exception: return None, None class MasterGenerator(tf.keras.utils.Sequence): def __init__(self, wavs, midis, frames_output, augment=False): self.wavs, self.midis, self.frames_output, self.augment = wavs, midis, frames_output, augment self.indices = np.arange(len(self.wavs)) self.on_epoch_end() def __len__(self): return int(np.ceil(len(self.wavs)/BATCH_SIZE)) def __getitem__(self, idx): batch_indices = self.indices[idx * BATCH_SIZE : (idx + 1) * BATCH_SIZE] audios, notes, onsets, contours = [], [], [], [] for i in batch_indices: a, t = get_data(self.wavs[i], self.midis[i], self.frames_output, augment=self.augment) if a is not None: audios.append(a); notes.append(t["note"]); onsets.append(t["onset"]); contours.append(t["contour"]) return np.array(audios), {"note": np.array(notes), "onset": np.array(onsets), "contour": np.array(contours)} def on_epoch_end(self): np.random.shuffle(self.indices) if __name__ == "__main__": print("--- FINE-TUNING: VALIDACIÓN REAL PURA (25 Train / 10 Val) ---") # 1. Cargar y clasificar archivos all_wavs = [f for f in os.listdir(PATH_DATASET_WAV) if f.endswith(".wav")] wav_real, midi_real, wav_sint, midi_sint = [], [], [], [] for w in all_wavs: m = w.replace(".wav", ".mid") if os.path.exists(os.path.join(PATH_DATASET_MIDI, m)): path_w, path_m = os.path.join(PATH_DATASET_WAV, w), os.path.join(PATH_DATASET_MIDI, m) if FILTRO_REAL.lower() in w.lower(): wav_real.append(path_w); midi_real.append(path_m) else: wav_sint.append(path_w); midi_sint.append(path_m) # 2. Split de Reales (25 para entrenar, 10 para validar) tr_rw, val_w, tr_rm, val_m = train_test_split(wav_real, midi_real, test_size=10, random_state=42) # 3. Mezcla de Entrenamiento (500 sintéticos + 25 reales) train_w, train_m = wav_sint + tr_rw, midi_sint + tr_rm # 4. Configurar Modelo model = bp_model() model.load_weights(ICASSP_2022_MODEL_PATH) frames_out = model(np.zeros((1, SAMPLES_PER_CLIP, 1))).get('note').shape[1] for l in model.layers: l.trainable = 'cqt' not in l.name model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE), loss={"note": "binary_crossentropy", "onset": weighted_binary_crossentropy, "contour": "binary_crossentropy"}, loss_weights={"note": 1.0, "onset": 1.5, "contour": 0.5}) train_gen = MasterGenerator(train_w, train_m, frames_out, augment=True) val_gen = MasterGenerator(val_w, val_m, frames_out, augment=False) callbacks = [ tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True), tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True) ] print(f"📊 Train: {len(train_w)} (500 sint + 25 real) | Val: {len(val_w)} (10 real).") model.fit(train_gen, validation_data=val_gen, epochs=EPOCHS, callbacks=callbacks)