Spaces:
Sleeping
Sleeping
File size: 6,010 Bytes
cbb1b1a 585a064 cbb1b1a 585a064 06254f4 585a064 06254f4 585a064 06254f4 585a064 cbb1b1a 585a064 cbb1b1a 06254f4 585a064 06254f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """
DomainClassifier inference helper.
Loads a saved checkpoint and exposes classify() used by the CMA tool
classify_domain.
Graceful fallback: if no checkpoint exists at model_dir, a keyword-based
heuristic is used instead so the pipeline always returns a result. The
returned DomainResult's confidence will be 0.0 to signal the fallback path.
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, DomainClassifier, DomainResult
logger = logging.getLogger(__name__)
_DEFAULT_MODEL_DIR = "models/domain_classifier"
_classifier: Optional[DomainClassifier] = None
# Minimum model confidence required to trust the classification result.
# Below this threshold the result is flagged as low_confidence=True and the
# CMA agent is expected to ask the user a clarifying domain question instead
# of proceeding automatically. The keyword fallback (confidence=0.0 sentinel)
# always sets low_confidence=True regardless of this threshold.
DOMAIN_CONFIDENCE_THRESHOLD: float = 0.50
# ---------------------------------------------------------------------------
# Keyword fallback
# ---------------------------------------------------------------------------
_KEYWORD_DOMAINS: dict[str, list[str]] = {
"ecommerce": [
"order", "delivery", "product", "refund", "return", "shipped",
"package", "cart", "flipkart", "amazon", "myntra", "snapdeal",
"meesho", "swiggy", "zomato", "cashback",
],
"telecom": [
"network", "sim", "call", "internet", "data", "recharge", "plan",
"airtel", "jio", "vodafone", "bsnl", "trai", "porting", "prepaid",
"postpaid", "broadband", "signal",
],
"banking": [
"account", "bank", "transaction", "debit", "neft", "rtgs", "imps",
"loan", "emi", "cheque", "atm", "ifsc", "branch", "passbook",
"fixed deposit", "savings", "current account", "sbi", "hdfc",
"icici", "axis bank", "pnb",
],
"cibil": [
"credit score", "cibil", "credit report", "credit card", "credit rating",
"debt", "collection", "default", "npa", "experian", "equifax",
"credit bureau", "outstanding", "overdue",
],
"insurance": [
"policy", "claim", "premium", "insurance", "insurer", "coverage",
"lic", "settle", "settlement", "nominee", "health insurance",
"motor insurance", "irdai", "indemnity", "surveyor",
],
"general": [], # catch-all β scores 1 point unconditionally below
}
def _keyword_classify(text: str) -> DomainResult:
"""Score text against keyword lists and return the best-matching domain."""
lower = text.lower()
scores: dict[str, int] = {d: 1 if d == "general" else 0 for d in DOMAIN_LABELS}
for domain, keywords in _KEYWORD_DOMAINS.items():
for kw in keywords:
if kw in lower:
scores[domain] += 1
best = max(scores, key=lambda d: scores[d])
total = sum(scores.values()) or 1
all_probs = {d: round(scores[d] / total, 4) for d in DOMAIN_LABELS}
logger.debug("Keyword fallback scores: %s β %s", scores, best)
return DomainResult(
domain=best,
confidence=0.0, # sentinel: 0.0 signals keyword fallback, not a model score
all_probs=all_probs,
low_confidence=True, # always uncertain when falling back to keywords
)
# ---------------------------------------------------------------------------
# Checkpoint detection
# ---------------------------------------------------------------------------
def _checkpoint_exists(model_dir: str) -> bool:
return os.path.isfile(os.path.join(model_dir, "config.json"))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def init_classifier(model_dir: str = _DEFAULT_MODEL_DIR) -> DomainClassifier:
"""
Explicitly initialise (or reload) the module-level DomainClassifier.
Raises FileNotFoundError if *model_dir* contains no checkpoint, so callers
can decide whether to abort or fall back. For transparent fallback use
classify() directly.
"""
global _classifier
if not _checkpoint_exists(model_dir):
raise FileNotFoundError(
f"No DomainClassifier checkpoint found at '{model_dir}'. "
f"Run: python -m src.classifier.train --cfpb_csv <path> --output_dir {model_dir}"
)
logger.info("Loading DomainClassifier from %s β¦", model_dir)
_classifier = DomainClassifier(model_dir)
return _classifier
def classify(text: str, model_dir: str = _DEFAULT_MODEL_DIR) -> DomainResult:
"""
Classify *text* and return a DomainResult.
Two sources of low_confidence=True:
1. No checkpoint exists β keyword fallback is used (confidence=0.0 sentinel).
2. Model exists but top-class probability < DOMAIN_CONFIDENCE_THRESHOLD.
Callers (and the CMA agent) must check low_confidence and ask the user a
clarifying domain question rather than proceeding with an uncertain result.
"""
global _classifier
if _classifier is None:
if not _checkpoint_exists(model_dir):
logger.warning(
"No DomainClassifier checkpoint at '%s' β using keyword fallback. "
"Train with: python -m src.classifier.train "
"--cfpb_csv data/raw/complaints.csv --output_dir %s",
model_dir, model_dir,
)
return _keyword_classify(text)
_classifier = DomainClassifier(model_dir)
result = _classifier.predict(text)
if result.confidence < DOMAIN_CONFIDENCE_THRESHOLD:
logger.info(
"DomainClassifier low confidence (%.2f < %.2f) for domain '%s' β "
"flagging for user clarification.",
result.confidence, DOMAIN_CONFIDENCE_THRESHOLD, result.domain,
)
result.low_confidence = True
return result
|