File size: 2,032 Bytes
84bb476
 
 
f9ac587
 
 
 
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
 
 
 
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
 
 
 
 
84bb476
f9ac587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from ..core.cache import LRUCache
from ..core.model_loader import get_classifier
from ..core.schemas import PredictResponse, IntentPrediction

_cache = LRUCache(max_size=512)


def get_cache() -> LRUCache:
    return _cache


def run_inference(texts: list[str]) -> list[PredictResponse]:
    """Run the model on a list of texts and return structured responses."""
    classifier = get_classifier()
    raw = classifier(texts, batch_size=16)
    responses = []
    for text, result in zip(texts, raw):
        preds = result if isinstance(result, list) else [result]
        top_3 = [
            IntentPrediction(category=r["label"], confidence=round(r["score"], 4))
            for r in preds
        ]
        responses.append(PredictResponse(text=text, prediction=top_3[0], top_3=top_3))
    return responses


def classify_one(normalized_text: str, original_text: str) -> PredictResponse:
    """Classify a single message, using cache when available."""
    cached = _cache.get(normalized_text)
    if cached:
        return cached.model_copy(update={"cached": True})
    response = run_inference([normalized_text])[0]
    response.text = original_text
    _cache.set(normalized_text, response)
    return response


def classify_many(normalized_texts: list[str], original_texts: list[str]) -> tuple[list[PredictResponse], int]:
    """Classify a batch of messages. Returns (results, from_cache_count)."""
    results: list[PredictResponse | None] = [None] * len(normalized_texts)
    from_cache = 0
    pending = []

    for i, key in enumerate(normalized_texts):
        cached = _cache.get(key)
        if cached:
            results[i] = cached
            from_cache += 1
        else:
            pending.append(i)

    if pending:
        inferred = run_inference([normalized_texts[i] for i in pending])
        for i, response in zip(pending, inferred):
            response.text = original_texts[i]
            _cache.set(normalized_texts[i], response)
            results[i] = response

    return results, from_cache