File size: 6,010 Bytes
cbb1b1a
 
 
585a064
 
 
 
 
 
cbb1b1a
 
585a064
 
 
 
 
 
 
 
 
 
 
 
 
06254f4
 
 
 
 
 
 
585a064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06254f4
585a064
06254f4
585a064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
585a064
 
 
cbb1b1a
06254f4
 
 
 
 
 
585a064
 
 
 
 
 
 
 
 
 
 
 
06254f4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
DomainClassifier inference helper.

Loads a saved checkpoint and exposes classify() used by the CMA tool
classify_domain.

Graceful fallback: if no checkpoint exists at model_dir, a keyword-based
heuristic is used instead so the pipeline always returns a result.  The
returned DomainResult's confidence will be 0.0 to signal the fallback path.
"""

from __future__ import annotations

import logging
import os
from typing import Optional

from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, DomainClassifier, DomainResult

logger = logging.getLogger(__name__)

_DEFAULT_MODEL_DIR = "models/domain_classifier"
_classifier: Optional[DomainClassifier] = None

# Minimum model confidence required to trust the classification result.
# Below this threshold the result is flagged as low_confidence=True and the
# CMA agent is expected to ask the user a clarifying domain question instead
# of proceeding automatically.  The keyword fallback (confidence=0.0 sentinel)
# always sets low_confidence=True regardless of this threshold.
DOMAIN_CONFIDENCE_THRESHOLD: float = 0.50

# ---------------------------------------------------------------------------
# Keyword fallback
# ---------------------------------------------------------------------------

_KEYWORD_DOMAINS: dict[str, list[str]] = {
    "ecommerce": [
        "order", "delivery", "product", "refund", "return", "shipped",
        "package", "cart", "flipkart", "amazon", "myntra", "snapdeal",
        "meesho", "swiggy", "zomato", "cashback",
    ],
    "telecom": [
        "network", "sim", "call", "internet", "data", "recharge", "plan",
        "airtel", "jio", "vodafone", "bsnl", "trai", "porting", "prepaid",
        "postpaid", "broadband", "signal",
    ],
    "banking": [
        "account", "bank", "transaction", "debit", "neft", "rtgs", "imps",
        "loan", "emi", "cheque", "atm", "ifsc", "branch", "passbook",
        "fixed deposit", "savings", "current account", "sbi", "hdfc",
        "icici", "axis bank", "pnb",
    ],
    "cibil": [
        "credit score", "cibil", "credit report", "credit card", "credit rating",
        "debt", "collection", "default", "npa", "experian", "equifax",
        "credit bureau", "outstanding", "overdue",
    ],
    "insurance": [
        "policy", "claim", "premium", "insurance", "insurer", "coverage",
        "lic", "settle", "settlement", "nominee", "health insurance",
        "motor insurance", "irdai", "indemnity", "surveyor",
    ],
    "general": [],  # catch-all β€” scores 1 point unconditionally below
}


def _keyword_classify(text: str) -> DomainResult:
    """Score text against keyword lists and return the best-matching domain."""
    lower = text.lower()
    scores: dict[str, int] = {d: 1 if d == "general" else 0 for d in DOMAIN_LABELS}
    for domain, keywords in _KEYWORD_DOMAINS.items():
        for kw in keywords:
            if kw in lower:
                scores[domain] += 1

    best = max(scores, key=lambda d: scores[d])
    total = sum(scores.values()) or 1
    all_probs = {d: round(scores[d] / total, 4) for d in DOMAIN_LABELS}

    logger.debug("Keyword fallback scores: %s β†’ %s", scores, best)
    return DomainResult(
        domain=best,
        confidence=0.0,   # sentinel: 0.0 signals keyword fallback, not a model score
        all_probs=all_probs,
        low_confidence=True,  # always uncertain when falling back to keywords
    )


# ---------------------------------------------------------------------------
# Checkpoint detection
# ---------------------------------------------------------------------------

def _checkpoint_exists(model_dir: str) -> bool:
    return os.path.isfile(os.path.join(model_dir, "config.json"))


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def init_classifier(model_dir: str = _DEFAULT_MODEL_DIR) -> DomainClassifier:
    """
    Explicitly initialise (or reload) the module-level DomainClassifier.

    Raises FileNotFoundError if *model_dir* contains no checkpoint, so callers
    can decide whether to abort or fall back.  For transparent fallback use
    classify() directly.
    """
    global _classifier
    if not _checkpoint_exists(model_dir):
        raise FileNotFoundError(
            f"No DomainClassifier checkpoint found at '{model_dir}'. "
            f"Run: python -m src.classifier.train --cfpb_csv <path> --output_dir {model_dir}"
        )
    logger.info("Loading DomainClassifier from %s …", model_dir)
    _classifier = DomainClassifier(model_dir)
    return _classifier


def classify(text: str, model_dir: str = _DEFAULT_MODEL_DIR) -> DomainResult:
    """
    Classify *text* and return a DomainResult.

    Two sources of low_confidence=True:
      1. No checkpoint exists β†’ keyword fallback is used (confidence=0.0 sentinel).
      2. Model exists but top-class probability < DOMAIN_CONFIDENCE_THRESHOLD.

    Callers (and the CMA agent) must check low_confidence and ask the user a
    clarifying domain question rather than proceeding with an uncertain result.
    """
    global _classifier
    if _classifier is None:
        if not _checkpoint_exists(model_dir):
            logger.warning(
                "No DomainClassifier checkpoint at '%s' β€” using keyword fallback. "
                "Train with: python -m src.classifier.train "
                "--cfpb_csv data/raw/complaints.csv --output_dir %s",
                model_dir, model_dir,
            )
            return _keyword_classify(text)
        _classifier = DomainClassifier(model_dir)

    result = _classifier.predict(text)

    if result.confidence < DOMAIN_CONFIDENCE_THRESHOLD:
        logger.info(
            "DomainClassifier low confidence (%.2f < %.2f) for domain '%s' β€” "
            "flagging for user clarification.",
            result.confidence, DOMAIN_CONFIDENCE_THRESHOLD, result.domain,
        )
        result.low_confidence = True

    return result