File size: 3,194 Bytes
6a31658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
"""
Inferencia para MIA (RustyLinux/MiaMotion)
- Carga el modelo desde archivos locales si existen (best_model.pt, config.json).
- Si no están en local, los descarga desde el Hugging Face Hub.
- Expone `predict(text: str)` para usar desde scripts, Spaces (Gradio) o servicios.
"""

import os
import json
from pathlib import Path
from typing import Any, Dict

import torch

try:
    from huggingface_hub import hf_hub_download
except Exception:
    hf_hub_download = None  # si no está instalado, funcionará en local con archivos presentes

from emotion_classifier_model import EmotionClassifier

# ---------------- Config ----------------
REPO_ID = "RustyLinux/MiaMotion"  # tu repo en el Hub

LOCAL_CKPT = Path("best_model.pt")
LOCAL_CFG  = Path("config.json")

_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model = None  # cache global


def _resolve_paths() -> (str, str):
    """
    Retorna (ckpt_path, cfg_path). Busca primero local, si no, descarga del Hub.
    """
    if LOCAL_CKPT.exists() and LOCAL_CFG.exists():
        return str(LOCAL_CKPT.resolve()), str(LOCAL_CFG.resolve())

    if hf_hub_download is None:
        raise RuntimeError(
            "No se encontraron 'best_model.pt' y 'config.json' en local, "
            "y 'huggingface_hub' no está instalado para descargarlos."
        )

    ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt")
    cfg_path  = hf_hub_download(repo_id=REPO_ID, filename="config.json")
    return ckpt_path, cfg_path


def _load_model() -> EmotionClassifier:
    global _model
    if _model is not None:
        return _model

    ckpt_path, cfg_path = _resolve_paths()

    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = json.load(f)

    model = EmotionClassifier(
        model_name=cfg.get("base_model_id", "dccuchile/bert-base-spanish-wwm-cased"),
        max_length=cfg.get("max_length", 128),
        hidden1=cfg.get("hidden1", 128),
        hidden2=cfg.get("hidden2", 64),
        num_classes=cfg.get("num_classes", 6),
        dropout=cfg.get("dropout", 0.3),
        device=_device,
        pretrained_encoder=cfg.get("pretrained_encoder", "beto"),
    )

    state = torch.load(ckpt_path, map_location=_device)
    if isinstance(state, dict) and "model_state_dict" in state:
        model.load_state_dict(state["model_state_dict"])
    else:
        # por si guardaste el state_dict directo
        model.load_state_dict(state)

    model.eval()
    _model = model
    return _model


def predict(text: str, return_probs: bool = False) -> Any:
    """
    Predice la emoción para un texto.
    - return_probs=True: devuelve (label:str, probs:list[float]) en el orden de model.label_map
    - return_probs=False: devuelve solo label:str
    """
    model = _load_model()
    if return_probs:
        label, probs = model.predict_single(text, return_probs=True)
        return label, probs.tolist()
    return model.predict_single(text)


if __name__ == "__main__":
    # Pruebas rápidas
    print(predict("Estoy muy contento con los resultados", return_probs=True))
    print(predict("Tengo miedo de lo que pueda pasar", return_probs=True))