ChatbotRAG / intent_classifier.py
minhvtt's picture
Upload 26 files
75033ed verified
raw
history blame
7.14 kB
"""
Intent Classifier for Hybrid RAG + FSM Chatbot
Detects user intent to route between scenario flows and RAG queries
"""
from typing import Dict, Optional, List
import re
class IntentClassifier:
"""
Classify user intent using keyword matching
Routes to either:
- Scenario flows (scripted conversations)
- RAG queries (knowledge retrieval)
"""
def __init__(self, scenarios_dir: str = "scenarios"):
"""
Initialize with auto-loading triggers from scenario JSON files
Args:
scenarios_dir: Directory containing scenario JSON files
"""
# Auto-load scenario patterns from JSON files
self.scenario_patterns = self._load_scenario_patterns(scenarios_dir)
# General question patterns (RAG)
self.general_patterns = [
# Location
"ở đâu", "địa điểm", "location", "where",
"chỗ nào", "tổ chức tại",
# Time
"mấy giờ", "khi nào", "when", "time",
"bao giờ", "thời gian", "ngày nào",
# Info
"thông tin", "info", "information",
"chi tiết", "details", "về",
# Parking
"đậu xe", "parking", "gửi xe",
# Contact
"liên hệ", "contact", "số điện thoại"
]
def _load_scenario_patterns(self, scenarios_dir: str) -> dict:
"""
Auto-load triggers from all scenario JSON files
Returns:
{"scenario_id": ["trigger1", "trigger2", ...]}
"""
import json
import os
patterns = {}
if not os.path.exists(scenarios_dir):
print(f"⚠ Scenarios directory not found: {scenarios_dir}")
return patterns
for filename in os.listdir(scenarios_dir):
if filename.endswith('.json'):
filepath = os.path.join(scenarios_dir, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
scenario = json.load(f)
scenario_id = scenario.get('scenario_id')
triggers = scenario.get('triggers', [])
if scenario_id and triggers:
patterns[scenario_id] = triggers
print(f"✓ Loaded triggers for: {scenario_id} ({len(triggers)} patterns)")
except Exception as e:
print(f"⚠ Error loading {filename}: {e}")
return patterns
def classify(
self,
message: str,
conversation_state: Optional[Dict] = None
) -> str:
"""
Classify user intent
Args:
message: User message
conversation_state: Current conversation state (optional)
{
"active_scenario": "price_inquiry" | null,
"scenario_step": 3,
"scenario_data": {...}
}
Returns:
Intent string:
- "scenario:<scenario_id>" - Start new scenario
- "scenario:continue" - Continue current scenario
- "rag:general" - General RAG query
- "rag:with_resume" - RAG query but resume scenario after
"""
message_lower = message.lower().strip()
state = conversation_state or {}
# Check if in active scenario
in_scenario = state.get("active_scenario") is not None
if in_scenario:
# User is mid-scenario
# Check if message is off-topic question
if self._is_general_question(message_lower):
return "rag:with_resume"
else:
# Continue scenario (user answering scenario question)
return "scenario:continue"
# Not in scenario - check for new scenario triggers
for scenario_id, patterns in self.scenario_patterns.items():
if self._matches_any_pattern(message_lower, patterns):
return f"scenario:{scenario_id}"
# Default: general RAG query
return "rag:general"
def _is_general_question(self, message: str) -> bool:
"""
Check if message is a general question (should use RAG)
"""
return self._matches_any_pattern(message, self.general_patterns)
def _matches_any_pattern(self, message: str, patterns: List[str]) -> bool:
"""
Check if message matches any pattern in list
"""
for pattern in patterns:
# Simple substring match (case insensitive already done)
if pattern in message:
return True
# Check word boundary
if re.search(rf'\b{re.escape(pattern)}\b', message, re.IGNORECASE):
return True
return False
def get_scenario_type(self, intent: str) -> Optional[str]:
"""
Extract scenario type from intent string
Args:
intent: "scenario:price_inquiry" or "scenario:continue"
Returns:
"price_inquiry" or None
"""
if not intent.startswith("scenario:"):
return None
parts = intent.split(":", 1)
if len(parts) < 2:
return None
scenario_type = parts[1]
if scenario_type == "continue":
return None
return scenario_type
def add_scenario_pattern(self, scenario_id: str, patterns: List[str]):
"""
Dynamically add new scenario patterns
"""
if scenario_id in self.scenario_patterns:
self.scenario_patterns[scenario_id].extend(patterns)
else:
self.scenario_patterns[scenario_id] = patterns
def add_general_pattern(self, patterns: List[str]):
"""
Dynamically add new general question patterns
"""
self.general_patterns.extend(patterns)
# Example usage
if __name__ == "__main__":
classifier = IntentClassifier()
# Test cases
test_cases = [
("giá vé bao nhiêu?", None),
("sự kiện ở đâu?", None),
("đặt vé cho tôi", None),
("A show", {"active_scenario": "price_inquiry", "scenario_step": 1}),
("sự kiện mấy giờ?", {"active_scenario": "price_inquiry", "scenario_step": 3}),
]
print("Intent Classification Test:")
print("-" * 50)
for message, state in test_cases:
intent = classifier.classify(message, state)
print(f"Message: {message}")
print(f"State: {state}")
print(f"Intent: {intent}")
print()