import torch from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification MODEL_ID = "Nottybro/acra-classifier" LABEL_NAMES = ["L0_direct", "L1_single_hop", "L2_multi_hop", "L3_complex"] _tok = None _mdl = None def _load(): global _tok, _mdl if _mdl is None: print(f"Loading classifier from {MODEL_ID}...") _tok = DistilBertTokenizerFast.from_pretrained(MODEL_ID) _mdl = DistilBertForSequenceClassification.from_pretrained(MODEL_ID) _mdl.eval() def warm_up(): _load() classify_query("what is the capital of france") print("Classifier warm ✓") def classify_query(query: str) -> dict: _load() enc = _tok(query, max_length=128, padding="max_length", truncation=True, return_tensors="pt") with torch.no_grad(): probs = torch.softmax(_mdl(**enc).logits, dim=-1).squeeze() level = int(probs.argmax()) return { "level": level, "label": LABEL_NAMES[level], "confidence": round(probs[level].item(), 4), "scores": {f"L{i}": round(p.item(), 4) for i, p in enumerate(probs)} }