Musombi commited on
Commit
7f8e444
·
verified ·
1 Parent(s): 8e5185b

Update language/intent.py

Browse files
Files changed (1) hide show
  1. language/intent.py +32 -10
language/intent.py CHANGED
@@ -2,10 +2,11 @@ import torch
2
  import torch.nn as nn
3
  from core.device import DEVICE
4
 
 
5
  class IntentClassifier(nn.Module):
6
  def __init__(self, input_dim: int, intent_labels: list):
7
  """
8
- input_dim: dimension of sentence embeddings (HIDDEN_DIM)
9
  intent_labels: list of intent names
10
  """
11
  super().__init__()
@@ -13,23 +14,25 @@ class IntentClassifier(nn.Module):
13
  self.intent_labels = intent_labels
14
  self.num_intents = len(intent_labels)
15
 
16
- self.classifier = nn.Linear(input_dim, self.num_intents)
 
 
 
 
 
 
 
 
 
17
  self.softmax = nn.Softmax(dim=-1)
18
 
19
  self.to(DEVICE)
20
 
21
  def forward(self, sentence_embedding: torch.Tensor) -> torch.Tensor:
22
- """
23
- sentence_embedding: (batch_size, input_dim)
24
- returns logits: (batch_size, num_intents)
25
- """
26
  sentence_embedding = sentence_embedding.to(DEVICE)
27
  return self.classifier(sentence_embedding)
28
 
29
- def predict(self, sentence_embedding: torch.Tensor) -> dict:
30
- """
31
- Returns intent label + confidence
32
- """
33
  with torch.no_grad():
34
  logits = self.forward(sentence_embedding)
35
  probs = self.softmax(logits)
@@ -37,7 +40,26 @@ class IntentClassifier(nn.Module):
37
  confidence, idx = torch.max(probs, dim=-1)
38
  label = self.intent_labels[idx.item()]
39
 
 
 
 
40
  return {
41
  "intent": label,
42
  "confidence": confidence.item()
43
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  from core.device import DEVICE
4
 
5
+
6
  class IntentClassifier(nn.Module):
7
  def __init__(self, input_dim: int, intent_labels: list):
8
  """
9
+ input_dim: dimension of sentence embeddings
10
  intent_labels: list of intent names
11
  """
12
  super().__init__()
 
14
  self.intent_labels = intent_labels
15
  self.num_intents = len(intent_labels)
16
 
17
+ self.classifier = nn.Sequential(
18
+ nn.Linear(input_dim, 512),
19
+ nn.ReLU(),
20
+ nn.Dropout(0.3),
21
+ nn.Linear(512, 256),
22
+ nn.ReLU(),
23
+ nn.Dropout(0.3),
24
+ nn.Linear(256, self.num_intents)
25
+ )
26
+
27
  self.softmax = nn.Softmax(dim=-1)
28
 
29
  self.to(DEVICE)
30
 
31
  def forward(self, sentence_embedding: torch.Tensor) -> torch.Tensor:
 
 
 
 
32
  sentence_embedding = sentence_embedding.to(DEVICE)
33
  return self.classifier(sentence_embedding)
34
 
35
+ def predict(self, sentence_embedding: torch.Tensor, threshold=0.4) -> dict:
 
 
 
36
  with torch.no_grad():
37
  logits = self.forward(sentence_embedding)
38
  probs = self.softmax(logits)
 
40
  confidence, idx = torch.max(probs, dim=-1)
41
  label = self.intent_labels[idx.item()]
42
 
43
+ if confidence.item() < threshold:
44
+ label = "question_general"
45
+
46
  return {
47
  "intent": label,
48
  "confidence": confidence.item()
49
  }
50
+
51
+
52
+ # ------------------------
53
+ # RULE-BASED OVERRIDE
54
+ # ------------------------
55
+
56
+ def override_intent(text: str, predicted: str) -> str:
57
+ text = text.lower().strip()
58
+
59
+ if "who are you" in text:
60
+ return "self_identity"
61
+
62
+ if text.startswith(("hi", "hello", "hey")):
63
+ return "greeting"
64
+
65
+ return predicted