acra-api / classifier_inference.py
Nottybro's picture
deploy: classifier_inference.py
90f2a5f verified
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
MODEL_ID = "Nottybro/acra-classifier"
LABEL_NAMES = ["L0_direct", "L1_single_hop", "L2_multi_hop", "L3_complex"]
_tok = None
_mdl = None
def _load():
global _tok, _mdl
if _mdl is None:
print(f"Loading classifier from {MODEL_ID}...")
_tok = DistilBertTokenizerFast.from_pretrained(MODEL_ID)
_mdl = DistilBertForSequenceClassification.from_pretrained(MODEL_ID)
_mdl.eval()
def warm_up():
_load()
classify_query("what is the capital of france")
print("Classifier warm ✓")
def classify_query(query: str) -> dict:
_load()
enc = _tok(query, max_length=128, padding="max_length",
truncation=True, return_tensors="pt")
with torch.no_grad():
probs = torch.softmax(_mdl(**enc).logits, dim=-1).squeeze()
level = int(probs.argmax())
return {
"level": level,
"label": LABEL_NAMES[level],
"confidence": round(probs[level].item(), 4),
"scores": {f"L{i}": round(p.item(), 4) for i, p in enumerate(probs)}
}