Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from core.device import DEVICE | |
| class IntentClassifier(nn.Module): | |
| def __init__(self, input_dim: int, intent_labels: list): | |
| """ | |
| input_dim: dimension of sentence embeddings | |
| intent_labels: list of intent names | |
| """ | |
| super().__init__() | |
| self.intent_labels = intent_labels | |
| self.num_intents = len(intent_labels) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(input_dim, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, self.num_intents) | |
| ) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.to(DEVICE) | |
| def forward(self, sentence_embedding: torch.Tensor) -> torch.Tensor: | |
| sentence_embedding = sentence_embedding.to(DEVICE) | |
| return self.classifier(sentence_embedding) | |
| def predict(self, sentence_embedding: torch.Tensor, threshold=0.4) -> dict: | |
| with torch.no_grad(): | |
| logits = self.forward(sentence_embedding) | |
| probs = self.softmax(logits) | |
| confidence, idx = torch.max(probs, dim=-1) | |
| label = self.intent_labels[idx.item()] | |
| if confidence.item() < threshold: | |
| label = "question_general" | |
| return { | |
| "intent": label, | |
| "confidence": confidence.item() | |
| } | |
| # ------------------------ | |
| # RULE-BASED OVERRIDE | |
| # ------------------------ | |
| def override_intent(text: str, predicted: str) -> str: | |
| text = text.lower().strip() | |
| if "who are you" in text: | |
| return "self_identity" | |
| if text.startswith(("hi", "hello", "hey")): | |
| return "greeting" | |
| return predicted | |