File size: 3,499 Bytes
cbb1b1a
 
 
585a064
 
cbb1b1a
 
 
 
 
 
 
585a064
 
 
cbb1b1a
 
585a064
 
 
 
 
 
 
 
cbb1b1a
585a064
 
 
cbb1b1a
585a064
 
 
 
 
 
 
 
cbb1b1a
 
 
 
 
 
06254f4
 
585a064
cbb1b1a
585a064
 
 
cbb1b1a
 
585a064
 
cbb1b1a
585a064
 
 
 
 
 
 
cbb1b1a
585a064
 
 
d199df8
 
 
 
 
585a064
 
cbb1b1a
 
 
585a064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
DomainClassifier model definition.

Architecture: distilbert-base-uncased + linear classification head
              (DistilBertForSequenceClassification).
Task:         Multi-class text classification (6 consumer-complaint domains).
Classes:      ecommerce | telecom | banking | cibil | insurance | general
Input:        Redacted complaint text (str, max 512 tokens after tokenisation).
Output:       DomainResult(domain: str, confidence: float, all_probs: dict[str, float])
Library:      HuggingFace transformers.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Label constants — shared by model.py, train.py, predict.py
# ---------------------------------------------------------------------------

DOMAIN_LABELS: list[str] = [
    "ecommerce", "telecom", "banking", "cibil", "insurance", "general"
]

DOMAIN2ID: dict[str, int] = {d: i for i, d in enumerate(DOMAIN_LABELS)}
ID2DOMAIN: dict[int, str] = {i: d for d, i in DOMAIN2ID.items()}
NUM_CLASSES: int = len(DOMAIN_LABELS)


# ---------------------------------------------------------------------------
# Public output type
# ---------------------------------------------------------------------------

@dataclass
class DomainResult:
    """Classification output for a single complaint."""
    domain: str
    confidence: float
    all_probs: dict       # {domain_label: probability}
    low_confidence: bool = False  # True when confidence < DOMAIN_CONFIDENCE_THRESHOLD


# ---------------------------------------------------------------------------
# DomainClassifier
# ---------------------------------------------------------------------------

class DomainClassifier:
    """
    DistilBERT-based domain classifier.

    Loads a fine-tuned DistilBertForSequenceClassification checkpoint produced
    by train.py.  Runs inference on CPU or GPU automatically.
    """

    BASE_MODEL = "distilbert-base-uncased"

    def __init__(self, model_dir: str) -> None:
        """Load a fine-tuned checkpoint from *model_dir*."""
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
        self.model.eval()
        self._device = torch.device(
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
        self.model.to(self._device)
        logger.info("DomainClassifier loaded from %s on %s", model_dir, self._device)

    def predict(self, text: str) -> DomainResult:
        """Classify *text* and return a DomainResult."""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        )
        inputs = {k: v.to(self._device) for k, v in inputs.items()}

        with torch.no_grad():
            logits = self.model(**inputs).logits[0]  # (num_classes,)

        probs: list[float] = torch.softmax(logits, dim=-1).cpu().tolist()
        pred_id: int = int(torch.argmax(torch.tensor(probs)).item())

        return DomainResult(
            domain=ID2DOMAIN[pred_id],
            confidence=round(probs[pred_id], 4),
            all_probs={ID2DOMAIN[i]: round(probs[i], 4) for i in range(NUM_CLASSES)},
        )