Spaces:
Sleeping
Sleeping
| """ | |
| Intent Classifier for Hybrid RAG + FSM Chatbot | |
| Detects user intent to route between scenario flows and RAG queries | |
| """ | |
| from typing import Dict, Optional, List | |
| import re | |
| class IntentClassifier: | |
| """ | |
| Classify user intent using keyword matching | |
| Routes to either: | |
| - Scenario flows (scripted conversations) | |
| - RAG queries (knowledge retrieval) | |
| """ | |
| def __init__(self, scenarios_dir: str = "scenarios"): | |
| """ | |
| Initialize with auto-loading triggers from scenario JSON files | |
| Args: | |
| scenarios_dir: Directory containing scenario JSON files | |
| """ | |
| # Auto-load scenario patterns from JSON files | |
| self.scenario_patterns = self._load_scenario_patterns(scenarios_dir) | |
| # General question patterns (RAG) | |
| self.general_patterns = [ | |
| # Location | |
| "ở đâu", "địa điểm", "location", "where", | |
| "chỗ nào", "tổ chức tại", | |
| # Time | |
| "mấy giờ", "khi nào", "when", "time", | |
| "bao giờ", "thời gian", "ngày nào", | |
| # Info | |
| "thông tin", "info", "information", | |
| "chi tiết", "details", "về", | |
| # Parking | |
| "đậu xe", "parking", "gửi xe", | |
| # Contact | |
| "liên hệ", "contact", "số điện thoại" | |
| ] | |
| def _load_scenario_patterns(self, scenarios_dir: str) -> dict: | |
| """ | |
| Auto-load triggers from all scenario JSON files | |
| Returns: | |
| {"scenario_id": ["trigger1", "trigger2", ...]} | |
| """ | |
| import json | |
| import os | |
| patterns = {} | |
| if not os.path.exists(scenarios_dir): | |
| print(f"⚠ Scenarios directory not found: {scenarios_dir}") | |
| return patterns | |
| for filename in os.listdir(scenarios_dir): | |
| if filename.endswith('.json'): | |
| filepath = os.path.join(scenarios_dir, filename) | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| scenario = json.load(f) | |
| scenario_id = scenario.get('scenario_id') | |
| triggers = scenario.get('triggers', []) | |
| if scenario_id and triggers: | |
| patterns[scenario_id] = triggers | |
| print(f"✓ Loaded triggers for: {scenario_id} ({len(triggers)} patterns)") | |
| except Exception as e: | |
| print(f"⚠ Error loading {filename}: {e}") | |
| return patterns | |
| def classify( | |
| self, | |
| message: str, | |
| conversation_state: Optional[Dict] = None | |
| ) -> str: | |
| """ | |
| Classify user intent | |
| Args: | |
| message: User message | |
| conversation_state: Current conversation state (optional) | |
| { | |
| "active_scenario": "price_inquiry" | null, | |
| "scenario_step": 3, | |
| "scenario_data": {...} | |
| } | |
| Returns: | |
| Intent string: | |
| - "scenario:<scenario_id>" - Start new scenario | |
| - "scenario:continue" - Continue current scenario | |
| - "rag:general" - General RAG query | |
| - "rag:with_resume" - RAG query but resume scenario after | |
| """ | |
| message_lower = message.lower().strip() | |
| state = conversation_state or {} | |
| # Check if in active scenario | |
| in_scenario = state.get("active_scenario") is not None | |
| if in_scenario: | |
| # User is mid-scenario | |
| # Check if message is off-topic question | |
| if self._is_general_question(message_lower): | |
| return "rag:with_resume" | |
| else: | |
| # Continue scenario (user answering scenario question) | |
| return "scenario:continue" | |
| # Not in scenario - check for new scenario triggers | |
| for scenario_id, patterns in self.scenario_patterns.items(): | |
| if self._matches_any_pattern(message_lower, patterns): | |
| return f"scenario:{scenario_id}" | |
| # Default: general RAG query | |
| return "rag:general" | |
| def _is_general_question(self, message: str) -> bool: | |
| """ | |
| Check if message is a general question (should use RAG) | |
| """ | |
| return self._matches_any_pattern(message, self.general_patterns) | |
| def _matches_any_pattern(self, message: str, patterns: List[str]) -> bool: | |
| """ | |
| Check if message matches any pattern in list | |
| """ | |
| for pattern in patterns: | |
| # Simple substring match (case insensitive already done) | |
| if pattern in message: | |
| return True | |
| # Check word boundary | |
| if re.search(rf'\b{re.escape(pattern)}\b', message, re.IGNORECASE): | |
| return True | |
| return False | |
| def get_scenario_type(self, intent: str) -> Optional[str]: | |
| """ | |
| Extract scenario type from intent string | |
| Args: | |
| intent: "scenario:price_inquiry" or "scenario:continue" | |
| Returns: | |
| "price_inquiry" or None | |
| """ | |
| if not intent.startswith("scenario:"): | |
| return None | |
| parts = intent.split(":", 1) | |
| if len(parts) < 2: | |
| return None | |
| scenario_type = parts[1] | |
| if scenario_type == "continue": | |
| return None | |
| return scenario_type | |
| def add_scenario_pattern(self, scenario_id: str, patterns: List[str]): | |
| """ | |
| Dynamically add new scenario patterns | |
| """ | |
| if scenario_id in self.scenario_patterns: | |
| self.scenario_patterns[scenario_id].extend(patterns) | |
| else: | |
| self.scenario_patterns[scenario_id] = patterns | |
| def add_general_pattern(self, patterns: List[str]): | |
| """ | |
| Dynamically add new general question patterns | |
| """ | |
| self.general_patterns.extend(patterns) | |
| # Example usage | |
| if __name__ == "__main__": | |
| classifier = IntentClassifier() | |
| # Test cases | |
| test_cases = [ | |
| ("giá vé bao nhiêu?", None), | |
| ("sự kiện ở đâu?", None), | |
| ("đặt vé cho tôi", None), | |
| ("A show", {"active_scenario": "price_inquiry", "scenario_step": 1}), | |
| ("sự kiện mấy giờ?", {"active_scenario": "price_inquiry", "scenario_step": 3}), | |
| ] | |
| print("Intent Classification Test:") | |
| print("-" * 50) | |
| for message, state in test_cases: | |
| intent = classifier.classify(message, state) | |
| print(f"Message: {message}") | |
| print(f"State: {state}") | |
| print(f"Intent: {intent}") | |
| print() | |