File size: 1,144 Bytes
90f2a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

MODEL_ID    = "Nottybro/acra-classifier"
LABEL_NAMES = ["L0_direct", "L1_single_hop", "L2_multi_hop", "L3_complex"]
_tok = None
_mdl = None

def _load():
    global _tok, _mdl
    if _mdl is None:
        print(f"Loading classifier from {MODEL_ID}...")
        _tok = DistilBertTokenizerFast.from_pretrained(MODEL_ID)
        _mdl = DistilBertForSequenceClassification.from_pretrained(MODEL_ID)
        _mdl.eval()

def warm_up():
    _load()
    classify_query("what is the capital of france")
    print("Classifier warm ✓")

def classify_query(query: str) -> dict:
    _load()
    enc = _tok(query, max_length=128, padding="max_length",
               truncation=True, return_tensors="pt")
    with torch.no_grad():
        probs = torch.softmax(_mdl(**enc).logits, dim=-1).squeeze()
    level = int(probs.argmax())
    return {
        "level":      level,
        "label":      LABEL_NAMES[level],
        "confidence": round(probs[level].item(), 4),
        "scores":     {f"L{i}": round(p.item(), 4) for i, p in enumerate(probs)}
    }