Spaces:
Running
Running
| """ | |
| BTC Toolkit - Motor de Inferencia de Acordes | |
| Carga el modelo BTC, preprocesa el audio y retorna los acordes detectados. | |
| Flujo: | |
| 1. Cargar audio con librosa (extracción de CQT). | |
| 2. Pasar por el modelo Transformer. | |
| 3. Decodificar predicciones a nombres de acordes Mayor/Menor. | |
| 4. Agrupar acordes consecutivos iguales en segmentos de tiempo. | |
| """ | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from .model import BTCModel, CHORD_VOCAB | |
| from .weights_manager import ensure_weights_available | |
| logger = logging.getLogger(__name__) | |
| # ========================================== | |
| # PARÁMETROS DE AUDIO | |
| # ========================================== | |
| SAMPLE_RATE = 22050 # Hz | |
| HOP_LENGTH = 512 # Frames CQT | |
| N_BINS = 144 # Bins CQT (12 semitones x 12 octavas) | |
| BINS_PER_OCTAVE = 36 # Resolución octava | |
| SEGMENT_SECONDS = 30.0 # Longitud máxima por bloque (para VRAM limitada) | |
| MIN_CHORD_DURATION = 0.5 # Duración mínima de un acorde para mostrarlo (segundos) | |
| # ========================================== | |
| # CLASE PRINCIPAL DE INFERENCIA | |
| # ========================================== | |
| class BTCChordRecognizer: | |
| """ | |
| Motor de reconocimiento de acordes usando BTC Transformer. | |
| Diseñado para bajo uso de VRAM: | |
| - Procesa el audio en bloques de 30 segundos. | |
| - Libera la VRAM al finalizar con `unload()`. | |
| """ | |
| def __init__(self, device: str = "cuda"): | |
| self._model: BTCModel | None = None | |
| self._device_str = device | |
| self._device: torch.device | None = None | |
| def _resolve_device(self) -> torch.device: | |
| """Selecciona GPU si está disponible, sino CPU.""" | |
| if self._device_str == "cuda" and torch.cuda.is_available(): | |
| dev = torch.device("cuda") | |
| vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| logger.info(f"[BTC] GPU detectada: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)") | |
| else: | |
| dev = torch.device("cpu") | |
| logger.info("[BTC] Usando CPU para inferencia de acordes.") | |
| return dev | |
| def load_model(self) -> None: | |
| """Carga el modelo en memoria. Descarga pesos si no están disponibles.""" | |
| if self._model is not None: | |
| return | |
| self._device = self._resolve_device() | |
| # Asegurar que los pesos estén disponibles (descarga si falta) | |
| weights_path = ensure_weights_available() | |
| # Instanciar arquitectura | |
| self._model = BTCModel(n_freq=N_BINS) | |
| self._model.to(self._device) | |
| # Cargar pesos | |
| try: | |
| checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False) | |
| # Soporte para checkpoints con y sin wrapper 'state_dict' | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| state = checkpoint["state_dict"] | |
| elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| state = checkpoint["model_state_dict"] | |
| elif isinstance(checkpoint, dict) and "model" in checkpoint: | |
| state = checkpoint["model"] | |
| else: | |
| state = checkpoint | |
| # Carga flexible: ignora claves que no coinciden | |
| missing, unexpected = self._model.load_state_dict(state, strict=False) | |
| if missing: | |
| logger.warning(f"[BTC] Parámetros faltantes al cargar pesos: {len(missing)} parámetros faltantes") | |
| if unexpected: | |
| logger.debug(f"[BTC] Parámetros inesperados ignorados: {len(unexpected)} parámetros inesperados") | |
| self._model.eval() | |
| logger.info("[BTC] Modelo BTC cargado y listo para inferencia.") | |
| except Exception as e: | |
| self._model = None | |
| raise RuntimeError(f"[BTC] Error al cargar los pesos del modelo: {e}") from e | |
| def unload(self) -> None: | |
| """Libera el modelo de la memoria GPU para dejar espacio a otros procesos.""" | |
| if self._model is not None: | |
| self._model.cpu() | |
| del self._model | |
| self._model = None | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("[BTC] VRAM liberada correctamente.") | |
| def recognize(self, audio_path: str, bpm: float = None, beats: list = None) -> list[dict]: | |
| """ | |
| Detecta acordes en un archivo de audio. | |
| Args: | |
| audio_path: Ruta al archivo de audio (.flac, .wav, .mp3). | |
| Returns: | |
| Lista de diccionarios con formato: | |
| [{"t": tiempo_inicio_segundos, "ch": "NombreAcorde"}, ...] | |
| """ | |
| try: | |
| import librosa | |
| except ImportError: | |
| raise ImportError("librosa es requerido para la extracción de características CQT.") | |
| if not os.path.exists(audio_path): | |
| logger.error(f"[BTC] Archivo no encontrado: {audio_path}") | |
| return [] | |
| # Asegurar que el modelo esté cargado | |
| if self._model is None: | |
| self.load_model() | |
| logger.info(f"[BTC] Analizando acordes en: {audio_path}") | |
| try: | |
| # ========================================== | |
| # EXTRACCIÓN DE CARACTERÍSTICAS CQT | |
| # ========================================== | |
| y, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True) | |
| # Transformada de Q constante (CQT) - representación frecuencial de alta resolución | |
| # NOTA: Se ha omitido librosa.effects.harmonic por su altísimo costo en CPU | |
| cqt = np.abs(librosa.cqt( | |
| y, | |
| sr=sr, | |
| hop_length=HOP_LENGTH, | |
| n_bins=N_BINS, | |
| bins_per_octave=BINS_PER_OCTAVE, | |
| fmin=librosa.note_to_hz('C1') | |
| )) | |
| # Normalización logarítmica (aumenta contraste de frecuencias armónicas) | |
| cqt = librosa.amplitude_to_db(cqt, ref=np.max) | |
| cqt = (cqt - cqt.mean()) / (cqt.std() + 1e-8) | |
| # Transponer: (freq, time) -> (time, freq) | |
| cqt = cqt.T # shape: (T, N_BINS) | |
| # ========================================== | |
| # INFERENCIA POR BLOQUES (30s para VRAM 4GB) | |
| # ========================================== | |
| frames_per_second = sr / HOP_LENGTH | |
| block_size = int(SEGMENT_SECONDS * frames_per_second) | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for start in range(0, len(cqt), block_size): | |
| block = cqt[start: start + block_size] | |
| block_tensor = torch.tensor(block, dtype=torch.float32).unsqueeze(0) | |
| block_tensor = block_tensor.to(self._device) | |
| logits = self._model(block_tensor) | |
| probs = F.softmax(logits, dim=-1).squeeze(0).cpu() # (N_frames, 171) | |
| all_probs.append(probs) | |
| # Concatenar probabilidades de toda la canción | |
| full_probs = torch.cat(all_probs, dim=0).numpy() | |
| # ========================================== | |
| # DECODIFICACIÓN, POOLING Y AGRUPACIÓN | |
| # ========================================== | |
| acordes_agrupados = _decode_and_group_chords(full_probs, frames_per_second, bpm, beats) | |
| logger.info(f"[BTC] ✅ Reconocimiento completado: {len(acordes_agrupados)} segmentos de acordes detectados.") | |
| return acordes_agrupados | |
| except Exception as e: | |
| logger.error(f"[BTC] Error durante la inferencia: {e}", exc_info=True) | |
| return [] | |
| finally: | |
| # Siempre liberar VRAM al terminar, incluso si hubo error | |
| self.unload() | |
| # ========================================== | |
| # FUNCIÓN DE AGRUPACIÓN Y POOLING RÍTMICO | |
| # ========================================== | |
| def _decode_and_group_chords(probs: np.ndarray, fps: float, bpm: float = None, beats: list = None) -> list[dict]: | |
| """ | |
| Toma un tensor de probabilidades y extrae los acordes. | |
| Estrategia: Emite un acorde por cada beat donde se detecta un cambio. | |
| El frontend hereda el acorde del beat anterior para beats sin cambio, | |
| garantizando que cada compás tenga un acorde visible. | |
| Cuando el modelo predice 'N' (no-chord), se toma el mejor acorde no-N | |
| para evitar compases vacíos. | |
| """ | |
| total_frames = len(probs) | |
| N_IDX = 0 # 'N' es siempre el índice 0 en CHORD_VOCAB | |
| # ── Determinar las fronteras de cada beat ── | |
| if beats is not None and len(beats) > 1: | |
| beat_times = list(beats) | |
| elif bpm is not None and bpm > 40: | |
| beat_duration = 60.0 / bpm | |
| total_duration = total_frames / fps | |
| beat_times = [i * beat_duration for i in range(int(total_duration / beat_duration) + 1)] | |
| else: | |
| # Sin BPM conocido, usar beats de 0.5s | |
| total_duration = total_frames / fps | |
| beat_times = [i * 0.5 for i in range(int(total_duration / 0.5) + 1)] | |
| if not beat_times: | |
| return [] | |
| # ── Poolear probabilidades por beat y obtener el mejor acorde no-N ── | |
| beat_chords = [] | |
| for i, bt in enumerate(beat_times): | |
| # Calcular rango de frames para este beat | |
| next_bt = beat_times[i + 1] if i + 1 < len(beat_times) else total_frames / fps | |
| start_frame = int(bt * fps) | |
| end_frame = min(int(next_bt * fps), total_frames) | |
| if start_frame >= total_frames or start_frame >= end_frame: | |
| continue | |
| # Promediar probabilidades en el rango del beat | |
| beat_probs = np.mean(probs[start_frame:end_frame], axis=0) | |
| # Obtener el mejor acorde, evitando 'N' | |
| best_idx = np.argmax(beat_probs) | |
| if best_idx == N_IDX: | |
| # Buscar el mejor no-N | |
| beat_probs_no_n = beat_probs.copy() | |
| beat_probs_no_n[N_IDX] = -1 | |
| best_idx = np.argmax(beat_probs_no_n) | |
| # Si el mejor no-N tiene probabilidad muy baja, usar N | |
| if beat_probs[best_idx] < 0.02: | |
| best_idx = N_IDX | |
| chord_name = CHORD_VOCAB[best_idx] | |
| beat_chords.append((round(bt, 3), chord_name)) | |
| # ── Emitir un acorde solo cuando cambia (optimización) ── | |
| # El frontend hereda el acorde anterior para beats sin cambio, | |
| # así que solo necesitamos emitir en los puntos de cambio. | |
| result = [] | |
| prev_chord = None | |
| for bt, chord in beat_chords: | |
| if chord == 'N': | |
| continue # Saltar los pocos frames que aún predicen N | |
| if chord != prev_chord: | |
| result.append({"t": round(bt, 2), "ch": chord}) | |
| prev_chord = chord | |
| # Asegurar que siempre hay un acorde al inicio (t=0) | |
| if result and result[0]["t"] > 0.5: | |
| # Si el primer acorde está muy lejos del inicio, insertar uno al inicio | |
| result.insert(0, {"t": 0.0, "ch": result[0]["ch"]}) | |
| return result | |