import torch import torch.nn as nn from core.device import DEVICE class IntentClassifier(nn.Module): def __init__(self, input_dim: int, intent_labels: list): """ input_dim: dimension of sentence embeddings intent_labels: list of intent names """ super().__init__() self.intent_labels = intent_labels self.num_intents = len(intent_labels) self.classifier = nn.Sequential( nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, self.num_intents) ) self.softmax = nn.Softmax(dim=-1) self.to(DEVICE) def forward(self, sentence_embedding: torch.Tensor) -> torch.Tensor: sentence_embedding = sentence_embedding.to(DEVICE) return self.classifier(sentence_embedding) def predict(self, sentence_embedding: torch.Tensor, threshold=0.4) -> dict: with torch.no_grad(): logits = self.forward(sentence_embedding) probs = self.softmax(logits) confidence, idx = torch.max(probs, dim=-1) label = self.intent_labels[idx.item()] if confidence.item() < threshold: label = "question_general" return { "intent": label, "confidence": confidence.item() } # ------------------------ # RULE-BASED OVERRIDE # ------------------------ def override_intent(text: str, predicted: str) -> str: text = text.lower().strip() if "who are you" in text: return "self_identity" if text.startswith(("hi", "hello", "hey")): return "greeting" return predicted