Spaces:
Sleeping
Sleeping
| """ | |
| DomainClassifier inference helper. | |
| Loads a saved checkpoint and exposes classify() used by the CMA tool | |
| classify_domain. | |
| Graceful fallback: if no checkpoint exists at model_dir, a keyword-based | |
| heuristic is used instead so the pipeline always returns a result. The | |
| returned DomainResult's confidence will be 0.0 to signal the fallback path. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from typing import Optional | |
| from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, DomainClassifier, DomainResult | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_MODEL_DIR = "models/domain_classifier" | |
| _classifier: Optional[DomainClassifier] = None | |
| # Minimum model confidence required to trust the classification result. | |
| # Below this threshold the result is flagged as low_confidence=True and the | |
| # CMA agent is expected to ask the user a clarifying domain question instead | |
| # of proceeding automatically. The keyword fallback (confidence=0.0 sentinel) | |
| # always sets low_confidence=True regardless of this threshold. | |
| DOMAIN_CONFIDENCE_THRESHOLD: float = 0.50 | |
| # --------------------------------------------------------------------------- | |
| # Keyword fallback | |
| # --------------------------------------------------------------------------- | |
| _KEYWORD_DOMAINS: dict[str, list[str]] = { | |
| "ecommerce": [ | |
| "order", "delivery", "product", "refund", "return", "shipped", | |
| "package", "cart", "flipkart", "amazon", "myntra", "snapdeal", | |
| "meesho", "swiggy", "zomato", "cashback", | |
| ], | |
| "telecom": [ | |
| "network", "sim", "call", "internet", "data", "recharge", "plan", | |
| "airtel", "jio", "vodafone", "bsnl", "trai", "porting", "prepaid", | |
| "postpaid", "broadband", "signal", | |
| ], | |
| "banking": [ | |
| "account", "bank", "transaction", "debit", "neft", "rtgs", "imps", | |
| "loan", "emi", "cheque", "atm", "ifsc", "branch", "passbook", | |
| "fixed deposit", "savings", "current account", "sbi", "hdfc", | |
| "icici", "axis bank", "pnb", | |
| ], | |
| "cibil": [ | |
| "credit score", "cibil", "credit report", "credit card", "credit rating", | |
| "debt", "collection", "default", "npa", "experian", "equifax", | |
| "credit bureau", "outstanding", "overdue", | |
| ], | |
| "insurance": [ | |
| "policy", "claim", "premium", "insurance", "insurer", "coverage", | |
| "lic", "settle", "settlement", "nominee", "health insurance", | |
| "motor insurance", "irdai", "indemnity", "surveyor", | |
| ], | |
| "general": [], # catch-all β scores 1 point unconditionally below | |
| } | |
| def _keyword_classify(text: str) -> DomainResult: | |
| """Score text against keyword lists and return the best-matching domain.""" | |
| lower = text.lower() | |
| scores: dict[str, int] = {d: 1 if d == "general" else 0 for d in DOMAIN_LABELS} | |
| for domain, keywords in _KEYWORD_DOMAINS.items(): | |
| for kw in keywords: | |
| if kw in lower: | |
| scores[domain] += 1 | |
| best = max(scores, key=lambda d: scores[d]) | |
| total = sum(scores.values()) or 1 | |
| all_probs = {d: round(scores[d] / total, 4) for d in DOMAIN_LABELS} | |
| logger.debug("Keyword fallback scores: %s β %s", scores, best) | |
| return DomainResult( | |
| domain=best, | |
| confidence=0.0, # sentinel: 0.0 signals keyword fallback, not a model score | |
| all_probs=all_probs, | |
| low_confidence=True, # always uncertain when falling back to keywords | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint detection | |
| # --------------------------------------------------------------------------- | |
| def _checkpoint_exists(model_dir: str) -> bool: | |
| return os.path.isfile(os.path.join(model_dir, "config.json")) | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def init_classifier(model_dir: str = _DEFAULT_MODEL_DIR) -> DomainClassifier: | |
| """ | |
| Explicitly initialise (or reload) the module-level DomainClassifier. | |
| Raises FileNotFoundError if *model_dir* contains no checkpoint, so callers | |
| can decide whether to abort or fall back. For transparent fallback use | |
| classify() directly. | |
| """ | |
| global _classifier | |
| if not _checkpoint_exists(model_dir): | |
| raise FileNotFoundError( | |
| f"No DomainClassifier checkpoint found at '{model_dir}'. " | |
| f"Run: python -m src.classifier.train --cfpb_csv <path> --output_dir {model_dir}" | |
| ) | |
| logger.info("Loading DomainClassifier from %s β¦", model_dir) | |
| _classifier = DomainClassifier(model_dir) | |
| return _classifier | |
| def classify(text: str, model_dir: str = _DEFAULT_MODEL_DIR) -> DomainResult: | |
| """ | |
| Classify *text* and return a DomainResult. | |
| Two sources of low_confidence=True: | |
| 1. No checkpoint exists β keyword fallback is used (confidence=0.0 sentinel). | |
| 2. Model exists but top-class probability < DOMAIN_CONFIDENCE_THRESHOLD. | |
| Callers (and the CMA agent) must check low_confidence and ask the user a | |
| clarifying domain question rather than proceeding with an uncertain result. | |
| """ | |
| global _classifier | |
| if _classifier is None: | |
| if not _checkpoint_exists(model_dir): | |
| logger.warning( | |
| "No DomainClassifier checkpoint at '%s' β using keyword fallback. " | |
| "Train with: python -m src.classifier.train " | |
| "--cfpb_csv data/raw/complaints.csv --output_dir %s", | |
| model_dir, model_dir, | |
| ) | |
| return _keyword_classify(text) | |
| _classifier = DomainClassifier(model_dir) | |
| result = _classifier.predict(text) | |
| if result.confidence < DOMAIN_CONFIDENCE_THRESHOLD: | |
| logger.info( | |
| "DomainClassifier low confidence (%.2f < %.2f) for domain '%s' β " | |
| "flagging for user clarification.", | |
| result.confidence, DOMAIN_CONFIDENCE_THRESHOLD, result.domain, | |
| ) | |
| result.low_confidence = True | |
| return result | |