File size: 5,628 Bytes
a63c61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
import re
import threading

logger = logging.getLogger(__name__)

# ── Instant keyword shortcuts ─────────────────────────────────────────────────

_SMALL_TALK_EXACT = {
    "hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye",
    "good morning", "good afternoon", "good evening", "sup", "yo",
    "hello there", "hey there", "hi there", "greetings",
    # frustration / profanity
    "wtf", "lol", "lmao", "omg", "damn", "shit", "fuck",
    "for fuck sake", "for fucks sake", "oh my god", "are you kidding",
    "seriously", "come on", "ugh", "argh",
}
_SMALL_TALK_PREFIX = (
    "how are you", "what are you", "who are you", "what can you do",
    "tell me a joke", "make me laugh", "what's up", "whats up",
    "for fuck", "for fucks", "what the fuck", "what the hell",
    "are you serious", "you must be", "hello ", "hi ", "hey ",
)

# Temporal patterns β†’ always NEWS (user is asking about time-scoped news)
_TEMPORAL_PATTERNS = re.compile(
    r"\b("
    r"today|yesterday|tomorrow|tonight|"
    r"this (week|month|year|morning|evening|afternoon)|"
    r"last (week|month|year|night|monday|tuesday|wednesday|thursday|friday|saturday|sunday)|"
    r"next (week|month|year)|"
    r"past (\d+ )?(day|days|week|weeks|month|months|year|years)|"
    r"recent(ly)?|latest|breaking|just (now|happened|announced)|"
    r"(monday|tuesday|wednesday|thursday|friday|saturday|sunday)|"
    r"january|february|march|april|may|june|july|august|september|october|november|december|"
    r"\d{4}|"           # year like 2024, 2025
    r"\d+(st|nd|rd|th)" # ordinal like 1st, 2nd
    r")\b",
    re.IGNORECASE
)


class IntentClassifier:
    """
    Local zero-shot intent classifier using MoritzLaurer/DeBERTa-v3-small-mnli.
    - 140 MB model, ~20 ms inference, no API calls, no rate limits.
    - Lazy-loaded on first use so startup is not blocked.
    - Thread-safe singleton load.

    Classification priority:
    1. Small-talk exact/prefix match β†’ OTHER (instant)
    2. Temporal pattern match β†’ NEWS (instant, handles "this week", "yesterday", etc.)
    3. DeBERTa NLI model β†’ NEWS or OTHER (~20ms)
    4. Keyword fallback if model failed to load
    5. Default β†’ NEWS (always prefer RAG over hallucination)
    """
    MODEL_NAME = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"

    def __init__(self):
        self._pipe = None
        self._lock = threading.Lock()
        self._load_failed = False

    def _load(self):
        if self._pipe is not None or self._load_failed:
            return
        with self._lock:
            if self._pipe is not None or self._load_failed:
                return
            try:
                from transformers import pipeline
                logger.info(f"Loading intent classifier: {self.MODEL_NAME} ...")
                self._pipe = pipeline(
                    "zero-shot-classification",
                    model=self.MODEL_NAME,
                    device=-1,          # CPU
                    multi_label=False,
                )
                logger.info("βœ… Intent classifier loaded.")
            except Exception as e:
                logger.error(f"❌ Failed to load intent classifier: {e}. Falling back to keyword matching.")
                self._load_failed = True

    def classify(self, query: str) -> str:
        """Returns 'NEWS' or 'OTHER'."""
        q = query.strip().lower()

        # 1. Instant small-talk shortcuts
        if q in _SMALL_TALK_EXACT:
            logger.debug(f"Intent: OTHER (small-talk exact) β€” '{q}'")
            return "OTHER"
        if any(q.startswith(p) for p in _SMALL_TALK_PREFIX):
            logger.debug(f"Intent: OTHER (small-talk prefix) β€” '{q}'")
            return "OTHER"

        # 2. Temporal pattern β†’ always NEWS
        if _TEMPORAL_PATTERNS.search(query):
            logger.debug(f"Intent: NEWS (temporal pattern) β€” '{query[:60]}'")
            return "NEWS"

        # 3. DeBERTa NLI model
        self._load()
        if self._pipe is not None:
            try:
                result = self._pipe(
                    query,
                    candidate_labels=[
                        "news, current events, politics, economy, sports, technology, world affairs",
                        "small talk, greeting, joke, or general question unrelated to news",
                    ],
                    hypothesis_template="This message is about {}.",
                )
                top_label = result["labels"][0]
                score = result["scores"][0]
                intent = "NEWS" if "news" in top_label else "OTHER"
                logger.debug(f"Intent: {intent} (DeBERTa score={score:.2f}) β€” '{query[:60]}'")
                return intent
            except Exception as e:
                logger.warning(f"Intent classifier inference failed: {e}. Defaulting to NEWS.")

        # 4. Keyword fallback
        news_signals = [
            "latest", "recent", "news", "update", "development", "what happened",
            "who is", "what is", "when did", "why did", "how did", "report",
            "conflict", "election", "economy", "war", "crisis", "deal",
            "agreement", "president", "minister", "market", "price",
            "attack", "protest", "africa", "ethiopia",
        ]
        if any(s in q for s in news_signals):
            return "NEWS"

        # 5. Default β€” always prefer RAG over hallucination
        return "NEWS"


# Module-level singleton β€” shared across all requests
intent_classifier = IntentClassifier()