Spaces:
Running
Running
| """ | |
| Intent Classifier v5 — Fast Keyword Pre-Check + LLM Fallback Chain | |
| Architecture: | |
| Layer 0: Instant exact match (0ms) — greetings, single-char, test | |
| Layer 1: Fast keyword rules (0ms) — temporal/historical/other patterns | |
| ↳ Catches 80%+ of queries instantly, no API call needed | |
| Layer 2: Groq llama-3.1-8b-instant — 14,400 free RPD, ~50ms (PRIMARY) | |
| Layer 3: Gemini Flash fallback — 1,500 free RPD, ~200ms (FALLBACK 1) | |
| Layer 4: OpenRouter free router — free models pool, ~300ms (FALLBACK 2) | |
| Layer 5: HuggingFace Inference API — ~300 RPH, ~2s (FALLBACK 3) | |
| Layer 6: Safe default — NEWS_GENERAL, 0ms (ALWAYS WORKS) | |
| Layer 1 keyword rules cover: | |
| - Temporal: "today", "now", "breaking", "latest", "just happened", etc. | |
| - Historical: "history of", "background", "what caused", "explain", etc. | |
| - Other: greetings, identity questions, math, creative writing | |
| - Ethiopia-specific: "Abiy", "TPLF", "Fano", "Tigray" → NEWS_GENERAL fast path | |
| Why this matters: | |
| - Saves Groq API quota (14,400 RPD is finite) | |
| - Reduces latency from ~50ms → 0ms for common queries | |
| - Works offline / when all LLM providers are down | |
| - Handles Amharic/Arabic/Somali temporal words natively | |
| """ | |
| import logging | |
| import re | |
| import time | |
| import httpx | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # LAYER 0: INSTANT EXACT MATCH — greetings, empty, test | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| _INSTANT_OTHER = { | |
| "hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye", | |
| "ok", "okay", "yes", "no", "sure", "cool", "nice", | |
| "lol", "lmao", "haha", "omg", "wtf", "wow", | |
| ".", "..", "...", "?", "!", "test", "ping", | |
| } | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # LAYER 1: FAST KEYWORD RULES | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # ── Temporal signals → NEWS_TEMPORAL ───────────────────────────────────────── | |
| # English | |
| _TEMPORAL_EN = re.compile( | |
| r"\b(" | |
| r"today|tonight|right now|just now|breaking|just happened|" | |
| r"this morning|this afternoon|this evening|this hour|" | |
| r"latest|current(ly)?|live|ongoing|unfolding|" | |
| r"yesterday|last night|" | |
| r"this week|this month|this year|" | |
| r"recent(ly)?|new(ly)?|fresh|" | |
| r"past (few )?(hours?|days?|weeks?)|" | |
| r"in the (last|past) \d+|" | |
| r"as of (today|now)|" | |
| r"update[sd]?|news flash|alert" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| # Amharic temporal words (common ones) | |
| _TEMPORAL_AM = re.compile( | |
| r"(ዛሬ|አሁን|ዘንድሮ|ቅርብ|አዲስ|ዜና|ዛሬ ምሽት|ዛሬ ጠዋት)", | |
| re.UNICODE | |
| ) | |
| # Arabic temporal words | |
| _TEMPORAL_AR = re.compile( | |
| r"(اليوم|الآن|عاجل|أخبار عاجلة|حديثاً|مؤخراً|هذا الأسبوع|هذا الشهر)", | |
| re.UNICODE | |
| ) | |
| # Somali temporal words | |
| _TEMPORAL_SO = re.compile(r"(maanta|hadda|wararka|cusub)", re.IGNORECASE | re.UNICODE) | |
| # Swahili temporal words | |
| _TEMPORAL_SW = re.compile(r"(leo|sasa|habari za leo|mpya|hivi karibuni)", re.IGNORECASE | re.UNICODE) | |
| # ── Historical signals → NEWS_HISTORICAL ───────────────────────────────────── | |
| _HISTORICAL = re.compile( | |
| r"\b(" | |
| r"history (of|behind)|historical(ly)?|" | |
| r"background (of|on|to)|context (of|behind)|" | |
| r"what caused|root cause|origin(s)? of|" | |
| r"explain|overview|summary of|" | |
| r"who (is|was|are|were)|what (is|was|are|were)|" | |
| r"tell me about|describe|" | |
| r"in \d{4}|since \d{4}|before \d{4}|" | |
| r"decade(s)?|century|centuries|" | |
| r"long.?term|over the years|traditionally|" | |
| r"founded|established|created|formed" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| # ── Other signals → OTHER ───────────────────────────────────────────────────── | |
| _OTHER_IDENTITY = re.compile( | |
| r"\b(" | |
| r"who are you|what are you|are you (an? )?ai|" | |
| r"what (model|llm|ai) are you|" | |
| r"who (made|built|created|trained) you|" | |
| r"your (name|purpose|capabilities)|" | |
| r"can you (help|do|write|make|create|generate)|" | |
| r"how (do you|does this) work" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| _OTHER_CREATIVE = re.compile( | |
| r"\b(" | |
| r"write (a |an )?(poem|story|essay|letter|email|code|script)|" | |
| r"make (a |an )?(joke|list|plan|recipe)|" | |
| r"translate (this|to|into)|" | |
| r"calculate|solve|compute|" | |
| r"what is \d|how many|how much|" | |
| r"recommend|suggest|give me (a |an )?(list|idea)" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| # ── Ethiopia/Africa fast-path → NEWS_GENERAL (skip LLM entirely) ───────────── | |
| _ETHIOPIA_ENTITIES = re.compile( | |
| r"\b(" | |
| r"ethiopia(n)?|addis ababa|addis|" | |
| r"tigray|amhara|oromia|oromo|afar|somali region|sidama|" | |
| r"abiy ahmed?|abiy|" | |
| r"tplf|fano|olf|oneg|endf|" | |
| r"gerd|renaissance dam|nile dam|" | |
| r"mekelle|bahir dar|gondar|hawassa|dire dawa|" | |
| r"africa(n)?|horn of africa|east africa|" | |
| r"sudan|somalia|eritrea|kenya|djibouti" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| # ── Conflict/humanitarian fast-path → NEWS_GENERAL ─────────────────────────── | |
| _NEWS_TOPICS = re.compile( | |
| r"\b(" | |
| r"conflict|war|fighting|clashes?|attack(s|ed)?|killed|casualties|" | |
| r"peace (talks?|deal|agreement|process)|ceasefire|" | |
| r"election(s)?|vote|voting|ballot|" | |
| r"government|minister|president|prime minister|parliament|" | |
| r"economy|economic|inflation|gdp|trade|investment|" | |
| r"humanitarian|refugee(s)?|displaced|famine|drought|flood|" | |
| r"protest(s|ers)?|demonstration|rally|" | |
| r"military|troops|soldiers?|forces?|" | |
| r"news|report(s|ed)?|update(s)?" | |
| r")\b", | |
| re.IGNORECASE | |
| ) | |
| def _fast_classify(query: str) -> Optional[Tuple[str, float, str]]: | |
| """ | |
| Layer 1: Fast keyword-based classification. | |
| Returns (intent, confidence, reason) or None if uncertain. | |
| Priority order: | |
| 1. OTHER (identity/creative) — highest priority, avoid wasting search | |
| 2. NEWS_TEMPORAL — temporal signals are unambiguous | |
| 3. NEWS_HISTORICAL — historical signals are fairly unambiguous | |
| 4. NEWS_GENERAL — Ethiopia/Africa entities or news topics | |
| 5. None — uncertain, let LLM decide | |
| """ | |
| q = query.strip() | |
| ql = q.lower() | |
| # ── 1. OTHER: identity questions ───────────────────────────────────────── | |
| if _OTHER_IDENTITY.search(q): | |
| return ("OTHER", 0.95, "identity_pattern") | |
| # ── 2. OTHER: creative/off-topic ───────────────────────────────────────── | |
| if _OTHER_CREATIVE.search(q): | |
| return ("OTHER", 0.90, "creative_pattern") | |
| # ── 3. NEWS_TEMPORAL: multilingual temporal signals ─────────────────────── | |
| if (_TEMPORAL_EN.search(q) or _TEMPORAL_AM.search(q) or | |
| _TEMPORAL_AR.search(q) or _TEMPORAL_SO.search(q) or | |
| _TEMPORAL_SW.search(q)): | |
| return ("NEWS_TEMPORAL", 0.92, "temporal_keyword") | |
| # ── 4. NEWS_HISTORICAL: historical/background signals ──────────────────── | |
| if _HISTORICAL.search(q): | |
| # But if it also has temporal signals, temporal wins | |
| return ("NEWS_HISTORICAL", 0.88, "historical_keyword") | |
| # ── 5. NEWS_GENERAL: Ethiopia/Africa entities ──────────────────────────── | |
| if _ETHIOPIA_ENTITIES.search(q): | |
| return ("NEWS_GENERAL", 0.85, "ethiopia_entity") | |
| # ── 6. NEWS_GENERAL: news topic keywords ───────────────────────────────── | |
| if _NEWS_TOPICS.search(q): | |
| return ("NEWS_GENERAL", 0.80, "news_topic_keyword") | |
| # ── 7. Uncertain — let LLM decide ──────────────────────────────────────── | |
| return None | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # LLM CLASSIFICATION PROMPT | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| _CLASSIFY_PROMPT = """You are an intent classifier for ARKI AI, a news assistant focused on Ethiopia and Africa. | |
| Classify the user query into EXACTLY ONE of these categories: | |
| NEWS_TEMPORAL — asks about current/recent/today's events, breaking news, latest updates | |
| NEWS_HISTORICAL — asks about past events, history, background, context, analysis | |
| NEWS_GENERAL — asks about news topics without a specific time reference (people, places, conflicts, politics, economy, humanitarian) | |
| OTHER — identity questions ("who are you"), math, greetings, creative writing, off-topic requests | |
| Rules: | |
| - Single words like "ethiopia", "amhara", "conflict", "news" → NEWS_GENERAL | |
| - Single words like "today", "now", "breaking", "latest" → NEWS_TEMPORAL | |
| - Vague queries about a news topic → NEWS_GENERAL (search and find nothing > refuse) | |
| - Questions about AI identity, capabilities, or the system → OTHER | |
| - Math, recipes, poems, games → OTHER | |
| - When in doubt between NEWS types → NEWS_GENERAL | |
| Reply with ONLY the category name. Nothing else. | |
| Query: {query} | |
| Category:""" | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # DATA CLASS | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| class IntentResult: | |
| intent: str # NEWS_TEMPORAL | NEWS_HISTORICAL | NEWS_GENERAL | OTHER | |
| confidence: float # 0.0 – 1.0 | |
| method: str # instant | keyword | llm_groq | llm_gemini | llm_openrouter | llm_hf | default | |
| inference_time_ms: float | |
| query_complexity: str # empty | vague | simple | medium | complex | |
| sub_type: str # general | conflict | humanitarian | identity | creative | off_topic | |
| should_use_live: bool | |
| should_use_db: bool | |
| metadata: Dict[str, Any] | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "intent": self.intent, | |
| "confidence": self.confidence, | |
| "method": self.method, | |
| "inference_time_ms": self.inference_time_ms, | |
| "query_complexity": self.query_complexity, | |
| "sub_type": self.sub_type, | |
| "should_use_live": self.should_use_live, | |
| "should_use_db": self.should_use_db, | |
| "metadata": self.metadata, | |
| } | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # CLASSIFIER | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| class IntentClassifierV2: | |
| """ | |
| Intent classifier v5: Fast keyword pre-check + LLM fallback chain. | |
| Layer 0: Instant exact match (0ms) | |
| Layer 1: Keyword rules (0ms) — handles ~80% of queries | |
| Layer 2: Groq 8B (50ms) | |
| Layer 3: Gemini Flash (200ms) | |
| Layer 4: OpenRouter (300ms) | |
| Layer 5: HuggingFace (2s) | |
| Layer 6: Default NEWS_GENERAL (0ms) | |
| """ | |
| GROQ_URL = "https://api.groq.com/openai/v1/chat/completions" | |
| GROQ_MODEL = "llama-3.1-8b-instant" | |
| GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |
| OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" | |
| OPENROUTER_MODEL = "openrouter/auto" | |
| HF_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-3B-Instruct/v1/chat/completions" | |
| VALID_INTENTS = {"NEWS_TEMPORAL", "NEWS_HISTORICAL", "NEWS_GENERAL", "OTHER"} | |
| def __init__(self): | |
| self._groq_key: Optional[str] = None | |
| self._gemini_key: Optional[str] = None | |
| self._openrouter_key: Optional[str] = None | |
| self._hf_token: Optional[str] = None | |
| self._client = httpx.Client(timeout=5.0) | |
| self._metrics: Dict[str, Any] = { | |
| "total": 0, | |
| "by_intent": {}, | |
| "by_method": {}, | |
| "total_ms": 0.0, | |
| "keyword_hits": 0, # how many queries handled by keyword layer | |
| "llm_calls": 0, # how many queries needed LLM | |
| } | |
| self._load_keys() | |
| def _load_keys(self): | |
| try: | |
| from src.core.config import settings | |
| key = settings.GROQ_API_KEY | |
| if key and key not in ("", "your-groq-api-key-here"): | |
| self._groq_key = key | |
| gem = settings.GEMINI_API_KEY | |
| if gem and gem not in ("", "your-gemini-api-key-here"): | |
| self._gemini_key = gem | |
| try: | |
| or_key = getattr(settings, "OPENROUTER_API_KEY", "") | |
| if or_key and or_key not in ("", "your-openrouter-api-key-here"): | |
| self._openrouter_key = or_key | |
| except Exception: | |
| pass | |
| hf = settings.HF_TOKEN | |
| if hf and hf not in ("", "your-hf-token-here"): | |
| self._hf_token = hf | |
| providers = ["Keyword"] | |
| if self._groq_key: providers.append("Groq") | |
| if self._gemini_key: providers.append("Gemini") | |
| if self._openrouter_key: providers.append("OpenRouter") | |
| if self._hf_token: providers.append("HuggingFace") | |
| providers.append("Default") | |
| logger.info(f"✅ Intent classifier v5 providers: {' → '.join(providers)}") | |
| except Exception as e: | |
| logger.error(f"Intent classifier: failed to load keys: {e}") | |
| # ── Public API ──────────────────────────────────────────────────────────── | |
| def classify(self, query: str) -> IntentResult: | |
| t0 = time.time() | |
| q = query.strip() | |
| ql = q.lower() | |
| complexity = self._complexity(q) | |
| # ── Layer 0: Instant exact match ────────────────────────────────────── | |
| if ql in _INSTANT_OTHER: | |
| return self._result("OTHER", 1.0, "instant", t0, complexity, "identity") | |
| # ── Layer 1: Fast keyword rules ─────────────────────────────────────── | |
| fast = _fast_classify(q) | |
| if fast: | |
| intent, confidence, reason = fast | |
| self._metrics["keyword_hits"] += 1 | |
| logger.debug(f"[Intent] Keyword rule: '{q[:50]}' → {intent} ({reason})") | |
| return self._result(intent, confidence, f"keyword:{reason}", t0, complexity, | |
| self._sub_type(q, intent)) | |
| # ── Layers 2-5: LLM providers ───────────────────────────────────────── | |
| self._metrics["llm_calls"] += 1 | |
| if self._groq_key: | |
| intent = self._call_openai_compat( | |
| url=self.GROQ_URL, api_key=self._groq_key, | |
| model=self.GROQ_MODEL, query=q, provider="groq" | |
| ) | |
| if intent: | |
| return self._result(intent, 0.97, "llm_groq", t0, complexity, | |
| self._sub_type(q, intent)) | |
| if self._gemini_key: | |
| intent = self._call_gemini(q) | |
| if intent: | |
| return self._result(intent, 0.95, "llm_gemini", t0, complexity, | |
| self._sub_type(q, intent)) | |
| if self._openrouter_key: | |
| intent = self._call_openai_compat( | |
| url=self.OPENROUTER_URL, api_key=self._openrouter_key, | |
| model=self.OPENROUTER_MODEL, query=q, provider="openrouter", | |
| extra_headers={ | |
| "HTTP-Referer": "https://arki-ai.com", | |
| "X-Title": "ARKI AI Intent Classifier", | |
| } | |
| ) | |
| if intent: | |
| return self._result(intent, 0.93, "llm_openrouter", t0, complexity, | |
| self._sub_type(q, intent)) | |
| if self._hf_token: | |
| intent = self._call_openai_compat( | |
| url=self.HF_URL, api_key=self._hf_token, | |
| model="meta-llama/Llama-3.2-3B-Instruct", | |
| query=q, provider="huggingface", timeout=8.0 | |
| ) | |
| if intent: | |
| return self._result(intent, 0.90, "llm_hf", t0, complexity, | |
| self._sub_type(q, intent)) | |
| # ── Layer 6: Safe default ───────────────────────────────────────────── | |
| logger.warning(f"[Intent] All providers failed for '{q[:50]}' — defaulting to NEWS_GENERAL") | |
| return self._result("NEWS_GENERAL", 0.50, "default", t0, complexity, "general") | |
| # ── Provider calls ──────────────────────────────────────────────────────── | |
| def _call_openai_compat( | |
| self, | |
| url: str, | |
| api_key: str, | |
| model: str, | |
| query: str, | |
| provider: str, | |
| extra_headers: Optional[Dict] = None, | |
| timeout: float = 4.0, | |
| ) -> Optional[str]: | |
| try: | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| if extra_headers: | |
| headers.update(extra_headers) | |
| response = self._client.post( | |
| url, headers=headers, | |
| json={ | |
| "model": model, | |
| "messages": [{"role": "user", "content": _CLASSIFY_PROMPT.format(query=query)}], | |
| "max_tokens": 20, | |
| "temperature": 0.0, | |
| }, | |
| timeout=timeout, | |
| ) | |
| if response.status_code == 200: | |
| content = ( | |
| response.json().get("choices", [{}])[0] | |
| .get("message", {}).get("content", "").strip() | |
| ) | |
| intent = self._parse_intent(content) | |
| if intent: | |
| logger.debug(f"[Intent] {provider}: '{query[:40]}' → {intent}") | |
| return intent | |
| logger.warning(f"[Intent] {provider}: unexpected response: '{content}'") | |
| elif response.status_code == 429: | |
| logger.warning(f"[Intent] {provider} rate limited") | |
| elif response.status_code == 503: | |
| logger.warning(f"[Intent] {provider} unavailable (503)") | |
| else: | |
| logger.warning(f"[Intent] {provider} returned {response.status_code}") | |
| except httpx.TimeoutException: | |
| logger.warning(f"[Intent] {provider} timeout ({timeout}s)") | |
| except Exception as e: | |
| logger.error(f"[Intent] {provider} error: {e}") | |
| return None | |
| def _call_gemini(self, query: str) -> Optional[str]: | |
| try: | |
| url = f"{self.GEMINI_URL}?key={self._gemini_key}" | |
| response = self._client.post( | |
| url, | |
| json={ | |
| "contents": [{"parts": [{"text": _CLASSIFY_PROMPT.format(query=query)}]}], | |
| "generationConfig": {"maxOutputTokens": 20, "temperature": 0.0}, | |
| }, | |
| timeout=4.0, | |
| ) | |
| if response.status_code == 200: | |
| content = ( | |
| response.json().get("candidates", [{}])[0] | |
| .get("content", {}).get("parts", [{}])[0] | |
| .get("text", "").strip() | |
| ) | |
| intent = self._parse_intent(content) | |
| if intent: | |
| logger.debug(f"[Intent] gemini: '{query[:40]}' → {intent}") | |
| return intent | |
| elif response.status_code == 429: | |
| logger.warning("[Intent] Gemini rate limited") | |
| else: | |
| logger.warning(f"[Intent] Gemini returned {response.status_code}") | |
| except httpx.TimeoutException: | |
| logger.warning("[Intent] Gemini timeout (4s)") | |
| except Exception as e: | |
| logger.error(f"[Intent] Gemini error: {e}") | |
| return None | |
| # ── Helpers ─────────────────────────────────────────────────────────────── | |
| def _parse_intent(self, raw: str) -> Optional[str]: | |
| cleaned = raw.strip().upper().replace(".", "").replace(":", "") | |
| first_word = cleaned.split()[0] if cleaned.split() else "" | |
| if first_word in self.VALID_INTENTS: | |
| return first_word | |
| for intent in self.VALID_INTENTS: | |
| if intent in cleaned: | |
| return intent | |
| return None | |
| def _sub_type(self, query: str, intent: str) -> str: | |
| if intent == "OTHER": | |
| ql = query.lower() | |
| if _OTHER_IDENTITY.search(query): | |
| return "identity" | |
| if _OTHER_CREATIVE.search(query): | |
| return "creative" | |
| return "off_topic" | |
| ql = query.lower() | |
| if any(w in ql for w in ("clash", "attack", "killed", "battle", "fano", "tplf", "military", "conflict", "war")): | |
| return "conflict" | |
| if any(w in ql for w in ("displaced", "refugee", "aid", "humanitarian", "famine", "drought")): | |
| return "humanitarian" | |
| if any(w in ql for w in ("election", "vote", "government", "minister", "president", "parliament")): | |
| return "political" | |
| if any(w in ql for w in ("economy", "economic", "inflation", "trade", "investment", "gdp")): | |
| return "economic" | |
| return "general" | |
| def _complexity(self, query: str) -> str: | |
| n = len(query.split()) | |
| if n == 0: return "empty" | |
| if n == 1: return "vague" | |
| if n <= 4: return "simple" | |
| if n <= 12: return "medium" | |
| return "complex" | |
| def _result( | |
| self, | |
| intent: str, | |
| confidence: float, | |
| method: str, | |
| t0: float, | |
| complexity: str, | |
| sub_type: str, | |
| metadata: Optional[Dict] = None, | |
| ) -> IntentResult: | |
| ms = (time.time() - t0) * 1000 | |
| self._metrics["total"] += 1 | |
| self._metrics["by_intent"][intent] = self._metrics["by_intent"].get(intent, 0) + 1 | |
| self._metrics["by_method"][method] = self._metrics["by_method"].get(method, 0) + 1 | |
| self._metrics["total_ms"] += ms | |
| logger.debug( | |
| f"[Intent] {intent} conf={confidence:.2f} method={method} " | |
| f"sub={sub_type} complexity={complexity} time={ms:.1f}ms" | |
| ) | |
| return IntentResult( | |
| intent=intent, | |
| confidence=confidence, | |
| method=method, | |
| inference_time_ms=ms, | |
| query_complexity=complexity, | |
| sub_type=sub_type, | |
| should_use_live=(intent == "NEWS_TEMPORAL"), | |
| should_use_db=(intent in ("NEWS_TEMPORAL", "NEWS_HISTORICAL", "NEWS_GENERAL")), | |
| metadata=metadata or {}, | |
| ) | |
| def get_metrics(self) -> Dict[str, Any]: | |
| total = self._metrics["total"] or 1 | |
| kw_pct = (self._metrics["keyword_hits"] / total) * 100 | |
| return { | |
| **self._metrics, | |
| "avg_ms": self._metrics["total_ms"] / total, | |
| "keyword_hit_rate_pct": round(kw_pct, 1), | |
| } | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| # SINGLETONS | |
| # ═══════════════════════════════════════════════════════════════════════════════ | |
| intent_classifier_v2 = IntentClassifierV2() | |
| class IntentClassifier: | |
| """Backward-compatible binary wrapper (NEWS / OTHER).""" | |
| def __init__(self): | |
| self._v2 = intent_classifier_v2 | |
| def classify(self, query: str) -> str: | |
| result = self._v2.classify(query) | |
| return "OTHER" if result.intent == "OTHER" else "NEWS" | |
| intent_classifier = IntentClassifier() | |