Spaces:
Running
Running
File size: 4,526 Bytes
440bac0 9982dba 440bac0 71e5d2c 440bac0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """
BTC Toolkit - Descarga Automática de Pesos del Modelo
Gestiona la descarga de los pesos pre-entrenados del modelo BTC desde GitHub.
Incluye:
- Descarga automática solo si no existen localmente.
- Reporte de éxito/fallo al logger.
- Verificación de integridad con SHA256.
"""
import os
import hashlib
import logging
import urllib.request
logger = logging.getLogger(__name__)
# ==========================================
# CONFIGURACIÓN DE PESOS
# ==========================================
# Directorio donde se guardan los pesos (relativo a este archivo)
_MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
# URL pública del checkpoint pre-entrenado oficial.
# Apunta al checkpoint completo de 12.2 MB alojado en Hugging Face (amaai-lab/music2emo)
BTC_WEIGHTS_URL = "https://huggingface.co/amaai-lab/music2emo/resolve/main/inference/data/btc_model.pt"
# Nombre local del archivo de pesos
BTC_WEIGHTS_FILENAME = "btc_chord.pth"
# Hash SHA256 del archivo esperado (para verificar integridad)
BTC_WEIGHTS_SHA256 = "71c2c5db17e8c43b8a9a9da5db36ef2d667158c07a214eba16344c154c00bf54"
def get_weights_path() -> str:
"""Retorna la ruta absoluta donde deben estar los pesos."""
return os.path.join(_MODELS_DIR, BTC_WEIGHTS_FILENAME)
def _verify_sha256(filepath: str, expected_hash: str) -> bool:
"""Verifica la integridad del archivo descargado."""
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest() == expected_hash
def ensure_weights_available() -> str:
"""
Verifica si los pesos del modelo están disponibles localmente.
Si no existen, los descarga automáticamente.
Returns:
str: Ruta absoluta al archivo de pesos.
Raises:
RuntimeError: Si la descarga falla.
"""
weights_path = get_weights_path()
# Crear directorio si no existe
os.makedirs(_MODELS_DIR, exist_ok=True)
# Si ya existen, no hace nada
if os.path.exists(weights_path):
logger.debug(f"[BTC] Pesos del modelo encontrados localmente: {weights_path}")
return weights_path
# --- Descarga automática ---
logger.info(f"[BTC] Pesos no encontrados. Iniciando descarga automática...")
logger.info(f"[BTC] URL: {BTC_WEIGHTS_URL}")
logger.info(f"[BTC] Destino: {weights_path}")
try:
def reporthook(count, block_size, total_size):
"""Callback de progreso durante la descarga."""
if total_size > 0:
downloaded = count * block_size
percent = min(100, int(downloaded * 100 / total_size))
mb_downloaded = downloaded / (1024 * 1024)
mb_total = total_size / (1024 * 1024)
logger.debug(f"[BTC] Descargando: {percent}% ({mb_downloaded:.1f} MB / {mb_total:.1f} MB)")
urllib.request.urlretrieve(BTC_WEIGHTS_URL, weights_path, reporthook)
# Verificar que el archivo se descargó correctamente
if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
raise RuntimeError("El archivo descargado está vacío o no existe.")
file_size_mb = os.path.getsize(weights_path) / (1024 * 1024)
# Verificar integridad si hay hash disponible
if BTC_WEIGHTS_SHA256:
if _verify_sha256(weights_path, BTC_WEIGHTS_SHA256):
logger.info(f"[BTC] ✅ Pesos descargados y verificados correctamente ({file_size_mb:.1f} MB)")
else:
os.remove(weights_path)
raise RuntimeError("La verificación de integridad del archivo falló (SHA256 no coincide). El archivo fue eliminado.")
else:
logger.info(f"[BTC] ✅ Pesos descargados correctamente ({file_size_mb:.1f} MB) → {weights_path}")
return weights_path
except Exception as e:
# Si falló la descarga, limpiar archivo parcial
if os.path.exists(weights_path):
try:
os.remove(weights_path)
except:
pass
error_msg = (
f"[BTC] ❌ ERROR: No se pudieron descargar los pesos del modelo.\n"
f" Razón: {str(e)}\n"
f" URL intentada: {BTC_WEIGHTS_URL}\n"
f" Verifica tu conexión a internet o descarga el archivo manualmente en: {weights_path}"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
|