| |
| """ |
| Inferencia para MIA (RustyLinux/MiaMotion) |
| - Carga el modelo desde archivos locales si existen (best_model.pt, config.json). |
| - Si no están en local, los descarga desde el Hugging Face Hub. |
| - Expone `predict(text: str)` para usar desde scripts, Spaces (Gradio) o servicios. |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from typing import Any, Dict |
|
|
| import torch |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| except Exception: |
| hf_hub_download = None |
|
|
| from emotion_classifier_model import EmotionClassifier |
|
|
| |
| REPO_ID = "RustyLinux/MiaMotion" |
|
|
| LOCAL_CKPT = Path("best_model.pt") |
| LOCAL_CFG = Path("config.json") |
|
|
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| _model = None |
|
|
|
|
| def _resolve_paths() -> (str, str): |
| """ |
| Retorna (ckpt_path, cfg_path). Busca primero local, si no, descarga del Hub. |
| """ |
| if LOCAL_CKPT.exists() and LOCAL_CFG.exists(): |
| return str(LOCAL_CKPT.resolve()), str(LOCAL_CFG.resolve()) |
|
|
| if hf_hub_download is None: |
| raise RuntimeError( |
| "No se encontraron 'best_model.pt' y 'config.json' en local, " |
| "y 'huggingface_hub' no está instalado para descargarlos." |
| ) |
|
|
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt") |
| cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") |
| return ckpt_path, cfg_path |
|
|
|
|
| def _load_model() -> EmotionClassifier: |
| global _model |
| if _model is not None: |
| return _model |
|
|
| ckpt_path, cfg_path = _resolve_paths() |
|
|
| with open(cfg_path, "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
|
|
| model = EmotionClassifier( |
| model_name=cfg.get("base_model_id", "dccuchile/bert-base-spanish-wwm-cased"), |
| max_length=cfg.get("max_length", 128), |
| hidden1=cfg.get("hidden1", 128), |
| hidden2=cfg.get("hidden2", 64), |
| num_classes=cfg.get("num_classes", 6), |
| dropout=cfg.get("dropout", 0.3), |
| device=_device, |
| pretrained_encoder=cfg.get("pretrained_encoder", "beto"), |
| ) |
|
|
| state = torch.load(ckpt_path, map_location=_device) |
| if isinstance(state, dict) and "model_state_dict" in state: |
| model.load_state_dict(state["model_state_dict"]) |
| else: |
| |
| model.load_state_dict(state) |
|
|
| model.eval() |
| _model = model |
| return _model |
|
|
|
|
| def predict(text: str, return_probs: bool = False) -> Any: |
| """ |
| Predice la emoción para un texto. |
| - return_probs=True: devuelve (label:str, probs:list[float]) en el orden de model.label_map |
| - return_probs=False: devuelve solo label:str |
| """ |
| model = _load_model() |
| if return_probs: |
| label, probs = model.predict_single(text, return_probs=True) |
| return label, probs.tolist() |
| return model.predict_single(text) |
|
|
|
|
| if __name__ == "__main__": |
| |
| print(predict("Estoy muy contento con los resultados", return_probs=True)) |
| print(predict("Tengo miedo de lo que pueda pasar", return_probs=True)) |
|
|