""" BTC Toolkit - Descarga Automática de Pesos del Modelo Gestiona la descarga de los pesos pre-entrenados del modelo BTC desde GitHub. Incluye: - Descarga automática solo si no existen localmente. - Reporte de éxito/fallo al logger. - Verificación de integridad con SHA256. """ import os import hashlib import logging import urllib.request logger = logging.getLogger(__name__) # ========================================== # CONFIGURACIÓN DE PESOS # ========================================== # Directorio donde se guardan los pesos (relativo a este archivo) _MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") # URL pública del checkpoint pre-entrenado oficial. # Apunta al checkpoint completo de 12.2 MB alojado en Hugging Face (amaai-lab/music2emo) BTC_WEIGHTS_URL = "https://huggingface.co/amaai-lab/music2emo/resolve/main/inference/data/btc_model.pt" # Nombre local del archivo de pesos BTC_WEIGHTS_FILENAME = "btc_chord.pth" # Hash SHA256 del archivo esperado (para verificar integridad) BTC_WEIGHTS_SHA256 = "71c2c5db17e8c43b8a9a9da5db36ef2d667158c07a214eba16344c154c00bf54" def get_weights_path() -> str: """Retorna la ruta absoluta donde deben estar los pesos.""" return os.path.join(_MODELS_DIR, BTC_WEIGHTS_FILENAME) def _verify_sha256(filepath: str, expected_hash: str) -> bool: """Verifica la integridad del archivo descargado.""" sha256 = hashlib.sha256() with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(8192), b""): sha256.update(chunk) return sha256.hexdigest() == expected_hash def ensure_weights_available() -> str: """ Verifica si los pesos del modelo están disponibles localmente. Si no existen, los descarga automáticamente. Returns: str: Ruta absoluta al archivo de pesos. Raises: RuntimeError: Si la descarga falla. """ weights_path = get_weights_path() # Crear directorio si no existe os.makedirs(_MODELS_DIR, exist_ok=True) # Si ya existen, no hace nada if os.path.exists(weights_path): logger.debug(f"[BTC] Pesos del modelo encontrados localmente: {weights_path}") return weights_path # --- Descarga automática --- logger.info(f"[BTC] Pesos no encontrados. Iniciando descarga automática...") logger.info(f"[BTC] URL: {BTC_WEIGHTS_URL}") logger.info(f"[BTC] Destino: {weights_path}") try: def reporthook(count, block_size, total_size): """Callback de progreso durante la descarga.""" if total_size > 0: downloaded = count * block_size percent = min(100, int(downloaded * 100 / total_size)) mb_downloaded = downloaded / (1024 * 1024) mb_total = total_size / (1024 * 1024) logger.debug(f"[BTC] Descargando: {percent}% ({mb_downloaded:.1f} MB / {mb_total:.1f} MB)") urllib.request.urlretrieve(BTC_WEIGHTS_URL, weights_path, reporthook) # Verificar que el archivo se descargó correctamente if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0: raise RuntimeError("El archivo descargado está vacío o no existe.") file_size_mb = os.path.getsize(weights_path) / (1024 * 1024) # Verificar integridad si hay hash disponible if BTC_WEIGHTS_SHA256: if _verify_sha256(weights_path, BTC_WEIGHTS_SHA256): logger.info(f"[BTC] ✅ Pesos descargados y verificados correctamente ({file_size_mb:.1f} MB)") else: os.remove(weights_path) raise RuntimeError("La verificación de integridad del archivo falló (SHA256 no coincide). El archivo fue eliminado.") else: logger.info(f"[BTC] ✅ Pesos descargados correctamente ({file_size_mb:.1f} MB) → {weights_path}") return weights_path except Exception as e: # Si falló la descarga, limpiar archivo parcial if os.path.exists(weights_path): try: os.remove(weights_path) except: pass error_msg = ( f"[BTC] ❌ ERROR: No se pudieron descargar los pesos del modelo.\n" f" Razón: {str(e)}\n" f" URL intentada: {BTC_WEIGHTS_URL}\n" f" Verifica tu conexión a internet o descarga el archivo manualmente en: {weights_path}" ) logger.error(error_msg) raise RuntimeError(error_msg) from e