Spaces:
Running
Running
File size: 1,417 Bytes
f9ac587 84bb476 f9ac587 84bb476 f9ac587 84bb476 f9ac587 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import json
from pathlib import Path
from transformers import pipeline
from huggingface_hub import hf_hub_download
HF_REPO = "cmeneses99/sms-classifier"
MODEL_DIR = Path(__file__).parent.parent.parent / "model"
_classifier = None
_categories: list[str] = []
def _ensure_model() -> Path:
"""Download model files from HF Hub if not present locally."""
if (MODEL_DIR / "config.json").exists():
return MODEL_DIR
MODEL_DIR.mkdir(parents=True, exist_ok=True)
for filename in [
"config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.txt",
"label_map.json",
]:
hf_hub_download(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR))
return MODEL_DIR
def load_model() -> None:
"""Load the classifier pipeline and category labels into module-level state."""
global _classifier, _categories
model_path = _ensure_model()
_classifier = pipeline(
"text-classification",
model=str(model_path),
tokenizer=str(model_path),
top_k=3,
device=-1,
)
with open(model_path / "label_map.json", encoding="utf-8") as f:
label_map: dict = json.load(f)
_categories = list(label_map.values())
def get_classifier():
return _classifier
def get_categories() -> list[str]:
return _categories
|