audit_assistant / src /agents /base_multi_agent_chatbot.py
akryldigital's picture
create base Agent factory
9db763a verified
"""
Base Multi-Agent Chatbot - Abstract base class with sophisticated query analysis
This module extracts the core multi-agent logic from MultiAgentRAGChatbot:
- Sophisticated LLM-based query analysis
- Filter extraction and validation
- Query rewriting
- Conversation management
- Main agent, RAG agent, Response agent logic
Subclasses only need to implement:
- _perform_retrieval(): The actual retrieval mechanism (text-based RAG vs visual search)
"""
import re
import json
import time
import logging
import traceback
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
from typing import Dict, List, Any, Optional, TypedDict, Union
from abc import ABC, abstractmethod
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from src.llm.adapters import get_llm_client
from src.config.paths import PROJECT_DIR, CONVERSATIONS_DIR
from src.config.loader import load_config
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class QueryContext:
"""Context extracted from conversation"""
has_district: bool = False
has_source: bool = False
has_year: bool = False
extracted_district: Optional[Union[str, List[str]]] = None
extracted_source: Optional[Union[str, List[str]]] = None
extracted_year: Optional[Union[str, List[str]]] = None
ui_filters: Dict[str, List[str]] = None
confidence_score: float = 0.0
needs_follow_up: bool = False
follow_up_question: Optional[str] = None
def __post_init__(self):
self._process_multiple("extracted_source")
self._process_multiple("extracted_district")
def _process_multiple(self, key):
if isinstance(self.__dict__[key], list):
self.__dict__[key] = [d.title() for d in self.__dict__[key]]
else:
self.__dict__[key] = self.__dict__[key].title() if self.__dict__[key] else None
class MultiAgentState(TypedDict):
"""State for the multi-agent conversation flow"""
conversation_id: str
messages: List[Any]
current_query: str
query_context: Optional[QueryContext]
rag_query: Optional[str]
rag_filters: Optional[Dict[str, Any]]
retrieved_documents: Optional[List[Any]]
final_response: Optional[str]
agent_logs: List[str]
conversation_context: Dict[str, Any]
session_start_time: float
last_ai_message_time: float
class BaseMultiAgentChatbot(ABC):
"""
Abstract base class for multi-agent chatbots.
Provides all the sophisticated logic from MultiAgentRAGChatbot:
- LLM-based query analysis
- Filter extraction and validation
- Query rewriting
- Main agent, RAG agent, Response agent
Subclasses only need to implement:
- _perform_retrieval(): The actual retrieval mechanism
"""
def __init__(self, config_path: str = "src/config/settings.yaml"):
"""Initialize the base multi-agent chatbot"""
self.config = load_config(config_path)
# Get LLM provider from config
reader_config = self.config.get("reader", {})
default_type = reader_config.get("default_type", "INF_PROVIDERS")
provider_name = default_type.lower()
self.llm_adapter = get_llm_client(provider_name, self.config)
# Create LangChain-compatible wrapper
class LLMWrapper:
def __init__(self, adapter):
self.adapter = adapter
def invoke(self, messages):
if isinstance(messages, list):
formatted_messages = []
for msg in messages:
if hasattr(msg, 'content'):
role = "user" if msg.__class__.__name__ == "HumanMessage" else "assistant"
formatted_messages.append({"role": role, "content": msg.content})
else:
formatted_messages.append({"role": "user", "content": str(msg)})
else:
formatted_messages = [{"role": "user", "content": str(messages)}]
response = self.adapter.generate(formatted_messages)
class MockResponse:
def __init__(self, content):
self.content = content
return MockResponse(response.content)
self.llm = LLMWrapper(self.llm_adapter)
# Load dynamic data (filter options)
self._load_dynamic_data()
# Build the multi-agent graph
self.graph = self._build_graph()
# Conversations directory
self.conversations_dir = CONVERSATIONS_DIR
try:
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
except (PermissionError, OSError) as e:
logger.warning(f"Could not create conversations directory at {self.conversations_dir}: {e}")
self.conversations_dir = Path("conversations")
try:
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
except (PermissionError, OSError) as e2:
logger.error(f"Could not create conversations directory at {self.conversations_dir}: {e2}")
raise RuntimeError(f"Failed to create conversations directory: {e2}")
logger.info("🤖 Base Multi-Agent Chatbot initialized")
def _load_dynamic_data(self):
"""Load dynamic data from filter_options.json"""
try:
fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
if fo.exists():
with open(fo) as f:
data = json.load(f)
self.year_whitelist = [str(y).strip() for y in data.get("years", [])]
self.source_whitelist = [str(s).strip() for s in data.get("sources", [])]
self.district_whitelist = [str(d).strip() for d in data.get("districts", [])]
else:
self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
except Exception as e:
logger.warning(f"Could not load filter options: {e}")
self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
# Enrich district list
try:
from add_district_metadata import DistrictMetadataProcessor
proc = DistrictMetadataProcessor()
names = set()
for key, mapping in proc.district_mappings.items():
if getattr(mapping, 'is_district', True):
names.add(mapping.name)
if names:
merged = list(self.district_whitelist)
for n in sorted(names):
if n not in merged:
merged.append(n)
self.district_whitelist = merged
logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
except Exception as e:
logger.info(f"ℹ️ Could not enrich districts: {e}")
# Calculate current year dynamically
self.current_year = str(datetime.now().year)
self.previous_year = str(datetime.now().year - 1)
logger.info(f"📊 ACTUAL FILTER VALUES:")
logger.info(f" Years: {self.year_whitelist}")
logger.info(f" Sources: {self.source_whitelist}")
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 30: {self.district_whitelist[:30]})")
def _normalize_district_name(self, district: str) -> Optional[str]:
"""Normalize district name with fuzzy matching - ALWAYS returns title case for Qdrant compatibility"""
if not district:
return None
district = district.strip()
district_title = district.title()
# Check if district exists in whitelist (case-insensitive)
district_lower = district.lower()
whitelist_lower = {d.lower(): d for d in self.district_whitelist}
# Direct match (case-insensitive) - always return title case
if district_lower in whitelist_lower:
return district_title # Return title case, not whitelist value
# Remove "District" suffix and try again
district_name = district.replace(" District", "").replace(" district", "").strip()
district_name_lower = district_name.lower()
district_name_title = district_name.title()
if district_name_lower in whitelist_lower:
return district_name_title # Return title case
# Common misspellings and abbreviations - return correct case
misspelling_map = {
"kalagala": "Kalangala",
"kalangala": "Kalangala",
"gulu": "Gulu",
"kampala": "Kampala",
"padr": "Pader",
"padre": "Pader",
"pader": "Pader",
"kcc": "Kcca", # Match whitelist format
"kcca": "Kcca", # Match whitelist format
"kimboga": "Kiboga",
"kiboga": "Kiboga",
"jinja": "Jinja",
"mbale": "Mbale",
"mbarara": "Mbarara",
"soroti": "Soroti",
"lira": "Lira",
"arua": "Arua",
"masaka": "Masaka",
"fort portal": "Fort Portal",
"fortportal": "Fort Portal",
}
if district_name_lower in misspelling_map:
return misspelling_map[district_name_lower] # Already title case
# Fuzzy matching (case-insensitive) - return title case
for whitelist_district in self.district_whitelist:
if district_name_lower == whitelist_district.lower():
return district_name_title # Return title case
if len(district_name) >= 4 and len(whitelist_district) >= 4:
if district_name_lower in whitelist_district.lower() or whitelist_district.lower() in district_name_lower:
min_len = min(len(district_name), len(whitelist_district))
max_len = max(len(district_name), len(whitelist_district))
if min_len / max_len >= 0.8:
return district_name_title # Return title case
# Last resort: if input looks valid, return title case anyway
# This handles cases where whitelist might be incomplete
if len(district_name) >= 3:
return district_name_title
return None
def _build_graph(self) -> StateGraph:
"""Build the multi-agent LangGraph"""
graph = StateGraph(MultiAgentState)
# Add nodes for each agent
graph.add_node("main_agent", self._main_agent)
graph.add_node("rag_agent", self._rag_agent)
graph.add_node("response_agent", self._response_agent)
# Define the flow
graph.set_entry_point("main_agent")
# Main agent decides next step
graph.add_conditional_edges(
"main_agent",
self._should_call_rag,
{
"follow_up": END,
"call_rag": "rag_agent"
}
)
# RAG agent calls response agent
graph.add_edge("rag_agent", "response_agent")
# Response agent returns to main agent
graph.add_edge("response_agent", "main_agent")
return graph.compile()
def _should_call_rag(self, state: MultiAgentState) -> str:
"""Determine if we should call RAG or ask follow-up"""
if state.get("final_response"):
return "follow_up"
context = state["query_context"]
if context and context.needs_follow_up:
return "follow_up"
return "call_rag"
def _main_agent(self, state: MultiAgentState) -> MultiAgentState:
"""Main Agent: Handles conversation flow and follow-ups"""
logger.info("🎯 MAIN AGENT: Starting analysis")
if state.get("final_response"):
logger.info("🎯 MAIN AGENT: Final response already exists, ending")
return state
query = state["current_query"]
messages = state["messages"]
logger.info(f"🎯 MAIN AGENT: Extracting UI filters from query")
ui_filters = self._extract_ui_filters(query)
logger.info(f"🎯 MAIN AGENT: UI filters extracted: {ui_filters}")
# Analyze query context using LLM
logger.info(f"🎯 MAIN AGENT: Analyzing query context")
context = self._analyze_query_context(query, messages, ui_filters)
state["agent_logs"].append(f"MAIN AGENT: Context analyzed - district={context.has_district}, source={context.has_source}, year={context.has_year}")
logger.info(f"🎯 MAIN AGENT: Context analysis complete")
state["query_context"] = context
# If follow-up needed, generate response
if context.needs_follow_up:
logger.info(f"🎯 MAIN AGENT: Follow-up needed, generating question")
response = context.follow_up_question
state["final_response"] = response
state["last_ai_message_time"] = time.time()
else:
logger.info("🎯 MAIN AGENT: No follow-up needed, proceeding to RAG")
return state
def _rag_agent(self, state: MultiAgentState) -> MultiAgentState:
"""RAG Agent: Rewrites queries and applies filters"""
logger.info("🔍 RAG AGENT: Starting query rewriting and filter preparation")
context = state["query_context"]
messages = state["messages"]
# Rewrite query for RAG
logger.info(f"🔍 RAG AGENT: Rewriting query for optimal retrieval")
rag_query = self._rewrite_query_for_rag(messages, context)
logger.info(f"🔍 RAG AGENT: Query rewritten: '{rag_query}'")
# Build filters
logger.info(f"🔍 RAG AGENT: Building filters from context: {context}")
filters = self._build_filters(context)
logger.info(f"🔍 RAG AGENT: Filters built: {filters}")
state["agent_logs"].append(f"RAG AGENT: Query='{rag_query}', Filters={filters}")
state["rag_query"] = rag_query
state["rag_filters"] = filters
return state
def _response_agent(self, state: MultiAgentState) -> MultiAgentState:
"""Response Agent: Generates final answer from retrieved documents"""
logger.info("📝 RESPONSE AGENT: Starting document retrieval and answer generation")
rag_query = state["rag_query"]
filters = state["rag_filters"]
logger.info(f"📝 RESPONSE AGENT: Calling retrieval with query: '{rag_query}'")
logger.info(f"📝 RESPONSE AGENT: Using filters: {filters}")
try:
# Call subclass-specific retrieval method
result = self._perform_retrieval(rag_query, filters)
state["retrieved_documents"] = result.sources
state["agent_logs"].append(f"RESPONSE AGENT: Retrieved {len(result.sources)} documents")
logger.info(f"📝 RESPONSE AGENT: Retrieved {len(result.sources)} documents")
# Check highest similarity score
highest_score = 0.0
if result.sources:
for doc in result.sources:
score = getattr(doc, 'metadata', {}).get('reranked_score') or getattr(doc, 'metadata', {}).get('original_score', 0.0) if hasattr(doc, 'metadata') else getattr(doc, 'score', 0.0)
if score > highest_score:
highest_score = score
logger.info(f"📝 RESPONSE AGENT: Highest similarity score: {highest_score:.4f}")
# If highest score is too low, use LLM knowledge only
if highest_score <= 0.15:
logger.warning(f"⚠️ RESPONSE AGENT: Low similarity score, using LLM knowledge only")
response = self._generate_conversational_response_without_docs(
state["current_query"],
state["messages"]
)
else:
# Generate conversational response with documents
response = self._generate_conversational_response(
state["current_query"],
result.sources,
result.answer,
state["messages"],
filters # Pass filters for coverage validation
)
state["final_response"] = response
state["last_ai_message_time"] = time.time()
logger.info(f"📝 RESPONSE AGENT: Answer generation complete")
except Exception as e:
logger.error(f"❌ RESPONSE AGENT ERROR: {e}")
traceback.print_exc()
state["final_response"] = "I apologize, but I encountered an error while retrieving information. Please try again."
state["last_ai_message_time"] = time.time()
return state
@abstractmethod
def _perform_retrieval(self, query: str, filters: Dict[str, Any]) -> Any:
"""
Perform retrieval - must be implemented by subclasses.
Args:
query: The rewritten query
filters: The filters to apply
Returns:
Result object with .sources and .answer attributes
"""
pass
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
"""Extract UI filters from query"""
filters = {}
if "FILTER CONTEXT:" in query:
filter_section = query.split("FILTER CONTEXT:")[1]
if "USER QUERY:" in filter_section:
filter_section = filter_section.split("USER QUERY:")[0]
filter_section = filter_section.strip()
if "Sources:" in filter_section:
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')][0]
sources_str = sources_line.split("Sources:")[1].strip()
if sources_str and sources_str != "None":
filters["sources"] = [s.strip() for s in sources_str.split(",")]
if "Years:" in filter_section:
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')][0]
years_str = years_line.split("Years:")[1].strip()
if years_str and years_str != "None":
filters["years"] = [y.strip() for y in years_str.split(",")]
if "Districts:" in filter_section:
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')][0]
districts_str = districts_line.split("Districts:")[1].strip()
if districts_str and districts_str != "None":
filters["districts"] = [d.strip() for d in districts_str.split(",")]
if "Filenames:" in filter_section:
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')][0]
filenames_str = filenames_line.split("Filenames:")[1].strip()
if filenames_str and filenames_str != "None":
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
return filters
def _analyze_query_context(self, query: str, messages: List[Any], ui_filters: Dict[str, List[str]]) -> QueryContext:
"""Analyze query context using LLM - EXACT COPY FROM v1"""
logger.info(f"🔍 QUERY ANALYSIS: '{query[:50]}...' | UI filters: {ui_filters}")
# Build conversation context
conversation_context = ""
for msg in messages[-6:]:
if isinstance(msg, HumanMessage):
conversation_context += f"User: {msg.content}\n"
elif isinstance(msg, AIMessage):
conversation_context += f"Assistant: {msg.content}\n"
# Create analysis prompt - ENHANCED FOR BETTER EXTRACTION
analysis_prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=f"""You are the Main Agent in an advanced multi-agent RAG system for audit report analysis.
🎯 PRIMARY GOAL: Intelligently analyze user queries and determine the optimal conversation flow, whether that's answering directly, asking follow-ups, or proceeding to RAG retrieval.
🧠 INTELLIGENCE LEVEL: You are a sophisticated conversational AI that can handle any type of user interaction - from greetings to complex audit queries.
📊 YOUR EXPERTISE: You specialize in analyzing audit reports from various sources (Local Government, Ministry, Hospital, etc.) across different years and districts in Uganda.
🔍 AVAILABLE FILTERS:
- Years: {', '.join(self.year_whitelist)}
- Current year: {self.current_year}, Previous year: {self.previous_year}
- Sources: {', '.join(self.source_whitelist)}
- Districts: {', '.join(self.district_whitelist[:50])}... (and {len(self.district_whitelist)-50} more)
🎛️ UI FILTERS PROVIDED: {ui_filters}
📋 UI FILTER HANDLING:
- If UI filters contain multiple values, extract ALL values
- UI filters take PRIORITY over conversation context
⚠️ CRITICAL EXTRACTION RULES:
1. **RELATIVE YEAR REFERENCES** - Convert to explicit years:
- "last couple of years" / "last 2 years" → [{self.previous_year}, {str(int(self.previous_year)-1)}] (2 years)
- "last few years" / "last 3 years" → [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}] (3 years)
- "recent years" → [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}]
- "this year" → ["{self.current_year}"]
- "last year" → ["{self.previous_year}"]
2. **DISTRICT TYPOS & ABBREVIATIONS** - Correct common mistakes:
- "KCC" or "KCCA" → "KCCA" (Kampala Capital City Authority)
- "Padr" or "Padre" → "Pader"
- "Kimboga" → "Kiboga"
- "Kalagala" → "Kalangala"
3. **MULTIPLE VALUES** - Extract ALL mentioned items:
- If user says "Kampala, Kiboga, and Pader" → extract ALL THREE districts
- If user says "2022, 2023, 2024" → extract ALL THREE years
- Use "+" or "and" or "," as separators
🧭 CONVERSATION FLOW INTELLIGENCE:
1. **GREETINGS & GENERAL CHAT**:
- If user greets you, respond warmly and guide them
2. **AUDIT QUERIES**:
- Extract values matching the available lists (with typo correction)
- DO NOT hallucinate values not mentioned by user
3. **SMART FOLLOW-UP STRATEGY**:
- If user provides 2+ pieces of info, proceed to RAG
- If user provides 1 piece of info, ask for missing piece
- If user provides 0 pieces of info, ask for clarification
- NEVER ask the same question twice
🎯 DECISION LOGIC:
- If query is a greeting/general chat → needs_follow_up: true
- If query has 2+ pieces of info → needs_follow_up: false, proceed to RAG
- If query has 1 piece of info → needs_follow_up: true, ask for missing piece
- If query has 0 pieces of info → needs_follow_up: true, ask for clarification
RESPOND WITH JSON ONLY:
{{
"has_district": boolean,
"has_source": boolean,
"has_year": boolean,
"extracted_district": "single or array or null",
"extracted_source": "single or array or null",
"extracted_year": "single or array or null",
"confidence_score": 0.0-1.0,
"needs_follow_up": boolean,
"follow_up_question": "question or null"
}}"""),
HumanMessage(content=f"""Query: {query}
Conversation Context:
{conversation_context}
CRITICAL: Analyze the FULL conversation context above.
Analyze this query using ONLY the exact values provided above:""")
])
try:
response = self.llm.invoke(analysis_prompt.format_messages())
# Clean and parse JSON
content = response.content.strip()
if content.startswith("```json"):
content = content.replace("```json", "").replace("```", "").strip()
elif content.startswith("```"):
content = content.replace("```", "").strip()
# Remove comments
content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
analysis = json.loads(content)
logger.info(f"🔍 QUERY ANALYSIS: ✅ Parsed successfully")
# Validate extracted values (same logic as v1)
extracted_district = analysis.get("extracted_district")
extracted_source = analysis.get("extracted_source")
extracted_year = analysis.get("extracted_year")
# Validate district
if extracted_district:
if isinstance(extracted_district, list):
valid_districts = []
for district in extracted_district:
normalized = self._normalize_district_name(district)
if normalized:
valid_districts.append(normalized)
extracted_district = valid_districts[0] if len(valid_districts) == 1 else (valid_districts if valid_districts else None)
else:
extracted_district = self._normalize_district_name(extracted_district)
# Validate source
if extracted_source:
if isinstance(extracted_source, list):
valid_sources = [s for s in extracted_source if s in self.source_whitelist]
extracted_source = valid_sources[0] if len(valid_sources) == 1 else (valid_sources if valid_sources else None)
else:
extracted_source = extracted_source if extracted_source in self.source_whitelist else None
# Validate year
if extracted_year:
if isinstance(extracted_year, list):
valid_years = [str(y) for y in extracted_year if str(y) in self.year_whitelist]
extracted_year = valid_years[0] if len(valid_years) == 1 else (valid_years if valid_years else None)
else:
extracted_year = str(extracted_year) if str(extracted_year) in self.year_whitelist else None
# Create QueryContext
context = QueryContext(
has_district=bool(extracted_district),
has_source=bool(extracted_source),
has_year=bool(extracted_year),
extracted_district=extracted_district,
extracted_source=extracted_source,
extracted_year=extracted_year,
ui_filters=ui_filters,
confidence_score=analysis.get("confidence_score", 0.0),
needs_follow_up=analysis.get("needs_follow_up", False),
follow_up_question=analysis.get("follow_up_question")
)
# If filenames provided, skip follow-ups
if ui_filters and ui_filters.get("filenames"):
context.needs_follow_up = False
context.follow_up_question = None
# Smart decision logic (same as v1)
if context.needs_follow_up:
info_count = sum([bool(context.extracted_district), bool(context.extracted_source), bool(context.extracted_year)])
query_lower = query.lower()
is_requesting_info = any(phrase in query_lower for phrase in [
"please provide", "could you provide", "can you provide",
"what is", "what are", "how much", "which", "what year",
"what district", "what source", "tell me about", "how were", "how was"
])
if info_count >= 2 and not is_requesting_info:
context.needs_follow_up = False
context.follow_up_question = None
elif info_count >= 2 and is_requesting_info:
context.needs_follow_up = False
context.follow_up_question = None
return context
except Exception as e:
logger.error(f"❌ Query analysis failed: {e}")
return QueryContext(
has_district=bool(ui_filters.get("districts")),
has_source=bool(ui_filters.get("sources")),
has_year=bool(ui_filters.get("years")),
ui_filters=ui_filters,
confidence_score=0.5,
needs_follow_up=False
)
def _rewrite_query_for_rag(self, messages: List[Any], context: QueryContext) -> str:
"""Rewrite query for optimal RAG retrieval - EXACT COPY FROM v1"""
logger.info("🔄 QUERY REWRITING: Starting")
# Build conversation context
conversation_lines = []
for msg in messages[-6:]:
if isinstance(msg, HumanMessage):
conversation_lines.append(f"User: {msg.content}")
elif isinstance(msg, AIMessage):
conversation_lines.append(f"Assistant: {msg.content}")
convo_text = "\n".join(conversation_lines)
# Create rewrite prompt
rewrite_prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""You are a query rewriter for RAG retrieval.
GOAL: Create the best possible search query for document retrieval.
CRITICAL RULES:
1. Focus on the core information need
2. Remove meta-verbs like "summarize", "list", "compare", "how much", "what"
3. DO NOT include filter details (years, districts, sources)
4. Output ONE clear sentence suitable for vector search
EXAMPLES:
- "What are the top challenges in budget allocation?" → "budget allocation challenges"
- "How were PDM administrative costs utilized?" → "PDM administrative costs utilization"
OUTPUT FORMAT:
EXPLANATION: [reasoning]
QUERY: [one clean sentence]"""),
HumanMessage(content=f"""Conversation:
{convo_text}
Rewrite the best retrieval query:""")
])
try:
response = self.llm.invoke(rewrite_prompt.format_messages())
rewritten = response.content.strip()
# Extract QUERY line
lines = rewritten.split('\n')
for line in lines:
if line.strip().startswith('QUERY:'):
query_line = line.replace('QUERY:', '').strip()
if len(query_line) > 5:
return query_line
# Fallback
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return msg.content
return "audit report information"
except Exception as e:
logger.error(f"❌ QUERY REWRITING: Error: {e}")
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return msg.content
return "audit report information"
def _build_filters(self, context: QueryContext) -> Dict[str, Any]:
"""Build filters for RAG retrieval"""
logger.info(f"🔧 FILTER BUILDING: Building filters from context: {context}")
filters = {}
# Check for filename filtering first
if context.ui_filters and context.ui_filters.get("filenames"):
filters["filenames"] = context.ui_filters["filenames"]
logger.info(f"🔧 FILTER BUILDING: Using filename filter: {filters}")
return filters
# UI filters take priority
if context.ui_filters:
if context.ui_filters.get("sources"):
filters["sources"] = context.ui_filters["sources"]
if context.ui_filters.get("years"):
filters["year"] = context.ui_filters["years"]
if context.ui_filters.get("districts"):
# Title case for Qdrant compatibility
normalized_districts = [d.title() for d in context.ui_filters['districts']]
filters["district"] = normalized_districts
# Merge with extracted context
if not filters.get("district") and context.extracted_district:
if isinstance(context.extracted_district, list):
# Normalize each district - _normalize_district_name returns correct case
normalized = [self._normalize_district_name(d) for d in context.extracted_district]
filters["district"] = [d for d in normalized if d]
else:
normalized = self._normalize_district_name(context.extracted_district)
if normalized:
filters["district"] = [normalized]
if not filters.get("year") and context.extracted_year:
filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year
if not filters.get("sources") and context.extracted_source:
filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source
else:
# Use extracted context (no UI filters)
if context.extracted_source:
filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source
if context.extracted_year:
filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year
if context.extracted_district:
if isinstance(context.extracted_district, list):
# Normalize each district - _normalize_district_name returns correct case
normalized = [self._normalize_district_name(d) for d in context.extracted_district]
filters["district"] = [d for d in normalized if d]
else:
normalized = self._normalize_district_name(context.extracted_district)
if normalized:
filters["district"] = [normalized]
return filters
@abstractmethod
def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any], filters: Dict[str, Any] = None) -> str:
"""Generate conversational response - must be implemented by subclasses"""
pass
@abstractmethod
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
"""Generate response without documents - must be implemented by subclasses"""
pass
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
"""Main chat interface"""
logger.info(f"💬 MULTI-AGENT CHAT: Processing '{user_input[:50]}...'")
# Load conversation
conversation_file = self.conversations_dir / f"{conversation_id}.json"
conversation = self._load_conversation(conversation_file)
# Add user message
conversation["messages"].append(HumanMessage(content=user_input))
# Prepare state
state = MultiAgentState(
conversation_id=conversation_id,
messages=conversation["messages"],
current_query=user_input,
query_context=None,
rag_query=None,
rag_filters=None,
retrieved_documents=None,
final_response=None,
agent_logs=[],
conversation_context=conversation.get("context", {}),
session_start_time=conversation["session_start_time"],
last_ai_message_time=conversation["last_ai_message_time"]
)
# Run multi-agent graph
final_state = self.graph.invoke(state)
# Add AI response to conversation
if final_state["final_response"]:
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
# Update conversation
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
conversation["context"] = final_state["conversation_context"]
# Save conversation
self._save_conversation(conversation_file, conversation)
# Return response
return {
'response': final_state["final_response"],
'rag_result': {
'sources': final_state["retrieved_documents"] or [],
'answer': final_state["final_response"]
},
'agent_logs': final_state["agent_logs"],
'actual_rag_query': final_state.get("rag_query", "")
}
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
"""Load conversation from file"""
if conversation_file.exists():
try:
with open(conversation_file) as f:
data = json.load(f)
messages = []
for msg_data in data.get("messages", []):
if msg_data["type"] == "human":
messages.append(HumanMessage(content=msg_data["content"]))
elif msg_data["type"] == "ai":
messages.append(AIMessage(content=msg_data["content"]))
data["messages"] = messages
return data
except Exception as e:
logger.warning(f"Could not load conversation: {e}")
return {
"messages": [],
"session_start_time": time.time(),
"last_ai_message_time": time.time(),
"context": {}
}
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
"""Save conversation to file"""
try:
conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
messages_data = []
for msg in conversation["messages"]:
if isinstance(msg, HumanMessage):
messages_data.append({"type": "human", "content": msg.content})
elif isinstance(msg, AIMessage):
messages_data.append({"type": "ai", "content": msg.content})
conversation_data = {
"messages": messages_data,
"session_start_time": conversation["session_start_time"],
"last_ai_message_time": conversation["last_ai_message_time"],
"context": conversation.get("context", {})
}
with open(conversation_file, 'w') as f:
json.dump(conversation_data, f, indent=2)
except Exception as e:
logger.error(f"Could not save conversation: {e}")