""" DomainClassifier model definition. Architecture: distilbert-base-uncased + linear classification head (DistilBertForSequenceClassification). Task: Multi-class text classification (6 consumer-complaint domains). Classes: ecommerce | telecom | banking | cibil | insurance | general Input: Redacted complaint text (str, max 512 tokens after tokenisation). Output: DomainResult(domain: str, confidence: float, all_probs: dict[str, float]) Library: HuggingFace transformers. """ from __future__ import annotations import logging from dataclasses import dataclass import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Label constants — shared by model.py, train.py, predict.py # --------------------------------------------------------------------------- DOMAIN_LABELS: list[str] = [ "ecommerce", "telecom", "banking", "cibil", "insurance", "general" ] DOMAIN2ID: dict[str, int] = {d: i for i, d in enumerate(DOMAIN_LABELS)} ID2DOMAIN: dict[int, str] = {i: d for d, i in DOMAIN2ID.items()} NUM_CLASSES: int = len(DOMAIN_LABELS) # --------------------------------------------------------------------------- # Public output type # --------------------------------------------------------------------------- @dataclass class DomainResult: """Classification output for a single complaint.""" domain: str confidence: float all_probs: dict # {domain_label: probability} low_confidence: bool = False # True when confidence < DOMAIN_CONFIDENCE_THRESHOLD # --------------------------------------------------------------------------- # DomainClassifier # --------------------------------------------------------------------------- class DomainClassifier: """ DistilBERT-based domain classifier. Loads a fine-tuned DistilBertForSequenceClassification checkpoint produced by train.py. Runs inference on CPU or GPU automatically. """ BASE_MODEL = "distilbert-base-uncased" def __init__(self, model_dir: str) -> None: """Load a fine-tuned checkpoint from *model_dir*.""" self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) self.model.eval() self._device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model.to(self._device) logger.info("DomainClassifier loaded from %s on %s", model_dir, self._device) def predict(self, text: str) -> DomainResult: """Classify *text* and return a DomainResult.""" inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, ) inputs = {k: v.to(self._device) for k, v in inputs.items()} with torch.no_grad(): logits = self.model(**inputs).logits[0] # (num_classes,) probs: list[float] = torch.softmax(logits, dim=-1).cpu().tolist() pred_id: int = int(torch.argmax(torch.tensor(probs)).item()) return DomainResult( domain=ID2DOMAIN[pred_id], confidence=round(probs[pred_id], 4), all_probs={ID2DOMAIN[i]: round(probs[i], 4) for i in range(NUM_CLASSES)}, )