melodix-api / btc_toolkit /weights_manager.py
GitHub Action
deploy from github actions
71e5d2c
Raw
History Blame Contribute Delete
4.53 kB
"""
BTC Toolkit - Descarga Automática de Pesos del Modelo
Gestiona la descarga de los pesos pre-entrenados del modelo BTC desde GitHub.
Incluye:
- Descarga automática solo si no existen localmente.
- Reporte de éxito/fallo al logger.
- Verificación de integridad con SHA256.
"""
import os
import hashlib
import logging
import urllib.request
logger = logging.getLogger(__name__)
# ==========================================
# CONFIGURACIÓN DE PESOS
# ==========================================
# Directorio donde se guardan los pesos (relativo a este archivo)
_MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
# URL pública del checkpoint pre-entrenado oficial.
# Apunta al checkpoint completo de 12.2 MB alojado en Hugging Face (amaai-lab/music2emo)
BTC_WEIGHTS_URL = "https://huggingface.co/amaai-lab/music2emo/resolve/main/inference/data/btc_model.pt"
# Nombre local del archivo de pesos
BTC_WEIGHTS_FILENAME = "btc_chord.pth"
# Hash SHA256 del archivo esperado (para verificar integridad)
BTC_WEIGHTS_SHA256 = "71c2c5db17e8c43b8a9a9da5db36ef2d667158c07a214eba16344c154c00bf54"
def get_weights_path() -> str:
"""Retorna la ruta absoluta donde deben estar los pesos."""
return os.path.join(_MODELS_DIR, BTC_WEIGHTS_FILENAME)
def _verify_sha256(filepath: str, expected_hash: str) -> bool:
"""Verifica la integridad del archivo descargado."""
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest() == expected_hash
def ensure_weights_available() -> str:
"""
Verifica si los pesos del modelo están disponibles localmente.
Si no existen, los descarga automáticamente.
Returns:
str: Ruta absoluta al archivo de pesos.
Raises:
RuntimeError: Si la descarga falla.
"""
weights_path = get_weights_path()
# Crear directorio si no existe
os.makedirs(_MODELS_DIR, exist_ok=True)
# Si ya existen, no hace nada
if os.path.exists(weights_path):
logger.debug(f"[BTC] Pesos del modelo encontrados localmente: {weights_path}")
return weights_path
# --- Descarga automática ---
logger.info(f"[BTC] Pesos no encontrados. Iniciando descarga automática...")
logger.info(f"[BTC] URL: {BTC_WEIGHTS_URL}")
logger.info(f"[BTC] Destino: {weights_path}")
try:
def reporthook(count, block_size, total_size):
"""Callback de progreso durante la descarga."""
if total_size > 0:
downloaded = count * block_size
percent = min(100, int(downloaded * 100 / total_size))
mb_downloaded = downloaded / (1024 * 1024)
mb_total = total_size / (1024 * 1024)
logger.debug(f"[BTC] Descargando: {percent}% ({mb_downloaded:.1f} MB / {mb_total:.1f} MB)")
urllib.request.urlretrieve(BTC_WEIGHTS_URL, weights_path, reporthook)
# Verificar que el archivo se descargó correctamente
if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
raise RuntimeError("El archivo descargado está vacío o no existe.")
file_size_mb = os.path.getsize(weights_path) / (1024 * 1024)
# Verificar integridad si hay hash disponible
if BTC_WEIGHTS_SHA256:
if _verify_sha256(weights_path, BTC_WEIGHTS_SHA256):
logger.info(f"[BTC] ✅ Pesos descargados y verificados correctamente ({file_size_mb:.1f} MB)")
else:
os.remove(weights_path)
raise RuntimeError("La verificación de integridad del archivo falló (SHA256 no coincide). El archivo fue eliminado.")
else:
logger.info(f"[BTC] ✅ Pesos descargados correctamente ({file_size_mb:.1f} MB) → {weights_path}")
return weights_path
except Exception as e:
# Si falló la descarga, limpiar archivo parcial
if os.path.exists(weights_path):
try:
os.remove(weights_path)
except:
pass
error_msg = (
f"[BTC] ❌ ERROR: No se pudieron descargar los pesos del modelo.\n"
f" Razón: {str(e)}\n"
f" URL intentada: {BTC_WEIGHTS_URL}\n"
f" Verifica tu conexión a internet o descarga el archivo manualmente en: {weights_path}"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e