File size: 1,417 Bytes
f9ac587
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import json
from pathlib import Path
from transformers import pipeline
from huggingface_hub import hf_hub_download

HF_REPO = "cmeneses99/sms-classifier"
MODEL_DIR = Path(__file__).parent.parent.parent / "model"

_classifier = None
_categories: list[str] = []


def _ensure_model() -> Path:
    """Download model files from HF Hub if not present locally."""
    if (MODEL_DIR / "config.json").exists():
        return MODEL_DIR

    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    for filename in [
        "config.json",
        "model.safetensors",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "vocab.txt",
        "label_map.json",
    ]:
        hf_hub_download(repo_id=HF_REPO, filename=filename, local_dir=str(MODEL_DIR))

    return MODEL_DIR


def load_model() -> None:
    """Load the classifier pipeline and category labels into module-level state."""
    global _classifier, _categories

    model_path = _ensure_model()

    _classifier = pipeline(
        "text-classification",
        model=str(model_path),
        tokenizer=str(model_path),
        top_k=3,
        device=-1,
    )

    with open(model_path / "label_map.json", encoding="utf-8") as f:
        label_map: dict = json.load(f)
    _categories = list(label_map.values())


def get_classifier():
    return _classifier


def get_categories() -> list[str]:
    return _categories