melodix-api / btc_toolkit /inference.py
GitHub Action
deploy from github actions
89322df
Raw
History Blame Contribute Delete
11.1 kB
"""
BTC Toolkit - Motor de Inferencia de Acordes
Carga el modelo BTC, preprocesa el audio y retorna los acordes detectados.
Flujo:
1. Cargar audio con librosa (extracción de CQT).
2. Pasar por el modelo Transformer.
3. Decodificar predicciones a nombres de acordes Mayor/Menor.
4. Agrupar acordes consecutivos iguales en segmentos de tiempo.
"""
import os
import logging
import numpy as np
import torch
import torch.nn.functional as F
from .model import BTCModel, CHORD_VOCAB
from .weights_manager import ensure_weights_available
logger = logging.getLogger(__name__)
# ==========================================
# PARÁMETROS DE AUDIO
# ==========================================
SAMPLE_RATE = 22050 # Hz
HOP_LENGTH = 512 # Frames CQT
N_BINS = 144 # Bins CQT (12 semitones x 12 octavas)
BINS_PER_OCTAVE = 36 # Resolución octava
SEGMENT_SECONDS = 30.0 # Longitud máxima por bloque (para VRAM limitada)
MIN_CHORD_DURATION = 0.5 # Duración mínima de un acorde para mostrarlo (segundos)
# ==========================================
# CLASE PRINCIPAL DE INFERENCIA
# ==========================================
class BTCChordRecognizer:
"""
Motor de reconocimiento de acordes usando BTC Transformer.
Diseñado para bajo uso de VRAM:
- Procesa el audio en bloques de 30 segundos.
- Libera la VRAM al finalizar con `unload()`.
"""
def __init__(self, device: str = "cuda"):
self._model: BTCModel | None = None
self._device_str = device
self._device: torch.device | None = None
def _resolve_device(self) -> torch.device:
"""Selecciona GPU si está disponible, sino CPU."""
if self._device_str == "cuda" and torch.cuda.is_available():
dev = torch.device("cuda")
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"[BTC] GPU detectada: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
else:
dev = torch.device("cpu")
logger.info("[BTC] Usando CPU para inferencia de acordes.")
return dev
def load_model(self) -> None:
"""Carga el modelo en memoria. Descarga pesos si no están disponibles."""
if self._model is not None:
return
self._device = self._resolve_device()
# Asegurar que los pesos estén disponibles (descarga si falta)
weights_path = ensure_weights_available()
# Instanciar arquitectura
self._model = BTCModel(n_freq=N_BINS)
self._model.to(self._device)
# Cargar pesos
try:
checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False)
# Soporte para checkpoints con y sin wrapper 'state_dict'
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state = checkpoint["state_dict"]
elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state = checkpoint["model_state_dict"]
elif isinstance(checkpoint, dict) and "model" in checkpoint:
state = checkpoint["model"]
else:
state = checkpoint
# Carga flexible: ignora claves que no coinciden
missing, unexpected = self._model.load_state_dict(state, strict=False)
if missing:
logger.warning(f"[BTC] Parámetros faltantes al cargar pesos: {len(missing)} parámetros faltantes")
if unexpected:
logger.debug(f"[BTC] Parámetros inesperados ignorados: {len(unexpected)} parámetros inesperados")
self._model.eval()
logger.info("[BTC] Modelo BTC cargado y listo para inferencia.")
except Exception as e:
self._model = None
raise RuntimeError(f"[BTC] Error al cargar los pesos del modelo: {e}") from e
def unload(self) -> None:
"""Libera el modelo de la memoria GPU para dejar espacio a otros procesos."""
if self._model is not None:
self._model.cpu()
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("[BTC] VRAM liberada correctamente.")
def recognize(self, audio_path: str, bpm: float = None, beats: list = None) -> list[dict]:
"""
Detecta acordes en un archivo de audio.
Args:
audio_path: Ruta al archivo de audio (.flac, .wav, .mp3).
Returns:
Lista de diccionarios con formato:
[{"t": tiempo_inicio_segundos, "ch": "NombreAcorde"}, ...]
"""
try:
import librosa
except ImportError:
raise ImportError("librosa es requerido para la extracción de características CQT.")
if not os.path.exists(audio_path):
logger.error(f"[BTC] Archivo no encontrado: {audio_path}")
return []
# Asegurar que el modelo esté cargado
if self._model is None:
self.load_model()
logger.info(f"[BTC] Analizando acordes en: {audio_path}")
try:
# ==========================================
# EXTRACCIÓN DE CARACTERÍSTICAS CQT
# ==========================================
y, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
# Transformada de Q constante (CQT) - representación frecuencial de alta resolución
# NOTA: Se ha omitido librosa.effects.harmonic por su altísimo costo en CPU
cqt = np.abs(librosa.cqt(
y,
sr=sr,
hop_length=HOP_LENGTH,
n_bins=N_BINS,
bins_per_octave=BINS_PER_OCTAVE,
fmin=librosa.note_to_hz('C1')
))
# Normalización logarítmica (aumenta contraste de frecuencias armónicas)
cqt = librosa.amplitude_to_db(cqt, ref=np.max)
cqt = (cqt - cqt.mean()) / (cqt.std() + 1e-8)
# Transponer: (freq, time) -> (time, freq)
cqt = cqt.T # shape: (T, N_BINS)
# ==========================================
# INFERENCIA POR BLOQUES (30s para VRAM 4GB)
# ==========================================
frames_per_second = sr / HOP_LENGTH
block_size = int(SEGMENT_SECONDS * frames_per_second)
all_probs = []
with torch.no_grad():
for start in range(0, len(cqt), block_size):
block = cqt[start: start + block_size]
block_tensor = torch.tensor(block, dtype=torch.float32).unsqueeze(0)
block_tensor = block_tensor.to(self._device)
logits = self._model(block_tensor)
probs = F.softmax(logits, dim=-1).squeeze(0).cpu() # (N_frames, 171)
all_probs.append(probs)
# Concatenar probabilidades de toda la canción
full_probs = torch.cat(all_probs, dim=0).numpy()
# ==========================================
# DECODIFICACIÓN, POOLING Y AGRUPACIÓN
# ==========================================
acordes_agrupados = _decode_and_group_chords(full_probs, frames_per_second, bpm, beats)
logger.info(f"[BTC] ✅ Reconocimiento completado: {len(acordes_agrupados)} segmentos de acordes detectados.")
return acordes_agrupados
except Exception as e:
logger.error(f"[BTC] Error durante la inferencia: {e}", exc_info=True)
return []
finally:
# Siempre liberar VRAM al terminar, incluso si hubo error
self.unload()
# ==========================================
# FUNCIÓN DE AGRUPACIÓN Y POOLING RÍTMICO
# ==========================================
def _decode_and_group_chords(probs: np.ndarray, fps: float, bpm: float = None, beats: list = None) -> list[dict]:
"""
Toma un tensor de probabilidades y extrae los acordes.
Estrategia: Emite un acorde por cada beat donde se detecta un cambio.
El frontend hereda el acorde del beat anterior para beats sin cambio,
garantizando que cada compás tenga un acorde visible.
Cuando el modelo predice 'N' (no-chord), se toma el mejor acorde no-N
para evitar compases vacíos.
"""
total_frames = len(probs)
N_IDX = 0 # 'N' es siempre el índice 0 en CHORD_VOCAB
# ── Determinar las fronteras de cada beat ──
if beats is not None and len(beats) > 1:
beat_times = list(beats)
elif bpm is not None and bpm > 40:
beat_duration = 60.0 / bpm
total_duration = total_frames / fps
beat_times = [i * beat_duration for i in range(int(total_duration / beat_duration) + 1)]
else:
# Sin BPM conocido, usar beats de 0.5s
total_duration = total_frames / fps
beat_times = [i * 0.5 for i in range(int(total_duration / 0.5) + 1)]
if not beat_times:
return []
# ── Poolear probabilidades por beat y obtener el mejor acorde no-N ──
beat_chords = []
for i, bt in enumerate(beat_times):
# Calcular rango de frames para este beat
next_bt = beat_times[i + 1] if i + 1 < len(beat_times) else total_frames / fps
start_frame = int(bt * fps)
end_frame = min(int(next_bt * fps), total_frames)
if start_frame >= total_frames or start_frame >= end_frame:
continue
# Promediar probabilidades en el rango del beat
beat_probs = np.mean(probs[start_frame:end_frame], axis=0)
# Obtener el mejor acorde, evitando 'N'
best_idx = np.argmax(beat_probs)
if best_idx == N_IDX:
# Buscar el mejor no-N
beat_probs_no_n = beat_probs.copy()
beat_probs_no_n[N_IDX] = -1
best_idx = np.argmax(beat_probs_no_n)
# Si el mejor no-N tiene probabilidad muy baja, usar N
if beat_probs[best_idx] < 0.02:
best_idx = N_IDX
chord_name = CHORD_VOCAB[best_idx]
beat_chords.append((round(bt, 3), chord_name))
# ── Emitir un acorde solo cuando cambia (optimización) ──
# El frontend hereda el acorde anterior para beats sin cambio,
# así que solo necesitamos emitir en los puntos de cambio.
result = []
prev_chord = None
for bt, chord in beat_chords:
if chord == 'N':
continue # Saltar los pocos frames que aún predicen N
if chord != prev_chord:
result.append({"t": round(bt, 2), "ch": chord})
prev_chord = chord
# Asegurar que siempre hay un acorde al inicio (t=0)
if result and result[0]["t"] > 0.5:
# Si el primer acorde está muy lejos del inicio, insertar uno al inicio
result.insert(0, {"t": 0.0, "ch": result[0]["ch"]})
return result