guide / src /classifier /predict.py
sangram kumar yerra
Bug fix: Added logic when domain classify confidence is low
06254f4
Raw
History Blame Contribute Delete
6.01 kB
"""
DomainClassifier inference helper.
Loads a saved checkpoint and exposes classify() used by the CMA tool
classify_domain.
Graceful fallback: if no checkpoint exists at model_dir, a keyword-based
heuristic is used instead so the pipeline always returns a result. The
returned DomainResult's confidence will be 0.0 to signal the fallback path.
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, DomainClassifier, DomainResult
logger = logging.getLogger(__name__)
_DEFAULT_MODEL_DIR = "models/domain_classifier"
_classifier: Optional[DomainClassifier] = None
# Minimum model confidence required to trust the classification result.
# Below this threshold the result is flagged as low_confidence=True and the
# CMA agent is expected to ask the user a clarifying domain question instead
# of proceeding automatically. The keyword fallback (confidence=0.0 sentinel)
# always sets low_confidence=True regardless of this threshold.
DOMAIN_CONFIDENCE_THRESHOLD: float = 0.50
# ---------------------------------------------------------------------------
# Keyword fallback
# ---------------------------------------------------------------------------
_KEYWORD_DOMAINS: dict[str, list[str]] = {
"ecommerce": [
"order", "delivery", "product", "refund", "return", "shipped",
"package", "cart", "flipkart", "amazon", "myntra", "snapdeal",
"meesho", "swiggy", "zomato", "cashback",
],
"telecom": [
"network", "sim", "call", "internet", "data", "recharge", "plan",
"airtel", "jio", "vodafone", "bsnl", "trai", "porting", "prepaid",
"postpaid", "broadband", "signal",
],
"banking": [
"account", "bank", "transaction", "debit", "neft", "rtgs", "imps",
"loan", "emi", "cheque", "atm", "ifsc", "branch", "passbook",
"fixed deposit", "savings", "current account", "sbi", "hdfc",
"icici", "axis bank", "pnb",
],
"cibil": [
"credit score", "cibil", "credit report", "credit card", "credit rating",
"debt", "collection", "default", "npa", "experian", "equifax",
"credit bureau", "outstanding", "overdue",
],
"insurance": [
"policy", "claim", "premium", "insurance", "insurer", "coverage",
"lic", "settle", "settlement", "nominee", "health insurance",
"motor insurance", "irdai", "indemnity", "surveyor",
],
"general": [], # catch-all β€” scores 1 point unconditionally below
}
def _keyword_classify(text: str) -> DomainResult:
"""Score text against keyword lists and return the best-matching domain."""
lower = text.lower()
scores: dict[str, int] = {d: 1 if d == "general" else 0 for d in DOMAIN_LABELS}
for domain, keywords in _KEYWORD_DOMAINS.items():
for kw in keywords:
if kw in lower:
scores[domain] += 1
best = max(scores, key=lambda d: scores[d])
total = sum(scores.values()) or 1
all_probs = {d: round(scores[d] / total, 4) for d in DOMAIN_LABELS}
logger.debug("Keyword fallback scores: %s β†’ %s", scores, best)
return DomainResult(
domain=best,
confidence=0.0, # sentinel: 0.0 signals keyword fallback, not a model score
all_probs=all_probs,
low_confidence=True, # always uncertain when falling back to keywords
)
# ---------------------------------------------------------------------------
# Checkpoint detection
# ---------------------------------------------------------------------------
def _checkpoint_exists(model_dir: str) -> bool:
return os.path.isfile(os.path.join(model_dir, "config.json"))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def init_classifier(model_dir: str = _DEFAULT_MODEL_DIR) -> DomainClassifier:
"""
Explicitly initialise (or reload) the module-level DomainClassifier.
Raises FileNotFoundError if *model_dir* contains no checkpoint, so callers
can decide whether to abort or fall back. For transparent fallback use
classify() directly.
"""
global _classifier
if not _checkpoint_exists(model_dir):
raise FileNotFoundError(
f"No DomainClassifier checkpoint found at '{model_dir}'. "
f"Run: python -m src.classifier.train --cfpb_csv <path> --output_dir {model_dir}"
)
logger.info("Loading DomainClassifier from %s …", model_dir)
_classifier = DomainClassifier(model_dir)
return _classifier
def classify(text: str, model_dir: str = _DEFAULT_MODEL_DIR) -> DomainResult:
"""
Classify *text* and return a DomainResult.
Two sources of low_confidence=True:
1. No checkpoint exists β†’ keyword fallback is used (confidence=0.0 sentinel).
2. Model exists but top-class probability < DOMAIN_CONFIDENCE_THRESHOLD.
Callers (and the CMA agent) must check low_confidence and ask the user a
clarifying domain question rather than proceeding with an uncertain result.
"""
global _classifier
if _classifier is None:
if not _checkpoint_exists(model_dir):
logger.warning(
"No DomainClassifier checkpoint at '%s' β€” using keyword fallback. "
"Train with: python -m src.classifier.train "
"--cfpb_csv data/raw/complaints.csv --output_dir %s",
model_dir, model_dir,
)
return _keyword_classify(text)
_classifier = DomainClassifier(model_dir)
result = _classifier.predict(text)
if result.confidence < DOMAIN_CONFIDENCE_THRESHOLD:
logger.info(
"DomainClassifier low confidence (%.2f < %.2f) for domain '%s' β€” "
"flagging for user clarification.",
result.confidence, DOMAIN_CONFIDENCE_THRESHOLD, result.domain,
)
result.low_confidence = True
return result