Spaces:
Sleeping
Sleeping
| """ | |
| Base Multi-Agent Chatbot - Abstract base class with sophisticated query analysis | |
| This module extracts the core multi-agent logic from MultiAgentRAGChatbot: | |
| - Sophisticated LLM-based query analysis | |
| - Filter extraction and validation | |
| - Query rewriting | |
| - Conversation management | |
| - Main agent, RAG agent, Response agent logic | |
| Subclasses only need to implement: | |
| - _perform_retrieval(): The actual retrieval mechanism (text-based RAG vs visual search) | |
| """ | |
| import re | |
| import json | |
| import time | |
| import logging | |
| import traceback | |
| from pathlib import Path | |
| from datetime import datetime | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Any, Optional, TypedDict, Union | |
| from abc import ABC, abstractmethod | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from src.llm.adapters import get_llm_client | |
| from src.config.paths import PROJECT_DIR, CONVERSATIONS_DIR | |
| from src.config.loader import load_config | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class QueryContext: | |
| """Context extracted from conversation""" | |
| has_district: bool = False | |
| has_source: bool = False | |
| has_year: bool = False | |
| extracted_district: Optional[Union[str, List[str]]] = None | |
| extracted_source: Optional[Union[str, List[str]]] = None | |
| extracted_year: Optional[Union[str, List[str]]] = None | |
| ui_filters: Dict[str, List[str]] = None | |
| confidence_score: float = 0.0 | |
| needs_follow_up: bool = False | |
| follow_up_question: Optional[str] = None | |
| def __post_init__(self): | |
| self._process_multiple("extracted_source") | |
| self._process_multiple("extracted_district") | |
| def _process_multiple(self, key): | |
| if isinstance(self.__dict__[key], list): | |
| self.__dict__[key] = [d.title() for d in self.__dict__[key]] | |
| else: | |
| self.__dict__[key] = self.__dict__[key].title() if self.__dict__[key] else None | |
| class MultiAgentState(TypedDict): | |
| """State for the multi-agent conversation flow""" | |
| conversation_id: str | |
| messages: List[Any] | |
| current_query: str | |
| query_context: Optional[QueryContext] | |
| rag_query: Optional[str] | |
| rag_filters: Optional[Dict[str, Any]] | |
| retrieved_documents: Optional[List[Any]] | |
| final_response: Optional[str] | |
| agent_logs: List[str] | |
| conversation_context: Dict[str, Any] | |
| session_start_time: float | |
| last_ai_message_time: float | |
| class BaseMultiAgentChatbot(ABC): | |
| """ | |
| Abstract base class for multi-agent chatbots. | |
| Provides all the sophisticated logic from MultiAgentRAGChatbot: | |
| - LLM-based query analysis | |
| - Filter extraction and validation | |
| - Query rewriting | |
| - Main agent, RAG agent, Response agent | |
| Subclasses only need to implement: | |
| - _perform_retrieval(): The actual retrieval mechanism | |
| """ | |
| def __init__(self, config_path: str = "src/config/settings.yaml"): | |
| """Initialize the base multi-agent chatbot""" | |
| self.config = load_config(config_path) | |
| # Get LLM provider from config | |
| reader_config = self.config.get("reader", {}) | |
| default_type = reader_config.get("default_type", "INF_PROVIDERS") | |
| provider_name = default_type.lower() | |
| self.llm_adapter = get_llm_client(provider_name, self.config) | |
| # Create LangChain-compatible wrapper | |
| class LLMWrapper: | |
| def __init__(self, adapter): | |
| self.adapter = adapter | |
| def invoke(self, messages): | |
| if isinstance(messages, list): | |
| formatted_messages = [] | |
| for msg in messages: | |
| if hasattr(msg, 'content'): | |
| role = "user" if msg.__class__.__name__ == "HumanMessage" else "assistant" | |
| formatted_messages.append({"role": role, "content": msg.content}) | |
| else: | |
| formatted_messages.append({"role": "user", "content": str(msg)}) | |
| else: | |
| formatted_messages = [{"role": "user", "content": str(messages)}] | |
| response = self.adapter.generate(formatted_messages) | |
| class MockResponse: | |
| def __init__(self, content): | |
| self.content = content | |
| return MockResponse(response.content) | |
| self.llm = LLMWrapper(self.llm_adapter) | |
| # Load dynamic data (filter options) | |
| self._load_dynamic_data() | |
| # Build the multi-agent graph | |
| self.graph = self._build_graph() | |
| # Conversations directory | |
| self.conversations_dir = CONVERSATIONS_DIR | |
| try: | |
| self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| except (PermissionError, OSError) as e: | |
| logger.warning(f"Could not create conversations directory at {self.conversations_dir}: {e}") | |
| self.conversations_dir = Path("conversations") | |
| try: | |
| self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| except (PermissionError, OSError) as e2: | |
| logger.error(f"Could not create conversations directory at {self.conversations_dir}: {e2}") | |
| raise RuntimeError(f"Failed to create conversations directory: {e2}") | |
| logger.info("🤖 Base Multi-Agent Chatbot initialized") | |
| def _load_dynamic_data(self): | |
| """Load dynamic data from filter_options.json""" | |
| try: | |
| fo = PROJECT_DIR / "src" / "config" / "filter_options.json" | |
| if fo.exists(): | |
| with open(fo) as f: | |
| data = json.load(f) | |
| self.year_whitelist = [str(y).strip() for y in data.get("years", [])] | |
| self.source_whitelist = [str(s).strip() for s in data.get("sources", [])] | |
| self.district_whitelist = [str(d).strip() for d in data.get("districts", [])] | |
| else: | |
| self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024'] | |
| self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency'] | |
| self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala'] | |
| except Exception as e: | |
| logger.warning(f"Could not load filter options: {e}") | |
| self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024'] | |
| self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency'] | |
| self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala'] | |
| # Enrich district list | |
| try: | |
| from add_district_metadata import DistrictMetadataProcessor | |
| proc = DistrictMetadataProcessor() | |
| names = set() | |
| for key, mapping in proc.district_mappings.items(): | |
| if getattr(mapping, 'is_district', True): | |
| names.add(mapping.name) | |
| if names: | |
| merged = list(self.district_whitelist) | |
| for n in sorted(names): | |
| if n not in merged: | |
| merged.append(n) | |
| self.district_whitelist = merged | |
| logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries") | |
| except Exception as e: | |
| logger.info(f"ℹ️ Could not enrich districts: {e}") | |
| # Calculate current year dynamically | |
| self.current_year = str(datetime.now().year) | |
| self.previous_year = str(datetime.now().year - 1) | |
| logger.info(f"📊 ACTUAL FILTER VALUES:") | |
| logger.info(f" Years: {self.year_whitelist}") | |
| logger.info(f" Sources: {self.source_whitelist}") | |
| logger.info(f" Districts: {len(self.district_whitelist)} districts (first 30: {self.district_whitelist[:30]})") | |
| def _normalize_district_name(self, district: str) -> Optional[str]: | |
| """Normalize district name with fuzzy matching - ALWAYS returns title case for Qdrant compatibility""" | |
| if not district: | |
| return None | |
| district = district.strip() | |
| district_title = district.title() | |
| # Check if district exists in whitelist (case-insensitive) | |
| district_lower = district.lower() | |
| whitelist_lower = {d.lower(): d for d in self.district_whitelist} | |
| # Direct match (case-insensitive) - always return title case | |
| if district_lower in whitelist_lower: | |
| return district_title # Return title case, not whitelist value | |
| # Remove "District" suffix and try again | |
| district_name = district.replace(" District", "").replace(" district", "").strip() | |
| district_name_lower = district_name.lower() | |
| district_name_title = district_name.title() | |
| if district_name_lower in whitelist_lower: | |
| return district_name_title # Return title case | |
| # Common misspellings and abbreviations - return correct case | |
| misspelling_map = { | |
| "kalagala": "Kalangala", | |
| "kalangala": "Kalangala", | |
| "gulu": "Gulu", | |
| "kampala": "Kampala", | |
| "padr": "Pader", | |
| "padre": "Pader", | |
| "pader": "Pader", | |
| "kcc": "Kcca", # Match whitelist format | |
| "kcca": "Kcca", # Match whitelist format | |
| "kimboga": "Kiboga", | |
| "kiboga": "Kiboga", | |
| "jinja": "Jinja", | |
| "mbale": "Mbale", | |
| "mbarara": "Mbarara", | |
| "soroti": "Soroti", | |
| "lira": "Lira", | |
| "arua": "Arua", | |
| "masaka": "Masaka", | |
| "fort portal": "Fort Portal", | |
| "fortportal": "Fort Portal", | |
| } | |
| if district_name_lower in misspelling_map: | |
| return misspelling_map[district_name_lower] # Already title case | |
| # Fuzzy matching (case-insensitive) - return title case | |
| for whitelist_district in self.district_whitelist: | |
| if district_name_lower == whitelist_district.lower(): | |
| return district_name_title # Return title case | |
| if len(district_name) >= 4 and len(whitelist_district) >= 4: | |
| if district_name_lower in whitelist_district.lower() or whitelist_district.lower() in district_name_lower: | |
| min_len = min(len(district_name), len(whitelist_district)) | |
| max_len = max(len(district_name), len(whitelist_district)) | |
| if min_len / max_len >= 0.8: | |
| return district_name_title # Return title case | |
| # Last resort: if input looks valid, return title case anyway | |
| # This handles cases where whitelist might be incomplete | |
| if len(district_name) >= 3: | |
| return district_name_title | |
| return None | |
| def _build_graph(self) -> StateGraph: | |
| """Build the multi-agent LangGraph""" | |
| graph = StateGraph(MultiAgentState) | |
| # Add nodes for each agent | |
| graph.add_node("main_agent", self._main_agent) | |
| graph.add_node("rag_agent", self._rag_agent) | |
| graph.add_node("response_agent", self._response_agent) | |
| # Define the flow | |
| graph.set_entry_point("main_agent") | |
| # Main agent decides next step | |
| graph.add_conditional_edges( | |
| "main_agent", | |
| self._should_call_rag, | |
| { | |
| "follow_up": END, | |
| "call_rag": "rag_agent" | |
| } | |
| ) | |
| # RAG agent calls response agent | |
| graph.add_edge("rag_agent", "response_agent") | |
| # Response agent returns to main agent | |
| graph.add_edge("response_agent", "main_agent") | |
| return graph.compile() | |
| def _should_call_rag(self, state: MultiAgentState) -> str: | |
| """Determine if we should call RAG or ask follow-up""" | |
| if state.get("final_response"): | |
| return "follow_up" | |
| context = state["query_context"] | |
| if context and context.needs_follow_up: | |
| return "follow_up" | |
| return "call_rag" | |
| def _main_agent(self, state: MultiAgentState) -> MultiAgentState: | |
| """Main Agent: Handles conversation flow and follow-ups""" | |
| logger.info("🎯 MAIN AGENT: Starting analysis") | |
| if state.get("final_response"): | |
| logger.info("🎯 MAIN AGENT: Final response already exists, ending") | |
| return state | |
| query = state["current_query"] | |
| messages = state["messages"] | |
| logger.info(f"🎯 MAIN AGENT: Extracting UI filters from query") | |
| ui_filters = self._extract_ui_filters(query) | |
| logger.info(f"🎯 MAIN AGENT: UI filters extracted: {ui_filters}") | |
| # Analyze query context using LLM | |
| logger.info(f"🎯 MAIN AGENT: Analyzing query context") | |
| context = self._analyze_query_context(query, messages, ui_filters) | |
| state["agent_logs"].append(f"MAIN AGENT: Context analyzed - district={context.has_district}, source={context.has_source}, year={context.has_year}") | |
| logger.info(f"🎯 MAIN AGENT: Context analysis complete") | |
| state["query_context"] = context | |
| # If follow-up needed, generate response | |
| if context.needs_follow_up: | |
| logger.info(f"🎯 MAIN AGENT: Follow-up needed, generating question") | |
| response = context.follow_up_question | |
| state["final_response"] = response | |
| state["last_ai_message_time"] = time.time() | |
| else: | |
| logger.info("🎯 MAIN AGENT: No follow-up needed, proceeding to RAG") | |
| return state | |
| def _rag_agent(self, state: MultiAgentState) -> MultiAgentState: | |
| """RAG Agent: Rewrites queries and applies filters""" | |
| logger.info("🔍 RAG AGENT: Starting query rewriting and filter preparation") | |
| context = state["query_context"] | |
| messages = state["messages"] | |
| # Rewrite query for RAG | |
| logger.info(f"🔍 RAG AGENT: Rewriting query for optimal retrieval") | |
| rag_query = self._rewrite_query_for_rag(messages, context) | |
| logger.info(f"🔍 RAG AGENT: Query rewritten: '{rag_query}'") | |
| # Build filters | |
| logger.info(f"🔍 RAG AGENT: Building filters from context: {context}") | |
| filters = self._build_filters(context) | |
| logger.info(f"🔍 RAG AGENT: Filters built: {filters}") | |
| state["agent_logs"].append(f"RAG AGENT: Query='{rag_query}', Filters={filters}") | |
| state["rag_query"] = rag_query | |
| state["rag_filters"] = filters | |
| return state | |
| def _response_agent(self, state: MultiAgentState) -> MultiAgentState: | |
| """Response Agent: Generates final answer from retrieved documents""" | |
| logger.info("📝 RESPONSE AGENT: Starting document retrieval and answer generation") | |
| rag_query = state["rag_query"] | |
| filters = state["rag_filters"] | |
| logger.info(f"📝 RESPONSE AGENT: Calling retrieval with query: '{rag_query}'") | |
| logger.info(f"📝 RESPONSE AGENT: Using filters: {filters}") | |
| try: | |
| # Call subclass-specific retrieval method | |
| result = self._perform_retrieval(rag_query, filters) | |
| state["retrieved_documents"] = result.sources | |
| state["agent_logs"].append(f"RESPONSE AGENT: Retrieved {len(result.sources)} documents") | |
| logger.info(f"📝 RESPONSE AGENT: Retrieved {len(result.sources)} documents") | |
| # Check highest similarity score | |
| highest_score = 0.0 | |
| if result.sources: | |
| for doc in result.sources: | |
| score = getattr(doc, 'metadata', {}).get('reranked_score') or getattr(doc, 'metadata', {}).get('original_score', 0.0) if hasattr(doc, 'metadata') else getattr(doc, 'score', 0.0) | |
| if score > highest_score: | |
| highest_score = score | |
| logger.info(f"📝 RESPONSE AGENT: Highest similarity score: {highest_score:.4f}") | |
| # If highest score is too low, use LLM knowledge only | |
| if highest_score <= 0.15: | |
| logger.warning(f"⚠️ RESPONSE AGENT: Low similarity score, using LLM knowledge only") | |
| response = self._generate_conversational_response_without_docs( | |
| state["current_query"], | |
| state["messages"] | |
| ) | |
| else: | |
| # Generate conversational response with documents | |
| response = self._generate_conversational_response( | |
| state["current_query"], | |
| result.sources, | |
| result.answer, | |
| state["messages"], | |
| filters # Pass filters for coverage validation | |
| ) | |
| state["final_response"] = response | |
| state["last_ai_message_time"] = time.time() | |
| logger.info(f"📝 RESPONSE AGENT: Answer generation complete") | |
| except Exception as e: | |
| logger.error(f"❌ RESPONSE AGENT ERROR: {e}") | |
| traceback.print_exc() | |
| state["final_response"] = "I apologize, but I encountered an error while retrieving information. Please try again." | |
| state["last_ai_message_time"] = time.time() | |
| return state | |
| def _perform_retrieval(self, query: str, filters: Dict[str, Any]) -> Any: | |
| """ | |
| Perform retrieval - must be implemented by subclasses. | |
| Args: | |
| query: The rewritten query | |
| filters: The filters to apply | |
| Returns: | |
| Result object with .sources and .answer attributes | |
| """ | |
| pass | |
| def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]: | |
| """Extract UI filters from query""" | |
| filters = {} | |
| if "FILTER CONTEXT:" in query: | |
| filter_section = query.split("FILTER CONTEXT:")[1] | |
| if "USER QUERY:" in filter_section: | |
| filter_section = filter_section.split("USER QUERY:")[0] | |
| filter_section = filter_section.strip() | |
| if "Sources:" in filter_section: | |
| sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')][0] | |
| sources_str = sources_line.split("Sources:")[1].strip() | |
| if sources_str and sources_str != "None": | |
| filters["sources"] = [s.strip() for s in sources_str.split(",")] | |
| if "Years:" in filter_section: | |
| years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')][0] | |
| years_str = years_line.split("Years:")[1].strip() | |
| if years_str and years_str != "None": | |
| filters["years"] = [y.strip() for y in years_str.split(",")] | |
| if "Districts:" in filter_section: | |
| districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')][0] | |
| districts_str = districts_line.split("Districts:")[1].strip() | |
| if districts_str and districts_str != "None": | |
| filters["districts"] = [d.strip() for d in districts_str.split(",")] | |
| if "Filenames:" in filter_section: | |
| filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')][0] | |
| filenames_str = filenames_line.split("Filenames:")[1].strip() | |
| if filenames_str and filenames_str != "None": | |
| filters["filenames"] = [f.strip() for f in filenames_str.split(",")] | |
| return filters | |
| def _analyze_query_context(self, query: str, messages: List[Any], ui_filters: Dict[str, List[str]]) -> QueryContext: | |
| """Analyze query context using LLM - EXACT COPY FROM v1""" | |
| logger.info(f"🔍 QUERY ANALYSIS: '{query[:50]}...' | UI filters: {ui_filters}") | |
| # Build conversation context | |
| conversation_context = "" | |
| for msg in messages[-6:]: | |
| if isinstance(msg, HumanMessage): | |
| conversation_context += f"User: {msg.content}\n" | |
| elif isinstance(msg, AIMessage): | |
| conversation_context += f"Assistant: {msg.content}\n" | |
| # Create analysis prompt - ENHANCED FOR BETTER EXTRACTION | |
| analysis_prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessage(content=f"""You are the Main Agent in an advanced multi-agent RAG system for audit report analysis. | |
| 🎯 PRIMARY GOAL: Intelligently analyze user queries and determine the optimal conversation flow, whether that's answering directly, asking follow-ups, or proceeding to RAG retrieval. | |
| 🧠 INTELLIGENCE LEVEL: You are a sophisticated conversational AI that can handle any type of user interaction - from greetings to complex audit queries. | |
| 📊 YOUR EXPERTISE: You specialize in analyzing audit reports from various sources (Local Government, Ministry, Hospital, etc.) across different years and districts in Uganda. | |
| 🔍 AVAILABLE FILTERS: | |
| - Years: {', '.join(self.year_whitelist)} | |
| - Current year: {self.current_year}, Previous year: {self.previous_year} | |
| - Sources: {', '.join(self.source_whitelist)} | |
| - Districts: {', '.join(self.district_whitelist[:50])}... (and {len(self.district_whitelist)-50} more) | |
| 🎛️ UI FILTERS PROVIDED: {ui_filters} | |
| 📋 UI FILTER HANDLING: | |
| - If UI filters contain multiple values, extract ALL values | |
| - UI filters take PRIORITY over conversation context | |
| ⚠️ CRITICAL EXTRACTION RULES: | |
| 1. **RELATIVE YEAR REFERENCES** - Convert to explicit years: | |
| - "last couple of years" / "last 2 years" → [{self.previous_year}, {str(int(self.previous_year)-1)}] (2 years) | |
| - "last few years" / "last 3 years" → [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}] (3 years) | |
| - "recent years" → [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}] | |
| - "this year" → ["{self.current_year}"] | |
| - "last year" → ["{self.previous_year}"] | |
| 2. **DISTRICT TYPOS & ABBREVIATIONS** - Correct common mistakes: | |
| - "KCC" or "KCCA" → "KCCA" (Kampala Capital City Authority) | |
| - "Padr" or "Padre" → "Pader" | |
| - "Kimboga" → "Kiboga" | |
| - "Kalagala" → "Kalangala" | |
| 3. **MULTIPLE VALUES** - Extract ALL mentioned items: | |
| - If user says "Kampala, Kiboga, and Pader" → extract ALL THREE districts | |
| - If user says "2022, 2023, 2024" → extract ALL THREE years | |
| - Use "+" or "and" or "," as separators | |
| 🧭 CONVERSATION FLOW INTELLIGENCE: | |
| 1. **GREETINGS & GENERAL CHAT**: | |
| - If user greets you, respond warmly and guide them | |
| 2. **AUDIT QUERIES**: | |
| - Extract values matching the available lists (with typo correction) | |
| - DO NOT hallucinate values not mentioned by user | |
| 3. **SMART FOLLOW-UP STRATEGY**: | |
| - If user provides 2+ pieces of info, proceed to RAG | |
| - If user provides 1 piece of info, ask for missing piece | |
| - If user provides 0 pieces of info, ask for clarification | |
| - NEVER ask the same question twice | |
| 🎯 DECISION LOGIC: | |
| - If query is a greeting/general chat → needs_follow_up: true | |
| - If query has 2+ pieces of info → needs_follow_up: false, proceed to RAG | |
| - If query has 1 piece of info → needs_follow_up: true, ask for missing piece | |
| - If query has 0 pieces of info → needs_follow_up: true, ask for clarification | |
| RESPOND WITH JSON ONLY: | |
| {{ | |
| "has_district": boolean, | |
| "has_source": boolean, | |
| "has_year": boolean, | |
| "extracted_district": "single or array or null", | |
| "extracted_source": "single or array or null", | |
| "extracted_year": "single or array or null", | |
| "confidence_score": 0.0-1.0, | |
| "needs_follow_up": boolean, | |
| "follow_up_question": "question or null" | |
| }}"""), | |
| HumanMessage(content=f"""Query: {query} | |
| Conversation Context: | |
| {conversation_context} | |
| CRITICAL: Analyze the FULL conversation context above. | |
| Analyze this query using ONLY the exact values provided above:""") | |
| ]) | |
| try: | |
| response = self.llm.invoke(analysis_prompt.format_messages()) | |
| # Clean and parse JSON | |
| content = response.content.strip() | |
| if content.startswith("```json"): | |
| content = content.replace("```json", "").replace("```", "").strip() | |
| elif content.startswith("```"): | |
| content = content.replace("```", "").strip() | |
| # Remove comments | |
| content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE) | |
| content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL) | |
| analysis = json.loads(content) | |
| logger.info(f"🔍 QUERY ANALYSIS: ✅ Parsed successfully") | |
| # Validate extracted values (same logic as v1) | |
| extracted_district = analysis.get("extracted_district") | |
| extracted_source = analysis.get("extracted_source") | |
| extracted_year = analysis.get("extracted_year") | |
| # Validate district | |
| if extracted_district: | |
| if isinstance(extracted_district, list): | |
| valid_districts = [] | |
| for district in extracted_district: | |
| normalized = self._normalize_district_name(district) | |
| if normalized: | |
| valid_districts.append(normalized) | |
| extracted_district = valid_districts[0] if len(valid_districts) == 1 else (valid_districts if valid_districts else None) | |
| else: | |
| extracted_district = self._normalize_district_name(extracted_district) | |
| # Validate source | |
| if extracted_source: | |
| if isinstance(extracted_source, list): | |
| valid_sources = [s for s in extracted_source if s in self.source_whitelist] | |
| extracted_source = valid_sources[0] if len(valid_sources) == 1 else (valid_sources if valid_sources else None) | |
| else: | |
| extracted_source = extracted_source if extracted_source in self.source_whitelist else None | |
| # Validate year | |
| if extracted_year: | |
| if isinstance(extracted_year, list): | |
| valid_years = [str(y) for y in extracted_year if str(y) in self.year_whitelist] | |
| extracted_year = valid_years[0] if len(valid_years) == 1 else (valid_years if valid_years else None) | |
| else: | |
| extracted_year = str(extracted_year) if str(extracted_year) in self.year_whitelist else None | |
| # Create QueryContext | |
| context = QueryContext( | |
| has_district=bool(extracted_district), | |
| has_source=bool(extracted_source), | |
| has_year=bool(extracted_year), | |
| extracted_district=extracted_district, | |
| extracted_source=extracted_source, | |
| extracted_year=extracted_year, | |
| ui_filters=ui_filters, | |
| confidence_score=analysis.get("confidence_score", 0.0), | |
| needs_follow_up=analysis.get("needs_follow_up", False), | |
| follow_up_question=analysis.get("follow_up_question") | |
| ) | |
| # If filenames provided, skip follow-ups | |
| if ui_filters and ui_filters.get("filenames"): | |
| context.needs_follow_up = False | |
| context.follow_up_question = None | |
| # Smart decision logic (same as v1) | |
| if context.needs_follow_up: | |
| info_count = sum([bool(context.extracted_district), bool(context.extracted_source), bool(context.extracted_year)]) | |
| query_lower = query.lower() | |
| is_requesting_info = any(phrase in query_lower for phrase in [ | |
| "please provide", "could you provide", "can you provide", | |
| "what is", "what are", "how much", "which", "what year", | |
| "what district", "what source", "tell me about", "how were", "how was" | |
| ]) | |
| if info_count >= 2 and not is_requesting_info: | |
| context.needs_follow_up = False | |
| context.follow_up_question = None | |
| elif info_count >= 2 and is_requesting_info: | |
| context.needs_follow_up = False | |
| context.follow_up_question = None | |
| return context | |
| except Exception as e: | |
| logger.error(f"❌ Query analysis failed: {e}") | |
| return QueryContext( | |
| has_district=bool(ui_filters.get("districts")), | |
| has_source=bool(ui_filters.get("sources")), | |
| has_year=bool(ui_filters.get("years")), | |
| ui_filters=ui_filters, | |
| confidence_score=0.5, | |
| needs_follow_up=False | |
| ) | |
| def _rewrite_query_for_rag(self, messages: List[Any], context: QueryContext) -> str: | |
| """Rewrite query for optimal RAG retrieval - EXACT COPY FROM v1""" | |
| logger.info("🔄 QUERY REWRITING: Starting") | |
| # Build conversation context | |
| conversation_lines = [] | |
| for msg in messages[-6:]: | |
| if isinstance(msg, HumanMessage): | |
| conversation_lines.append(f"User: {msg.content}") | |
| elif isinstance(msg, AIMessage): | |
| conversation_lines.append(f"Assistant: {msg.content}") | |
| convo_text = "\n".join(conversation_lines) | |
| # Create rewrite prompt | |
| rewrite_prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessage(content="""You are a query rewriter for RAG retrieval. | |
| GOAL: Create the best possible search query for document retrieval. | |
| CRITICAL RULES: | |
| 1. Focus on the core information need | |
| 2. Remove meta-verbs like "summarize", "list", "compare", "how much", "what" | |
| 3. DO NOT include filter details (years, districts, sources) | |
| 4. Output ONE clear sentence suitable for vector search | |
| EXAMPLES: | |
| - "What are the top challenges in budget allocation?" → "budget allocation challenges" | |
| - "How were PDM administrative costs utilized?" → "PDM administrative costs utilization" | |
| OUTPUT FORMAT: | |
| EXPLANATION: [reasoning] | |
| QUERY: [one clean sentence]"""), | |
| HumanMessage(content=f"""Conversation: | |
| {convo_text} | |
| Rewrite the best retrieval query:""") | |
| ]) | |
| try: | |
| response = self.llm.invoke(rewrite_prompt.format_messages()) | |
| rewritten = response.content.strip() | |
| # Extract QUERY line | |
| lines = rewritten.split('\n') | |
| for line in lines: | |
| if line.strip().startswith('QUERY:'): | |
| query_line = line.replace('QUERY:', '').strip() | |
| if len(query_line) > 5: | |
| return query_line | |
| # Fallback | |
| for msg in reversed(messages): | |
| if isinstance(msg, HumanMessage): | |
| return msg.content | |
| return "audit report information" | |
| except Exception as e: | |
| logger.error(f"❌ QUERY REWRITING: Error: {e}") | |
| for msg in reversed(messages): | |
| if isinstance(msg, HumanMessage): | |
| return msg.content | |
| return "audit report information" | |
| def _build_filters(self, context: QueryContext) -> Dict[str, Any]: | |
| """Build filters for RAG retrieval""" | |
| logger.info(f"🔧 FILTER BUILDING: Building filters from context: {context}") | |
| filters = {} | |
| # Check for filename filtering first | |
| if context.ui_filters and context.ui_filters.get("filenames"): | |
| filters["filenames"] = context.ui_filters["filenames"] | |
| logger.info(f"🔧 FILTER BUILDING: Using filename filter: {filters}") | |
| return filters | |
| # UI filters take priority | |
| if context.ui_filters: | |
| if context.ui_filters.get("sources"): | |
| filters["sources"] = context.ui_filters["sources"] | |
| if context.ui_filters.get("years"): | |
| filters["year"] = context.ui_filters["years"] | |
| if context.ui_filters.get("districts"): | |
| # Title case for Qdrant compatibility | |
| normalized_districts = [d.title() for d in context.ui_filters['districts']] | |
| filters["district"] = normalized_districts | |
| # Merge with extracted context | |
| if not filters.get("district") and context.extracted_district: | |
| if isinstance(context.extracted_district, list): | |
| # Normalize each district - _normalize_district_name returns correct case | |
| normalized = [self._normalize_district_name(d) for d in context.extracted_district] | |
| filters["district"] = [d for d in normalized if d] | |
| else: | |
| normalized = self._normalize_district_name(context.extracted_district) | |
| if normalized: | |
| filters["district"] = [normalized] | |
| if not filters.get("year") and context.extracted_year: | |
| filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year | |
| if not filters.get("sources") and context.extracted_source: | |
| filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source | |
| else: | |
| # Use extracted context (no UI filters) | |
| if context.extracted_source: | |
| filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source | |
| if context.extracted_year: | |
| filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year | |
| if context.extracted_district: | |
| if isinstance(context.extracted_district, list): | |
| # Normalize each district - _normalize_district_name returns correct case | |
| normalized = [self._normalize_district_name(d) for d in context.extracted_district] | |
| filters["district"] = [d for d in normalized if d] | |
| else: | |
| normalized = self._normalize_district_name(context.extracted_district) | |
| if normalized: | |
| filters["district"] = [normalized] | |
| return filters | |
| def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any], filters: Dict[str, Any] = None) -> str: | |
| """Generate conversational response - must be implemented by subclasses""" | |
| pass | |
| def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str: | |
| """Generate response without documents - must be implemented by subclasses""" | |
| pass | |
| def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]: | |
| """Main chat interface""" | |
| logger.info(f"💬 MULTI-AGENT CHAT: Processing '{user_input[:50]}...'") | |
| # Load conversation | |
| conversation_file = self.conversations_dir / f"{conversation_id}.json" | |
| conversation = self._load_conversation(conversation_file) | |
| # Add user message | |
| conversation["messages"].append(HumanMessage(content=user_input)) | |
| # Prepare state | |
| state = MultiAgentState( | |
| conversation_id=conversation_id, | |
| messages=conversation["messages"], | |
| current_query=user_input, | |
| query_context=None, | |
| rag_query=None, | |
| rag_filters=None, | |
| retrieved_documents=None, | |
| final_response=None, | |
| agent_logs=[], | |
| conversation_context=conversation.get("context", {}), | |
| session_start_time=conversation["session_start_time"], | |
| last_ai_message_time=conversation["last_ai_message_time"] | |
| ) | |
| # Run multi-agent graph | |
| final_state = self.graph.invoke(state) | |
| # Add AI response to conversation | |
| if final_state["final_response"]: | |
| conversation["messages"].append(AIMessage(content=final_state["final_response"])) | |
| # Update conversation | |
| conversation["last_ai_message_time"] = final_state["last_ai_message_time"] | |
| conversation["context"] = final_state["conversation_context"] | |
| # Save conversation | |
| self._save_conversation(conversation_file, conversation) | |
| # Return response | |
| return { | |
| 'response': final_state["final_response"], | |
| 'rag_result': { | |
| 'sources': final_state["retrieved_documents"] or [], | |
| 'answer': final_state["final_response"] | |
| }, | |
| 'agent_logs': final_state["agent_logs"], | |
| 'actual_rag_query': final_state.get("rag_query", "") | |
| } | |
| def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]: | |
| """Load conversation from file""" | |
| if conversation_file.exists(): | |
| try: | |
| with open(conversation_file) as f: | |
| data = json.load(f) | |
| messages = [] | |
| for msg_data in data.get("messages", []): | |
| if msg_data["type"] == "human": | |
| messages.append(HumanMessage(content=msg_data["content"])) | |
| elif msg_data["type"] == "ai": | |
| messages.append(AIMessage(content=msg_data["content"])) | |
| data["messages"] = messages | |
| return data | |
| except Exception as e: | |
| logger.warning(f"Could not load conversation: {e}") | |
| return { | |
| "messages": [], | |
| "session_start_time": time.time(), | |
| "last_ai_message_time": time.time(), | |
| "context": {} | |
| } | |
| def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]): | |
| """Save conversation to file""" | |
| try: | |
| conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| messages_data = [] | |
| for msg in conversation["messages"]: | |
| if isinstance(msg, HumanMessage): | |
| messages_data.append({"type": "human", "content": msg.content}) | |
| elif isinstance(msg, AIMessage): | |
| messages_data.append({"type": "ai", "content": msg.content}) | |
| conversation_data = { | |
| "messages": messages_data, | |
| "session_start_time": conversation["session_start_time"], | |
| "last_ai_message_time": conversation["last_ai_message_time"], | |
| "context": conversation.get("context", {}) | |
| } | |
| with open(conversation_file, 'w') as f: | |
| json.dump(conversation_data, f, indent=2) | |
| except Exception as e: | |
| logger.error(f"Could not save conversation: {e}") | |