| | import os
|
| | import numpy as np
|
| | import tensorflow as tf
|
| | import pretty_midi
|
| | import librosa
|
| | from basic_pitch.models import model as bp_model
|
| | from basic_pitch import ICASSP_2022_MODEL_PATH
|
| | from sklearn.model_selection import train_test_split
|
| |
|
| |
|
| | PATH_DATASET_WAV = "dataset_wavs"
|
| | PATH_DATASET_MIDI = "dataset_midis"
|
| | MODEL_SAVE_PATH = "CornetaAI.h5"
|
| |
|
| |
|
| | FILTRO_REAL = "ejercicio"
|
| |
|
| |
|
| | SAMPLE_RATE = 22050
|
| | SAMPLES_PER_CLIP = 43844
|
| | HOP_LENGTH = 256
|
| | ANNOTATIONS_FPS = SAMPLE_RATE / HOP_LENGTH
|
| | N_FREQ_BINS = 88
|
| | N_FREQ_BINS_CONTOUR = 264
|
| | TRAIN_DURATION = SAMPLES_PER_CLIP / SAMPLE_RATE
|
| |
|
| |
|
| | BATCH_SIZE = 8
|
| | EPOCHS = 50
|
| | LEARNING_RATE = 0.0003
|
| |
|
| |
|
| | def weighted_binary_crossentropy(y_true, y_pred):
|
| | epsilon = tf.keras.backend.epsilon()
|
| | y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
|
| | loss = -(0.95 * y_true * tf.math.log(y_pred) + 0.05 * (1.0 - y_true) * tf.math.log(1.0 - y_pred))
|
| | return tf.reduce_mean(loss)
|
| |
|
| | def get_data(wav, midi, frames_output, augment=False):
|
| | """Carga audio/midi con Pitch Shifting aleatorio para robustez."""
|
| | try:
|
| | audio, _ = librosa.load(wav, sr=SAMPLE_RATE)
|
| | if len(audio) < SAMPLES_PER_CLIP:
|
| | audio = np.pad(audio, (0, SAMPLES_PER_CLIP - len(audio)))
|
| |
|
| | max_start = len(audio) - SAMPLES_PER_CLIP
|
| | start_sample = np.random.randint(0, max_start + 1)
|
| | audio_crop = audio[start_sample : start_sample + SAMPLES_PER_CLIP]
|
| |
|
| | if augment:
|
| | n_steps = np.random.uniform(-0.5, 0.5)
|
| | audio_crop = librosa.effects.pitch_shift(audio_crop, sr=SAMPLE_RATE, n_steps=n_steps)
|
| | audio_crop += np.random.normal(0, 0.001, audio_crop.shape)
|
| |
|
| | if np.max(np.abs(audio_crop)) > 0:
|
| | audio_crop = audio_crop / np.max(np.abs(audio_crop))
|
| |
|
| | pm = pretty_midi.PrettyMIDI(midi)
|
| | start_time, end_time = start_sample / SAMPLE_RATE, (start_sample / SAMPLE_RATE) + TRAIN_DURATION
|
| |
|
| | targets = {
|
| | "note": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
|
| | "onset": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
|
| | "contour": np.zeros((frames_output, N_FREQ_BINS_CONTOUR), dtype=np.float32)
|
| | }
|
| |
|
| | for note in pm.instruments[0].notes:
|
| | if note.end < start_time or note.start > end_time: continue
|
| | rel_start, rel_end = max(0.0, note.start - start_time), min(TRAIN_DURATION, note.end - start_time)
|
| | s, e = int(rel_start * ANNOTATIONS_FPS), int(rel_end * ANNOTATIONS_FPS)
|
| | p = note.pitch - 21
|
| | if 0 <= p < N_FREQ_BINS:
|
| | s, e = max(0, min(s, frames_output - 1)), max(0, min(e, frames_output))
|
| | if s < e:
|
| | targets["note"][s:e, p] = 1.0
|
| | if note.start >= start_time:
|
| | targets["onset"][max(0, s-1):min(frames_output, s+2), p] = 1.0
|
| | c = p * 3 + 1
|
| | targets["contour"][s:e, c] = 1.0
|
| | if c > 0: targets["contour"][s:e, c-1] = 0.5
|
| | if c < 263: targets["contour"][s:e, c+1] = 0.5
|
| | return audio_crop[:, np.newaxis], targets
|
| | except Exception: return None, None
|
| |
|
| | class MasterGenerator(tf.keras.utils.Sequence):
|
| | def __init__(self, wavs, midis, frames_output, augment=False):
|
| | self.wavs, self.midis, self.frames_output, self.augment = wavs, midis, frames_output, augment
|
| | self.indices = np.arange(len(self.wavs))
|
| | self.on_epoch_end()
|
| | def __len__(self): return int(np.ceil(len(self.wavs)/BATCH_SIZE))
|
| | def __getitem__(self, idx):
|
| | batch_indices = self.indices[idx * BATCH_SIZE : (idx + 1) * BATCH_SIZE]
|
| | audios, notes, onsets, contours = [], [], [], []
|
| | for i in batch_indices:
|
| | a, t = get_data(self.wavs[i], self.midis[i], self.frames_output, augment=self.augment)
|
| | if a is not None:
|
| | audios.append(a); notes.append(t["note"]); onsets.append(t["onset"]); contours.append(t["contour"])
|
| | return np.array(audios), {"note": np.array(notes), "onset": np.array(onsets), "contour": np.array(contours)}
|
| | def on_epoch_end(self): np.random.shuffle(self.indices)
|
| |
|
| | if __name__ == "__main__":
|
| | print("--- FINE-TUNING: VALIDACIÓN REAL PURA (25 Train / 10 Val) ---")
|
| |
|
| |
|
| | all_wavs = [f for f in os.listdir(PATH_DATASET_WAV) if f.endswith(".wav")]
|
| | wav_real, midi_real, wav_sint, midi_sint = [], [], [], []
|
| |
|
| | for w in all_wavs:
|
| | m = w.replace(".wav", ".mid")
|
| | if os.path.exists(os.path.join(PATH_DATASET_MIDI, m)):
|
| | path_w, path_m = os.path.join(PATH_DATASET_WAV, w), os.path.join(PATH_DATASET_MIDI, m)
|
| | if FILTRO_REAL.lower() in w.lower():
|
| | wav_real.append(path_w); midi_real.append(path_m)
|
| | else:
|
| | wav_sint.append(path_w); midi_sint.append(path_m)
|
| |
|
| |
|
| | tr_rw, val_w, tr_rm, val_m = train_test_split(wav_real, midi_real, test_size=10, random_state=42)
|
| |
|
| |
|
| | train_w, train_m = wav_sint + tr_rw, midi_sint + tr_rm
|
| |
|
| |
|
| | model = bp_model()
|
| | model.load_weights(ICASSP_2022_MODEL_PATH)
|
| | frames_out = model(np.zeros((1, SAMPLES_PER_CLIP, 1))).get('note').shape[1]
|
| |
|
| | for l in model.layers: l.trainable = 'cqt' not in l.name
|
| | model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
|
| | loss={"note": "binary_crossentropy", "onset": weighted_binary_crossentropy, "contour": "binary_crossentropy"},
|
| | loss_weights={"note": 1.0, "onset": 1.5, "contour": 0.5})
|
| |
|
| | train_gen = MasterGenerator(train_w, train_m, frames_out, augment=True)
|
| | val_gen = MasterGenerator(val_w, val_m, frames_out, augment=False)
|
| |
|
| | callbacks = [
|
| | tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
|
| | tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
|
| | tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True)
|
| | ]
|
| |
|
| | print(f"📊 Train: {len(train_w)} (500 sint + 25 real) | Val: {len(val_w)} (10 real).")
|
| | model.fit(train_gen, validation_data=val_gen, epochs=EPOCHS, callbacks=callbacks) |