""" BTC Toolkit - Motor de Inferencia de Acordes Carga el modelo BTC, preprocesa el audio y retorna los acordes detectados. Flujo: 1. Cargar audio con librosa (extracción de CQT). 2. Pasar por el modelo Transformer. 3. Decodificar predicciones a nombres de acordes Mayor/Menor. 4. Agrupar acordes consecutivos iguales en segmentos de tiempo. """ import os import logging import numpy as np import torch import torch.nn.functional as F from .model import BTCModel, CHORD_VOCAB from .weights_manager import ensure_weights_available logger = logging.getLogger(__name__) # ========================================== # PARÁMETROS DE AUDIO # ========================================== SAMPLE_RATE = 22050 # Hz HOP_LENGTH = 512 # Frames CQT N_BINS = 144 # Bins CQT (12 semitones x 12 octavas) BINS_PER_OCTAVE = 36 # Resolución octava SEGMENT_SECONDS = 30.0 # Longitud máxima por bloque (para VRAM limitada) MIN_CHORD_DURATION = 0.5 # Duración mínima de un acorde para mostrarlo (segundos) # ========================================== # CLASE PRINCIPAL DE INFERENCIA # ========================================== class BTCChordRecognizer: """ Motor de reconocimiento de acordes usando BTC Transformer. Diseñado para bajo uso de VRAM: - Procesa el audio en bloques de 30 segundos. - Libera la VRAM al finalizar con `unload()`. """ def __init__(self, device: str = "cuda"): self._model: BTCModel | None = None self._device_str = device self._device: torch.device | None = None def _resolve_device(self) -> torch.device: """Selecciona GPU si está disponible, sino CPU.""" if self._device_str == "cuda" and torch.cuda.is_available(): dev = torch.device("cuda") vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) logger.info(f"[BTC] GPU detectada: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)") else: dev = torch.device("cpu") logger.info("[BTC] Usando CPU para inferencia de acordes.") return dev def load_model(self) -> None: """Carga el modelo en memoria. Descarga pesos si no están disponibles.""" if self._model is not None: return self._device = self._resolve_device() # Asegurar que los pesos estén disponibles (descarga si falta) weights_path = ensure_weights_available() # Instanciar arquitectura self._model = BTCModel(n_freq=N_BINS) self._model.to(self._device) # Cargar pesos try: checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False) # Soporte para checkpoints con y sin wrapper 'state_dict' if isinstance(checkpoint, dict) and "state_dict" in checkpoint: state = checkpoint["state_dict"] elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state = checkpoint["model_state_dict"] elif isinstance(checkpoint, dict) and "model" in checkpoint: state = checkpoint["model"] else: state = checkpoint # Carga flexible: ignora claves que no coinciden missing, unexpected = self._model.load_state_dict(state, strict=False) if missing: logger.warning(f"[BTC] Parámetros faltantes al cargar pesos: {len(missing)} parámetros faltantes") if unexpected: logger.debug(f"[BTC] Parámetros inesperados ignorados: {len(unexpected)} parámetros inesperados") self._model.eval() logger.info("[BTC] Modelo BTC cargado y listo para inferencia.") except Exception as e: self._model = None raise RuntimeError(f"[BTC] Error al cargar los pesos del modelo: {e}") from e def unload(self) -> None: """Libera el modelo de la memoria GPU para dejar espacio a otros procesos.""" if self._model is not None: self._model.cpu() del self._model self._model = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("[BTC] VRAM liberada correctamente.") def recognize(self, audio_path: str, bpm: float = None, beats: list = None) -> list[dict]: """ Detecta acordes en un archivo de audio. Args: audio_path: Ruta al archivo de audio (.flac, .wav, .mp3). Returns: Lista de diccionarios con formato: [{"t": tiempo_inicio_segundos, "ch": "NombreAcorde"}, ...] """ try: import librosa except ImportError: raise ImportError("librosa es requerido para la extracción de características CQT.") if not os.path.exists(audio_path): logger.error(f"[BTC] Archivo no encontrado: {audio_path}") return [] # Asegurar que el modelo esté cargado if self._model is None: self.load_model() logger.info(f"[BTC] Analizando acordes en: {audio_path}") try: # ========================================== # EXTRACCIÓN DE CARACTERÍSTICAS CQT # ========================================== y, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True) # Transformada de Q constante (CQT) - representación frecuencial de alta resolución # NOTA: Se ha omitido librosa.effects.harmonic por su altísimo costo en CPU cqt = np.abs(librosa.cqt( y, sr=sr, hop_length=HOP_LENGTH, n_bins=N_BINS, bins_per_octave=BINS_PER_OCTAVE, fmin=librosa.note_to_hz('C1') )) # Normalización logarítmica (aumenta contraste de frecuencias armónicas) cqt = librosa.amplitude_to_db(cqt, ref=np.max) cqt = (cqt - cqt.mean()) / (cqt.std() + 1e-8) # Transponer: (freq, time) -> (time, freq) cqt = cqt.T # shape: (T, N_BINS) # ========================================== # INFERENCIA POR BLOQUES (30s para VRAM 4GB) # ========================================== frames_per_second = sr / HOP_LENGTH block_size = int(SEGMENT_SECONDS * frames_per_second) all_probs = [] with torch.no_grad(): for start in range(0, len(cqt), block_size): block = cqt[start: start + block_size] block_tensor = torch.tensor(block, dtype=torch.float32).unsqueeze(0) block_tensor = block_tensor.to(self._device) logits = self._model(block_tensor) probs = F.softmax(logits, dim=-1).squeeze(0).cpu() # (N_frames, 171) all_probs.append(probs) # Concatenar probabilidades de toda la canción full_probs = torch.cat(all_probs, dim=0).numpy() # ========================================== # DECODIFICACIÓN, POOLING Y AGRUPACIÓN # ========================================== acordes_agrupados = _decode_and_group_chords(full_probs, frames_per_second, bpm, beats) logger.info(f"[BTC] ✅ Reconocimiento completado: {len(acordes_agrupados)} segmentos de acordes detectados.") return acordes_agrupados except Exception as e: logger.error(f"[BTC] Error durante la inferencia: {e}", exc_info=True) return [] finally: # Siempre liberar VRAM al terminar, incluso si hubo error self.unload() # ========================================== # FUNCIÓN DE AGRUPACIÓN Y POOLING RÍTMICO # ========================================== def _decode_and_group_chords(probs: np.ndarray, fps: float, bpm: float = None, beats: list = None) -> list[dict]: """ Toma un tensor de probabilidades y extrae los acordes. Estrategia: Emite un acorde por cada beat donde se detecta un cambio. El frontend hereda el acorde del beat anterior para beats sin cambio, garantizando que cada compás tenga un acorde visible. Cuando el modelo predice 'N' (no-chord), se toma el mejor acorde no-N para evitar compases vacíos. """ total_frames = len(probs) N_IDX = 0 # 'N' es siempre el índice 0 en CHORD_VOCAB # ── Determinar las fronteras de cada beat ── if beats is not None and len(beats) > 1: beat_times = list(beats) elif bpm is not None and bpm > 40: beat_duration = 60.0 / bpm total_duration = total_frames / fps beat_times = [i * beat_duration for i in range(int(total_duration / beat_duration) + 1)] else: # Sin BPM conocido, usar beats de 0.5s total_duration = total_frames / fps beat_times = [i * 0.5 for i in range(int(total_duration / 0.5) + 1)] if not beat_times: return [] # ── Poolear probabilidades por beat y obtener el mejor acorde no-N ── beat_chords = [] for i, bt in enumerate(beat_times): # Calcular rango de frames para este beat next_bt = beat_times[i + 1] if i + 1 < len(beat_times) else total_frames / fps start_frame = int(bt * fps) end_frame = min(int(next_bt * fps), total_frames) if start_frame >= total_frames or start_frame >= end_frame: continue # Promediar probabilidades en el rango del beat beat_probs = np.mean(probs[start_frame:end_frame], axis=0) # Obtener el mejor acorde, evitando 'N' best_idx = np.argmax(beat_probs) if best_idx == N_IDX: # Buscar el mejor no-N beat_probs_no_n = beat_probs.copy() beat_probs_no_n[N_IDX] = -1 best_idx = np.argmax(beat_probs_no_n) # Si el mejor no-N tiene probabilidad muy baja, usar N if beat_probs[best_idx] < 0.02: best_idx = N_IDX chord_name = CHORD_VOCAB[best_idx] beat_chords.append((round(bt, 3), chord_name)) # ── Emitir un acorde solo cuando cambia (optimización) ── # El frontend hereda el acorde anterior para beats sin cambio, # así que solo necesitamos emitir en los puntos de cambio. result = [] prev_chord = None for bt, chord in beat_chords: if chord == 'N': continue # Saltar los pocos frames que aún predicen N if chord != prev_chord: result.append({"t": round(bt, 2), "ch": chord}) prev_chord = chord # Asegurar que siempre hay un acorde al inicio (t=0) if result and result[0]["t"] > 0.5: # Si el primer acorde está muy lejos del inicio, insertar uno al inicio result.insert(0, {"t": 0.0, "ch": result[0]["ch"]}) return result