Spaces:
Sleeping
Sleeping
| """ | |
| DomainClassifier model definition. | |
| Architecture: distilbert-base-uncased + linear classification head | |
| (DistilBertForSequenceClassification). | |
| Task: Multi-class text classification (6 consumer-complaint domains). | |
| Classes: ecommerce | telecom | banking | cibil | insurance | general | |
| Input: Redacted complaint text (str, max 512 tokens after tokenisation). | |
| Output: DomainResult(domain: str, confidence: float, all_probs: dict[str, float]) | |
| Library: HuggingFace transformers. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Label constants — shared by model.py, train.py, predict.py | |
| # --------------------------------------------------------------------------- | |
| DOMAIN_LABELS: list[str] = [ | |
| "ecommerce", "telecom", "banking", "cibil", "insurance", "general" | |
| ] | |
| DOMAIN2ID: dict[str, int] = {d: i for i, d in enumerate(DOMAIN_LABELS)} | |
| ID2DOMAIN: dict[int, str] = {i: d for d, i in DOMAIN2ID.items()} | |
| NUM_CLASSES: int = len(DOMAIN_LABELS) | |
| # --------------------------------------------------------------------------- | |
| # Public output type | |
| # --------------------------------------------------------------------------- | |
| class DomainResult: | |
| """Classification output for a single complaint.""" | |
| domain: str | |
| confidence: float | |
| all_probs: dict # {domain_label: probability} | |
| low_confidence: bool = False # True when confidence < DOMAIN_CONFIDENCE_THRESHOLD | |
| # --------------------------------------------------------------------------- | |
| # DomainClassifier | |
| # --------------------------------------------------------------------------- | |
| class DomainClassifier: | |
| """ | |
| DistilBERT-based domain classifier. | |
| Loads a fine-tuned DistilBertForSequenceClassification checkpoint produced | |
| by train.py. Runs inference on CPU or GPU automatically. | |
| """ | |
| BASE_MODEL = "distilbert-base-uncased" | |
| def __init__(self, model_dir: str) -> None: | |
| """Load a fine-tuned checkpoint from *model_dir*.""" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| self.model.eval() | |
| self._device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| self.model.to(self._device) | |
| logger.info("DomainClassifier loaded from %s on %s", model_dir, self._device) | |
| def predict(self, text: str) -> DomainResult: | |
| """Classify *text* and return a DomainResult.""" | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| ) | |
| inputs = {k: v.to(self._device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits[0] # (num_classes,) | |
| probs: list[float] = torch.softmax(logits, dim=-1).cpu().tolist() | |
| pred_id: int = int(torch.argmax(torch.tensor(probs)).item()) | |
| return DomainResult( | |
| domain=ID2DOMAIN[pred_id], | |
| confidence=round(probs[pred_id], 4), | |
| all_probs={ID2DOMAIN[i]: round(probs[i], 4) for i in range(NUM_CLASSES)}, | |
| ) | |