mvi-ai-engine / language /intent.py
Musombi's picture
Update language/intent.py
37d8e7d
import torch
import torch.nn as nn
from core.device import DEVICE
class IntentClassifier(nn.Module):
def __init__(self, input_dim: int, intent_labels: list):
"""
input_dim: dimension of sentence embeddings
intent_labels: list of intent names
"""
super().__init__()
self.intent_labels = intent_labels
self.num_intents = len(intent_labels)
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, self.num_intents)
)
self.softmax = nn.Softmax(dim=-1)
self.to(DEVICE)
def forward(self, sentence_embedding: torch.Tensor) -> torch.Tensor:
sentence_embedding = sentence_embedding.to(DEVICE)
return self.classifier(sentence_embedding)
def predict(self, sentence_embedding: torch.Tensor, threshold=0.4) -> dict:
with torch.no_grad():
logits = self.forward(sentence_embedding)
probs = self.softmax(logits)
confidence, idx = torch.max(probs, dim=-1)
label = self.intent_labels[idx.item()]
if confidence.item() < threshold:
label = "question_general"
return {
"intent": label,
"confidence": confidence.item()
}
# ------------------------
# RULE-BASED OVERRIDE
# ------------------------
def override_intent(text: str, predicted: str) -> str:
text = text.lower().strip()
if "who are you" in text:
return "self_identity"
if text.startswith(("hi", "hello", "hey")):
return "greeting"
return predicted