sms-classifier-api / app /core /model_loader.py
cmeneses99's picture
Refactor: reorganize into core/, api/, web/, templates/
84bb476
import json
from pathlib import Path
from transformers import pipeline
from huggingface_hub import hf_hub_download
HF_REPO = "cmeneses99/sms-classifier"
MODEL_DIR = Path(__file__).parent.parent.parent / "model"
_classifier = None
_categories: list[str] = []
def _ensure_model() -> Path:
"""Download model files from HF Hub if not present locally."""
if (MODEL_DIR / "config.json").exists():
return MODEL_DIR
MODEL_DIR.mkdir(parents=True, exist_ok=True)
for filename in [
"config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.txt",
"label_map.json",
]:
hf_hub_download(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR))
return MODEL_DIR
def load_model() -> None:
"""Load the classifier pipeline and category labels into module-level state."""
global _classifier, _categories
model_path = _ensure_model()
_classifier = pipeline(
"text-classification",
model=str(model_path),
tokenizer=str(model_path),
top_k=3,
device=-1,
)
with open(model_path / "label_map.json", encoding="utf-8") as f:
label_map: dict = json.load(f)
_categories = list(label_map.values())
def get_classifier():
return _classifier
def get_categories() -> list[str]:
return _categories