File size: 6,653 Bytes
4f8d622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import numpy as np
import tensorflow as tf
import pretty_midi
import librosa
from basic_pitch.models import model as bp_model
from basic_pitch import ICASSP_2022_MODEL_PATH
from sklearn.model_selection import train_test_split

# --- CONFIGURACIÓN DE RUTAS ---
PATH_DATASET_WAV = "dataset_wavs"
PATH_DATASET_MIDI = "dataset_midis"
MODEL_SAVE_PATH = "CornetaAI.h5"

# IDENTIFICADOR DE ARCHIVOS REALES
FILTRO_REAL = "ejercicio" 

# Parámetros Técnicos
SAMPLE_RATE = 22050
SAMPLES_PER_CLIP = 43844 
HOP_LENGTH = 256
ANNOTATIONS_FPS = SAMPLE_RATE / HOP_LENGTH
N_FREQ_BINS = 88       
N_FREQ_BINS_CONTOUR = 264 
TRAIN_DURATION = SAMPLES_PER_CLIP / SAMPLE_RATE

# Hiperparámetros Optimizados
BATCH_SIZE = 8
EPOCHS = 50 
LEARNING_RATE = 0.0003 

# --- FUNCIÓN DE PÉRDIDA BALANCEADA (95/5) ---
def weighted_binary_crossentropy(y_true, y_pred):
    epsilon = tf.keras.backend.epsilon()
    y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
    loss = -(0.95 * y_true * tf.math.log(y_pred) + 0.05 * (1.0 - y_true) * tf.math.log(1.0 - y_pred))
    return tf.reduce_mean(loss)

def get_data(wav, midi, frames_output, augment=False):
    """Carga audio/midi con Pitch Shifting aleatorio para robustez."""
    try:
        audio, _ = librosa.load(wav, sr=SAMPLE_RATE)
        if len(audio) < SAMPLES_PER_CLIP:
            audio = np.pad(audio, (0, SAMPLES_PER_CLIP - len(audio)))
        
        max_start = len(audio) - SAMPLES_PER_CLIP
        start_sample = np.random.randint(0, max_start + 1)
        audio_crop = audio[start_sample : start_sample + SAMPLES_PER_CLIP]
        
        if augment:
            n_steps = np.random.uniform(-0.5, 0.5)
            audio_crop = librosa.effects.pitch_shift(audio_crop, sr=SAMPLE_RATE, n_steps=n_steps)
            audio_crop += np.random.normal(0, 0.001, audio_crop.shape)

        if np.max(np.abs(audio_crop)) > 0:
            audio_crop = audio_crop / np.max(np.abs(audio_crop))

        pm = pretty_midi.PrettyMIDI(midi)
        start_time, end_time = start_sample / SAMPLE_RATE, (start_sample / SAMPLE_RATE) + TRAIN_DURATION
        
        targets = {
            "note": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
            "onset": np.zeros((frames_output, N_FREQ_BINS), dtype=np.float32),
            "contour": np.zeros((frames_output, N_FREQ_BINS_CONTOUR), dtype=np.float32)
        }

        for note in pm.instruments[0].notes:
            if note.end < start_time or note.start > end_time: continue
            rel_start, rel_end = max(0.0, note.start - start_time), min(TRAIN_DURATION, note.end - start_time)
            s, e = int(rel_start * ANNOTATIONS_FPS), int(rel_end * ANNOTATIONS_FPS)
            p = note.pitch - 21
            if 0 <= p < N_FREQ_BINS:
                s, e = max(0, min(s, frames_output - 1)), max(0, min(e, frames_output))
                if s < e:
                    targets["note"][s:e, p] = 1.0
                    if note.start >= start_time:
                        targets["onset"][max(0, s-1):min(frames_output, s+2), p] = 1.0
                    c = p * 3 + 1
                    targets["contour"][s:e, c] = 1.0
                    if c > 0: targets["contour"][s:e, c-1] = 0.5
                    if c < 263: targets["contour"][s:e, c+1] = 0.5
        return audio_crop[:, np.newaxis], targets
    except Exception: return None, None

class MasterGenerator(tf.keras.utils.Sequence):
    def __init__(self, wavs, midis, frames_output, augment=False):
        self.wavs, self.midis, self.frames_output, self.augment = wavs, midis, frames_output, augment
        self.indices = np.arange(len(self.wavs))
        self.on_epoch_end()
    def __len__(self): return int(np.ceil(len(self.wavs)/BATCH_SIZE))
    def __getitem__(self, idx):
        batch_indices = self.indices[idx * BATCH_SIZE : (idx + 1) * BATCH_SIZE]
        audios, notes, onsets, contours = [], [], [], []
        for i in batch_indices:
            a, t = get_data(self.wavs[i], self.midis[i], self.frames_output, augment=self.augment)
            if a is not None:
                audios.append(a); notes.append(t["note"]); onsets.append(t["onset"]); contours.append(t["contour"])
        return np.array(audios), {"note": np.array(notes), "onset": np.array(onsets), "contour": np.array(contours)}
    def on_epoch_end(self): np.random.shuffle(self.indices)

if __name__ == "__main__":
    print("--- FINE-TUNING: VALIDACIÓN REAL PURA (25 Train / 10 Val) ---")
    
    # 1. Cargar y clasificar archivos
    all_wavs = [f for f in os.listdir(PATH_DATASET_WAV) if f.endswith(".wav")]
    wav_real, midi_real, wav_sint, midi_sint = [], [], [], []

    for w in all_wavs:
        m = w.replace(".wav", ".mid")
        if os.path.exists(os.path.join(PATH_DATASET_MIDI, m)):
            path_w, path_m = os.path.join(PATH_DATASET_WAV, w), os.path.join(PATH_DATASET_MIDI, m)
            if FILTRO_REAL.lower() in w.lower():
                wav_real.append(path_w); midi_real.append(path_m)
            else:
                wav_sint.append(path_w); midi_sint.append(path_m)

    # 2. Split de Reales (25 para entrenar, 10 para validar)
    tr_rw, val_w, tr_rm, val_m = train_test_split(wav_real, midi_real, test_size=10, random_state=42)
    
    # 3. Mezcla de Entrenamiento (500 sintéticos + 25 reales)
    train_w, train_m = wav_sint + tr_rw, midi_sint + tr_rm

    # 4. Configurar Modelo
    model = bp_model()
    model.load_weights(ICASSP_2022_MODEL_PATH)
    frames_out = model(np.zeros((1, SAMPLES_PER_CLIP, 1))).get('note').shape[1]

    for l in model.layers: l.trainable = 'cqt' not in l.name
    model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
                  loss={"note": "binary_crossentropy", "onset": weighted_binary_crossentropy, "contour": "binary_crossentropy"},
                  loss_weights={"note": 1.0, "onset": 1.5, "contour": 0.5})

    train_gen = MasterGenerator(train_w, train_m, frames_out, augment=True)
    val_gen = MasterGenerator(val_w, val_m, frames_out, augment=False)

    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
        tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True)
    ]

    print(f"📊 Train: {len(train_w)} (500 sint + 25 real) | Val: {len(val_w)} (10 real).")
    model.fit(train_gen, validation_data=val_gen, epochs=EPOCHS, callbacks=callbacks)