Spaces:
Running
Running
| """ | |
| BTC Toolkit - Descarga Automática de Pesos del Modelo | |
| Gestiona la descarga de los pesos pre-entrenados del modelo BTC desde GitHub. | |
| Incluye: | |
| - Descarga automática solo si no existen localmente. | |
| - Reporte de éxito/fallo al logger. | |
| - Verificación de integridad con SHA256. | |
| """ | |
| import os | |
| import hashlib | |
| import logging | |
| import urllib.request | |
| logger = logging.getLogger(__name__) | |
| # ========================================== | |
| # CONFIGURACIÓN DE PESOS | |
| # ========================================== | |
| # Directorio donde se guardan los pesos (relativo a este archivo) | |
| _MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") | |
| # URL pública del checkpoint pre-entrenado oficial. | |
| # Apunta al checkpoint completo de 12.2 MB alojado en Hugging Face (amaai-lab/music2emo) | |
| BTC_WEIGHTS_URL = "https://huggingface.co/amaai-lab/music2emo/resolve/main/inference/data/btc_model.pt" | |
| # Nombre local del archivo de pesos | |
| BTC_WEIGHTS_FILENAME = "btc_chord.pth" | |
| # Hash SHA256 del archivo esperado (para verificar integridad) | |
| BTC_WEIGHTS_SHA256 = "71c2c5db17e8c43b8a9a9da5db36ef2d667158c07a214eba16344c154c00bf54" | |
| def get_weights_path() -> str: | |
| """Retorna la ruta absoluta donde deben estar los pesos.""" | |
| return os.path.join(_MODELS_DIR, BTC_WEIGHTS_FILENAME) | |
| def _verify_sha256(filepath: str, expected_hash: str) -> bool: | |
| """Verifica la integridad del archivo descargado.""" | |
| sha256 = hashlib.sha256() | |
| with open(filepath, "rb") as f: | |
| for chunk in iter(lambda: f.read(8192), b""): | |
| sha256.update(chunk) | |
| return sha256.hexdigest() == expected_hash | |
| def ensure_weights_available() -> str: | |
| """ | |
| Verifica si los pesos del modelo están disponibles localmente. | |
| Si no existen, los descarga automáticamente. | |
| Returns: | |
| str: Ruta absoluta al archivo de pesos. | |
| Raises: | |
| RuntimeError: Si la descarga falla. | |
| """ | |
| weights_path = get_weights_path() | |
| # Crear directorio si no existe | |
| os.makedirs(_MODELS_DIR, exist_ok=True) | |
| # Si ya existen, no hace nada | |
| if os.path.exists(weights_path): | |
| logger.debug(f"[BTC] Pesos del modelo encontrados localmente: {weights_path}") | |
| return weights_path | |
| # --- Descarga automática --- | |
| logger.info(f"[BTC] Pesos no encontrados. Iniciando descarga automática...") | |
| logger.info(f"[BTC] URL: {BTC_WEIGHTS_URL}") | |
| logger.info(f"[BTC] Destino: {weights_path}") | |
| try: | |
| def reporthook(count, block_size, total_size): | |
| """Callback de progreso durante la descarga.""" | |
| if total_size > 0: | |
| downloaded = count * block_size | |
| percent = min(100, int(downloaded * 100 / total_size)) | |
| mb_downloaded = downloaded / (1024 * 1024) | |
| mb_total = total_size / (1024 * 1024) | |
| logger.debug(f"[BTC] Descargando: {percent}% ({mb_downloaded:.1f} MB / {mb_total:.1f} MB)") | |
| urllib.request.urlretrieve(BTC_WEIGHTS_URL, weights_path, reporthook) | |
| # Verificar que el archivo se descargó correctamente | |
| if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0: | |
| raise RuntimeError("El archivo descargado está vacío o no existe.") | |
| file_size_mb = os.path.getsize(weights_path) / (1024 * 1024) | |
| # Verificar integridad si hay hash disponible | |
| if BTC_WEIGHTS_SHA256: | |
| if _verify_sha256(weights_path, BTC_WEIGHTS_SHA256): | |
| logger.info(f"[BTC] ✅ Pesos descargados y verificados correctamente ({file_size_mb:.1f} MB)") | |
| else: | |
| os.remove(weights_path) | |
| raise RuntimeError("La verificación de integridad del archivo falló (SHA256 no coincide). El archivo fue eliminado.") | |
| else: | |
| logger.info(f"[BTC] ✅ Pesos descargados correctamente ({file_size_mb:.1f} MB) → {weights_path}") | |
| return weights_path | |
| except Exception as e: | |
| # Si falló la descarga, limpiar archivo parcial | |
| if os.path.exists(weights_path): | |
| try: | |
| os.remove(weights_path) | |
| except: | |
| pass | |
| error_msg = ( | |
| f"[BTC] ❌ ERROR: No se pudieron descargar los pesos del modelo.\n" | |
| f" Razón: {str(e)}\n" | |
| f" URL intentada: {BTC_WEIGHTS_URL}\n" | |
| f" Verifica tu conexión a internet o descarga el archivo manualmente en: {weights_path}" | |
| ) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |