Spaces:
Sleeping
Sleeping
File size: 3,499 Bytes
cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 06254f4 585a064 cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 585a064 cbb1b1a 585a064 d199df8 585a064 cbb1b1a 585a064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | """
DomainClassifier model definition.
Architecture: distilbert-base-uncased + linear classification head
(DistilBertForSequenceClassification).
Task: Multi-class text classification (6 consumer-complaint domains).
Classes: ecommerce | telecom | banking | cibil | insurance | general
Input: Redacted complaint text (str, max 512 tokens after tokenisation).
Output: DomainResult(domain: str, confidence: float, all_probs: dict[str, float])
Library: HuggingFace transformers.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Label constants — shared by model.py, train.py, predict.py
# ---------------------------------------------------------------------------
DOMAIN_LABELS: list[str] = [
"ecommerce", "telecom", "banking", "cibil", "insurance", "general"
]
DOMAIN2ID: dict[str, int] = {d: i for i, d in enumerate(DOMAIN_LABELS)}
ID2DOMAIN: dict[int, str] = {i: d for d, i in DOMAIN2ID.items()}
NUM_CLASSES: int = len(DOMAIN_LABELS)
# ---------------------------------------------------------------------------
# Public output type
# ---------------------------------------------------------------------------
@dataclass
class DomainResult:
"""Classification output for a single complaint."""
domain: str
confidence: float
all_probs: dict # {domain_label: probability}
low_confidence: bool = False # True when confidence < DOMAIN_CONFIDENCE_THRESHOLD
# ---------------------------------------------------------------------------
# DomainClassifier
# ---------------------------------------------------------------------------
class DomainClassifier:
"""
DistilBERT-based domain classifier.
Loads a fine-tuned DistilBertForSequenceClassification checkpoint produced
by train.py. Runs inference on CPU or GPU automatically.
"""
BASE_MODEL = "distilbert-base-uncased"
def __init__(self, model_dir: str) -> None:
"""Load a fine-tuned checkpoint from *model_dir*."""
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
self.model.eval()
self._device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
self.model.to(self._device)
logger.info("DomainClassifier loaded from %s on %s", model_dir, self._device)
def predict(self, text: str) -> DomainResult:
"""Classify *text* and return a DomainResult."""
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits[0] # (num_classes,)
probs: list[float] = torch.softmax(logits, dim=-1).cpu().tolist()
pred_id: int = int(torch.argmax(torch.tensor(probs)).item())
return DomainResult(
domain=ID2DOMAIN[pred_id],
confidence=round(probs[pred_id], 4),
all_probs={ID2DOMAIN[i]: round(probs[i], 4) for i in range(NUM_CLASSES)},
)
|