""" DomainClassifier inference helper. Loads a saved checkpoint and exposes classify() used by the CMA tool classify_domain. Graceful fallback: if no checkpoint exists at model_dir, a keyword-based heuristic is used instead so the pipeline always returns a result. The returned DomainResult's confidence will be 0.0 to signal the fallback path. """ from __future__ import annotations import logging import os from typing import Optional from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, DomainClassifier, DomainResult logger = logging.getLogger(__name__) _DEFAULT_MODEL_DIR = "models/domain_classifier" _classifier: Optional[DomainClassifier] = None # Minimum model confidence required to trust the classification result. # Below this threshold the result is flagged as low_confidence=True and the # CMA agent is expected to ask the user a clarifying domain question instead # of proceeding automatically. The keyword fallback (confidence=0.0 sentinel) # always sets low_confidence=True regardless of this threshold. DOMAIN_CONFIDENCE_THRESHOLD: float = 0.50 # --------------------------------------------------------------------------- # Keyword fallback # --------------------------------------------------------------------------- _KEYWORD_DOMAINS: dict[str, list[str]] = { "ecommerce": [ "order", "delivery", "product", "refund", "return", "shipped", "package", "cart", "flipkart", "amazon", "myntra", "snapdeal", "meesho", "swiggy", "zomato", "cashback", ], "telecom": [ "network", "sim", "call", "internet", "data", "recharge", "plan", "airtel", "jio", "vodafone", "bsnl", "trai", "porting", "prepaid", "postpaid", "broadband", "signal", ], "banking": [ "account", "bank", "transaction", "debit", "neft", "rtgs", "imps", "loan", "emi", "cheque", "atm", "ifsc", "branch", "passbook", "fixed deposit", "savings", "current account", "sbi", "hdfc", "icici", "axis bank", "pnb", ], "cibil": [ "credit score", "cibil", "credit report", "credit card", "credit rating", "debt", "collection", "default", "npa", "experian", "equifax", "credit bureau", "outstanding", "overdue", ], "insurance": [ "policy", "claim", "premium", "insurance", "insurer", "coverage", "lic", "settle", "settlement", "nominee", "health insurance", "motor insurance", "irdai", "indemnity", "surveyor", ], "general": [], # catch-all — scores 1 point unconditionally below } def _keyword_classify(text: str) -> DomainResult: """Score text against keyword lists and return the best-matching domain.""" lower = text.lower() scores: dict[str, int] = {d: 1 if d == "general" else 0 for d in DOMAIN_LABELS} for domain, keywords in _KEYWORD_DOMAINS.items(): for kw in keywords: if kw in lower: scores[domain] += 1 best = max(scores, key=lambda d: scores[d]) total = sum(scores.values()) or 1 all_probs = {d: round(scores[d] / total, 4) for d in DOMAIN_LABELS} logger.debug("Keyword fallback scores: %s → %s", scores, best) return DomainResult( domain=best, confidence=0.0, # sentinel: 0.0 signals keyword fallback, not a model score all_probs=all_probs, low_confidence=True, # always uncertain when falling back to keywords ) # --------------------------------------------------------------------------- # Checkpoint detection # --------------------------------------------------------------------------- def _checkpoint_exists(model_dir: str) -> bool: return os.path.isfile(os.path.join(model_dir, "config.json")) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def init_classifier(model_dir: str = _DEFAULT_MODEL_DIR) -> DomainClassifier: """ Explicitly initialise (or reload) the module-level DomainClassifier. Raises FileNotFoundError if *model_dir* contains no checkpoint, so callers can decide whether to abort or fall back. For transparent fallback use classify() directly. """ global _classifier if not _checkpoint_exists(model_dir): raise FileNotFoundError( f"No DomainClassifier checkpoint found at '{model_dir}'. " f"Run: python -m src.classifier.train --cfpb_csv --output_dir {model_dir}" ) logger.info("Loading DomainClassifier from %s …", model_dir) _classifier = DomainClassifier(model_dir) return _classifier def classify(text: str, model_dir: str = _DEFAULT_MODEL_DIR) -> DomainResult: """ Classify *text* and return a DomainResult. Two sources of low_confidence=True: 1. No checkpoint exists → keyword fallback is used (confidence=0.0 sentinel). 2. Model exists but top-class probability < DOMAIN_CONFIDENCE_THRESHOLD. Callers (and the CMA agent) must check low_confidence and ask the user a clarifying domain question rather than proceeding with an uncertain result. """ global _classifier if _classifier is None: if not _checkpoint_exists(model_dir): logger.warning( "No DomainClassifier checkpoint at '%s' — using keyword fallback. " "Train with: python -m src.classifier.train " "--cfpb_csv data/raw/complaints.csv --output_dir %s", model_dir, model_dir, ) return _keyword_classify(text) _classifier = DomainClassifier(model_dir) result = _classifier.predict(text) if result.confidence < DOMAIN_CONFIDENCE_THRESHOLD: logger.info( "DomainClassifier low confidence (%.2f < %.2f) for domain '%s' — " "flagging for user clarification.", result.confidence, DOMAIN_CONFIDENCE_THRESHOLD, result.domain, ) result.low_confidence = True return result