|
|
|
|
|
""" |
|
|
Inferencia para MIA (RustyLinux/MiaMotion) |
|
|
- Carga el modelo desde archivos locales si existen (best_model.pt, config.json). |
|
|
- Si no están en local, los descarga desde el Hugging Face Hub. |
|
|
- Expone `predict(text: str)` para usar desde scripts, Spaces (Gradio) o servicios. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
except Exception: |
|
|
hf_hub_download = None |
|
|
|
|
|
from emotion_classifier_model import EmotionClassifier |
|
|
|
|
|
|
|
|
REPO_ID = "RustyLinux/MiaMotion" |
|
|
|
|
|
LOCAL_CKPT = Path("best_model.pt") |
|
|
LOCAL_CFG = Path("config.json") |
|
|
|
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
_model = None |
|
|
|
|
|
|
|
|
def _resolve_paths() -> (str, str): |
|
|
""" |
|
|
Retorna (ckpt_path, cfg_path). Busca primero local, si no, descarga del Hub. |
|
|
""" |
|
|
if LOCAL_CKPT.exists() and LOCAL_CFG.exists(): |
|
|
return str(LOCAL_CKPT.resolve()), str(LOCAL_CFG.resolve()) |
|
|
|
|
|
if hf_hub_download is None: |
|
|
raise RuntimeError( |
|
|
"No se encontraron 'best_model.pt' y 'config.json' en local, " |
|
|
"y 'huggingface_hub' no está instalado para descargarlos." |
|
|
) |
|
|
|
|
|
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt") |
|
|
cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") |
|
|
return ckpt_path, cfg_path |
|
|
|
|
|
|
|
|
def _load_model() -> EmotionClassifier: |
|
|
global _model |
|
|
if _model is not None: |
|
|
return _model |
|
|
|
|
|
ckpt_path, cfg_path = _resolve_paths() |
|
|
|
|
|
with open(cfg_path, "r", encoding="utf-8") as f: |
|
|
cfg = json.load(f) |
|
|
|
|
|
model = EmotionClassifier( |
|
|
model_name=cfg.get("base_model_id", "dccuchile/bert-base-spanish-wwm-cased"), |
|
|
max_length=cfg.get("max_length", 128), |
|
|
hidden1=cfg.get("hidden1", 128), |
|
|
hidden2=cfg.get("hidden2", 64), |
|
|
num_classes=cfg.get("num_classes", 6), |
|
|
dropout=cfg.get("dropout", 0.3), |
|
|
device=_device, |
|
|
pretrained_encoder=cfg.get("pretrained_encoder", "beto"), |
|
|
) |
|
|
|
|
|
state = torch.load(ckpt_path, map_location=_device) |
|
|
if isinstance(state, dict) and "model_state_dict" in state: |
|
|
model.load_state_dict(state["model_state_dict"]) |
|
|
else: |
|
|
|
|
|
model.load_state_dict(state) |
|
|
|
|
|
model.eval() |
|
|
_model = model |
|
|
return _model |
|
|
|
|
|
|
|
|
def predict(text: str, return_probs: bool = False) -> Any: |
|
|
""" |
|
|
Predice la emoción para un texto. |
|
|
- return_probs=True: devuelve (label:str, probs:list[float]) en el orden de model.label_map |
|
|
- return_probs=False: devuelve solo label:str |
|
|
""" |
|
|
model = _load_model() |
|
|
if return_probs: |
|
|
label, probs = model.predict_single(text, return_probs=True) |
|
|
return label, probs.tolist() |
|
|
return model.predict_single(text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print(predict("Estoy muy contento con los resultados", return_probs=True)) |
|
|
print(predict("Tengo miedo de lo que pueda pasar", return_probs=True)) |
|
|
|