File size: 4,526 Bytes
440bac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9982dba
 
 
440bac0
 
 
 
 
71e5d2c
440bac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
BTC Toolkit - Descarga Automática de Pesos del Modelo
Gestiona la descarga de los pesos pre-entrenados del modelo BTC desde GitHub.
Incluye:
- Descarga automática solo si no existen localmente.
- Reporte de éxito/fallo al logger.
- Verificación de integridad con SHA256.
"""

import os
import hashlib
import logging
import urllib.request

logger = logging.getLogger(__name__)

# ==========================================
# CONFIGURACIÓN DE PESOS
# ==========================================

# Directorio donde se guardan los pesos (relativo a este archivo)
_MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")

# URL pública del checkpoint pre-entrenado oficial.
# Apunta al checkpoint completo de 12.2 MB alojado en Hugging Face (amaai-lab/music2emo)
BTC_WEIGHTS_URL = "https://huggingface.co/amaai-lab/music2emo/resolve/main/inference/data/btc_model.pt"

# Nombre local del archivo de pesos
BTC_WEIGHTS_FILENAME = "btc_chord.pth"

# Hash SHA256 del archivo esperado (para verificar integridad)
BTC_WEIGHTS_SHA256 = "71c2c5db17e8c43b8a9a9da5db36ef2d667158c07a214eba16344c154c00bf54"


def get_weights_path() -> str:
    """Retorna la ruta absoluta donde deben estar los pesos."""
    return os.path.join(_MODELS_DIR, BTC_WEIGHTS_FILENAME)


def _verify_sha256(filepath: str, expected_hash: str) -> bool:
    """Verifica la integridad del archivo descargado."""
    sha256 = hashlib.sha256()
    with open(filepath, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            sha256.update(chunk)
    return sha256.hexdigest() == expected_hash


def ensure_weights_available() -> str:
    """
    Verifica si los pesos del modelo están disponibles localmente.
    Si no existen, los descarga automáticamente.

    Returns:
        str: Ruta absoluta al archivo de pesos.

    Raises:
        RuntimeError: Si la descarga falla.
    """
    weights_path = get_weights_path()

    # Crear directorio si no existe
    os.makedirs(_MODELS_DIR, exist_ok=True)

    # Si ya existen, no hace nada
    if os.path.exists(weights_path):
        logger.debug(f"[BTC] Pesos del modelo encontrados localmente: {weights_path}")
        return weights_path

    # --- Descarga automática ---
    logger.info(f"[BTC] Pesos no encontrados. Iniciando descarga automática...")
    logger.info(f"[BTC] URL: {BTC_WEIGHTS_URL}")
    logger.info(f"[BTC] Destino: {weights_path}")

    try:
        def reporthook(count, block_size, total_size):
            """Callback de progreso durante la descarga."""
            if total_size > 0:
                downloaded = count * block_size
                percent = min(100, int(downloaded * 100 / total_size))
                mb_downloaded = downloaded / (1024 * 1024)
                mb_total = total_size / (1024 * 1024)
                logger.debug(f"[BTC] Descargando: {percent}% ({mb_downloaded:.1f} MB / {mb_total:.1f} MB)")

        urllib.request.urlretrieve(BTC_WEIGHTS_URL, weights_path, reporthook)

        # Verificar que el archivo se descargó correctamente
        if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
            raise RuntimeError("El archivo descargado está vacío o no existe.")

        file_size_mb = os.path.getsize(weights_path) / (1024 * 1024)

        # Verificar integridad si hay hash disponible
        if BTC_WEIGHTS_SHA256:
            if _verify_sha256(weights_path, BTC_WEIGHTS_SHA256):
                logger.info(f"[BTC] ✅ Pesos descargados y verificados correctamente ({file_size_mb:.1f} MB)")
            else:
                os.remove(weights_path)
                raise RuntimeError("La verificación de integridad del archivo falló (SHA256 no coincide). El archivo fue eliminado.")
        else:
            logger.info(f"[BTC] ✅ Pesos descargados correctamente ({file_size_mb:.1f} MB) → {weights_path}")

        return weights_path

    except Exception as e:
        # Si falló la descarga, limpiar archivo parcial
        if os.path.exists(weights_path):
            try:
                os.remove(weights_path)
            except:
                pass

        error_msg = (
            f"[BTC] ❌ ERROR: No se pudieron descargar los pesos del modelo.\n"
            f"       Razón: {str(e)}\n"
            f"       URL intentada: {BTC_WEIGHTS_URL}\n"
            f"       Verifica tu conexión a internet o descarga el archivo manualmente en: {weights_path}"
        )
        logger.error(error_msg)
        raise RuntimeError(error_msg) from e