| import torch |
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification |
|
|
| MODEL_ID = "Nottybro/acra-classifier" |
| LABEL_NAMES = ["L0_direct", "L1_single_hop", "L2_multi_hop", "L3_complex"] |
| _tok = None |
| _mdl = None |
|
|
| def _load(): |
| global _tok, _mdl |
| if _mdl is None: |
| print(f"Loading classifier from {MODEL_ID}...") |
| _tok = DistilBertTokenizerFast.from_pretrained(MODEL_ID) |
| _mdl = DistilBertForSequenceClassification.from_pretrained(MODEL_ID) |
| _mdl.eval() |
|
|
| def warm_up(): |
| _load() |
| classify_query("what is the capital of france") |
| print("Classifier warm ✓") |
|
|
| def classify_query(query: str) -> dict: |
| _load() |
| enc = _tok(query, max_length=128, padding="max_length", |
| truncation=True, return_tensors="pt") |
| with torch.no_grad(): |
| probs = torch.softmax(_mdl(**enc).logits, dim=-1).squeeze() |
| level = int(probs.argmax()) |
| return { |
| "level": level, |
| "label": LABEL_NAMES[level], |
| "confidence": round(probs[level].item(), 4), |
| "scores": {f"L{i}": round(p.item(), 4) for i, p in enumerate(probs)} |
| } |
|
|