Spaces:
Sleeping
Sleeping
| """ | |
| Multi-Agent RAG Chatbot using LangGraph | |
| This is the TEXT-BASED RAG chatbot that inherits from BaseMultiAgentChatbot. | |
| It implements the retrieval using the PipelineManager (Qdrant + text embeddings). | |
| """ | |
| import logging | |
| import traceback | |
| from typing import Dict, List, Any | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from src.agents.base_multi_agent_chatbot import BaseMultiAgentChatbot | |
| from src.pipeline import PipelineManager | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class MultiAgentRAGChatbot(BaseMultiAgentChatbot): | |
| """ | |
| Text-based Multi-agent RAG chatbot. | |
| Inherits all the sophisticated logic from BaseMultiAgentChatbot: | |
| - LLM-based query analysis | |
| - Filter extraction and validation | |
| - Query rewriting | |
| - Main agent, RAG agent, Response agent | |
| Implements: | |
| - _perform_retrieval(): Uses PipelineManager for text-based RAG | |
| - _generate_conversational_response(): Text-focused response generation | |
| - _generate_conversational_response_without_docs(): Fallback response | |
| """ | |
| def __init__(self, config_path: str = "src/config/settings.yaml"): | |
| """Initialize the text-based multi-agent chatbot""" | |
| # Initialize base class first (loads config, LLM, filters, builds graph) | |
| super().__init__(config_path) | |
| # Initialize pipeline manager for text-based retrieval | |
| logger.info("π Initializing pipeline manager and loading models...") | |
| try: | |
| self.pipeline_manager = PipelineManager(self.config) | |
| logger.info("β Pipeline manager initialized and models loaded") | |
| except Exception as e: | |
| logger.error(f"β Failed to initialize pipeline manager: {e}") | |
| traceback.print_exc() | |
| raise RuntimeError(f"Pipeline manager initialization failed: {e}") | |
| # Connect to vector store | |
| logger.info("π Connecting to vector store...") | |
| try: | |
| if not self.pipeline_manager.connect_vectorstore(): | |
| logger.error("β Failed to connect to vector store") | |
| raise RuntimeError("Vector store connection failed") | |
| logger.info("β Vector store connected successfully") | |
| except RuntimeError: | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Error during vector store connection: {e}") | |
| traceback.print_exc() | |
| raise RuntimeError(f"Vector store connection failed: {e}") | |
| logger.info("π€ Text-based Multi-Agent RAG Chatbot initialized") | |
| def _perform_retrieval(self, query: str, filters: Dict[str, Any]) -> Any: | |
| """ | |
| Perform text-based retrieval using PipelineManager. | |
| Args: | |
| query: The rewritten query | |
| filters: The filters to apply | |
| Returns: | |
| PipelineResult with .sources and .answer | |
| """ | |
| logger.info(f"π TEXT RETRIEVAL: Query='{query}', Filters={filters}") | |
| result = self.pipeline_manager.run( | |
| query=query, | |
| sources=filters.get("sources") if filters else None, | |
| auto_infer_filters=False, | |
| filters=filters if filters else None | |
| ) | |
| logger.info(f"π TEXT RETRIEVAL: Retrieved {len(result.sources)} documents") | |
| return result | |
| def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any], filters: Dict[str, Any] = None) -> str: | |
| """Generate conversational response from RAG results""" | |
| logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation") | |
| logger.info(f"π¬ RESPONSE GENERATION: Filters for validation: {filters}") | |
| # Build conversation history context | |
| conversation_context = self._build_conversation_context(messages) | |
| # Build detailed document information | |
| document_details = self._build_document_details(documents) | |
| # Extract correct district/source/year names from documents | |
| correct_names = self._extract_correct_names_from_documents(documents) | |
| # Create response prompt | |
| response_prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response. | |
| CRITICAL RULES - NO HALLUCINATION: | |
| 1. **ONLY use information from the retrieved documents provided below** | |
| 2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference** | |
| 3. **If a document doesn't contain the information, DO NOT make it up** | |
| 4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that** | |
| 5. **Check the document years/districts before making any claims about them** | |
| 6. **USE CORRECT NAMES**: Use the CORRECT spelling from the document metadata | |
| RULES: | |
| 1. Answer the user's question directly and clearly | |
| 2. Use ONLY the retrieved documents as evidence - DO NOT use your training data | |
| 3. Be conversational, not technical | |
| 4. Don't mention scores, retrieval details, or technical implementation | |
| 5. If relevant documents were found, reference them naturally | |
| 6. If no relevant documents, say you do not have enough information - DO NOT hallucinate | |
| 7. If the passages have useful facts or numbers, use them in your answer WITH references | |
| 8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document. | |
| 9. Do not use the sentence 'Doc i says ...' to say where information came from. | |
| 10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k] | |
| 11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation. | |
| 12. If it makes sense, use bullet points and lists to make your answers easier to understand. | |
| 13. You do not need to use every passage. Only use the ones that help answer the question. | |
| 14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents." | |
| 15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020." | |
| TONE: Professional but friendly."""), | |
| HumanMessage(content=f"""Conversation History: | |
| {conversation_context} | |
| Current User Question: {query} | |
| Retrieved Documents: {len(documents)} documents found | |
| CORRECT NAMES TO USE (from document metadata): | |
| {correct_names} | |
| Full Document Details: | |
| {document_details} | |
| RAG Answer: {rag_answer} | |
| CRITICAL: | |
| - Responses should be grounded to what is available in the retrieved documents | |
| - If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..." | |
| - Every factual claim MUST have [Doc i] reference | |
| - If information is not in documents, explicitly state it's not available | |
| Generate a conversational response with proper document references:""") | |
| ]) | |
| try: | |
| response = self.llm.invoke(response_prompt.format_messages()) | |
| # Post-process response to ensure no hallucination | |
| final_response = self._validate_and_enhance_response( | |
| response.content.strip(), | |
| documents, | |
| query, | |
| filters # Pass filters for coverage validation | |
| ) | |
| return final_response | |
| except Exception as e: | |
| logger.error(f"β RESPONSE GENERATION: Error: {e}") | |
| return rag_answer | |
| def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str: | |
| """Generate conversational response using only LLM knowledge""" | |
| logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents") | |
| # Build conversation context | |
| conversation_context = "" | |
| for msg in messages[-6:]: | |
| if isinstance(msg, HumanMessage): | |
| conversation_context += f"User: {msg.content}\n" | |
| elif isinstance(msg, AIMessage): | |
| conversation_context += f"Assistant: {msg.content}\n" | |
| response_prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response. | |
| RULES: | |
| 1. Answer the user's question directly and clearly based on your knowledge | |
| 2. Use conversation history for context | |
| 3. Be conversational, not technical | |
| 4. Acknowledge if the answer is based on general knowledge rather than specific documents | |
| 5. Stay professional but friendly | |
| TONE: Professional but friendly."""), | |
| HumanMessage(content=f"""Current Question: {query} | |
| Conversation History: | |
| {conversation_context} | |
| Generate a conversational response based on your knowledge:""") | |
| ]) | |
| try: | |
| response = self.llm.invoke(response_prompt.format_messages()) | |
| return response.content.strip() | |
| except Exception as e: | |
| logger.error(f"β RESPONSE GENERATION (NO DOCS): Error: {e}") | |
| return "I apologize, but I encountered an error. Please try asking your question differently." | |
| # ==================== HELPER METHODS ==================== | |
| def _build_conversation_context(self, messages: List[Any]) -> str: | |
| """Build conversation history context for response generation.""" | |
| if not messages: | |
| return "No previous conversation." | |
| context_lines = [] | |
| for msg in messages[-6:]: | |
| if isinstance(msg, HumanMessage): | |
| context_lines.append(f"User: {msg.content}") | |
| elif isinstance(msg, AIMessage): | |
| context_lines.append(f"Assistant: {msg.content}") | |
| return "\n".join(context_lines) if context_lines else "No previous conversation." | |
| def _build_document_details(self, documents: List[Any]) -> str: | |
| """Build detailed document information for response generation.""" | |
| if not documents: | |
| return "No documents retrieved." | |
| details = [] | |
| for i, doc in enumerate(documents[:15], 1): | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {}) | |
| content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '') | |
| if isinstance(metadata, dict): | |
| filename = metadata.get('filename', 'Unknown') | |
| year = metadata.get('year', 'Unknown') | |
| district = metadata.get('district', 'Unknown') | |
| source = metadata.get('source', 'Unknown') | |
| page = metadata.get('page', metadata.get('page_label', 'Unknown')) | |
| doc_info = f"[Doc {i}]" | |
| doc_info += f"\n Filename: {filename}" | |
| doc_info += f"\n Year: {year}" | |
| doc_info += f"\n District: {district}" | |
| doc_info += f"\n Source: {source}" | |
| if page != 'Unknown': | |
| doc_info += f"\n Page: {page}" | |
| doc_info += f"\n Content: {content[:500]}{'...' if len(content) > 500 else ''}" | |
| details.append(doc_info) | |
| return "\n\n".join(details) if details else "No document details available." | |
| def _extract_correct_names_from_documents(self, documents: List[Any]) -> str: | |
| """Extract correct district/source names from documents to correct misspellings.""" | |
| districts = set() | |
| sources = set() | |
| years = set() | |
| for doc in documents: | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {}) | |
| if isinstance(metadata, dict): | |
| if metadata.get('district'): | |
| districts.add(str(metadata['district'])) | |
| if metadata.get('source'): | |
| sources.add(str(metadata['source'])) | |
| if metadata.get('year'): | |
| years.add(str(metadata['year'])) | |
| result = [] | |
| if districts: | |
| result.append(f"Districts: {', '.join(sorted(districts))}") | |
| if sources: | |
| result.append(f"Sources: {', '.join(sorted(sources))}") | |
| if years: | |
| result.append(f"Years: {', '.join(sorted(years))}") | |
| return "\n".join(result) if result else "No metadata available." | |
| def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str, filters: Dict[str, Any] = None) -> str: | |
| """Validate response and ensure all claims are referenced. | |
| Compares REQUESTED filters against RETRIEVED document metadata to identify gaps. | |
| """ | |
| import re | |
| # Extract years and districts from RETRIEVED documents | |
| doc_years = set() | |
| doc_districts = set() | |
| for doc in documents: | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {}) | |
| if isinstance(metadata, dict): | |
| if metadata.get('year'): | |
| doc_years.add(str(metadata['year'])) | |
| if metadata.get('district'): | |
| doc_districts.add(str(metadata['district'])) | |
| logger.info(f"π VALIDATION: Retrieved docs cover years={doc_years}, districts={doc_districts}") | |
| warnings = [] | |
| # Get REQUESTED filters | |
| requested_years = set() | |
| requested_districts = set() | |
| if filters: | |
| if filters.get('year'): | |
| requested_years = set(str(y) for y in filters['year']) if isinstance(filters['year'], list) else {str(filters['year'])} | |
| if filters.get('district'): | |
| requested_districts = set(str(d) for d in filters['district']) if isinstance(filters['district'], list) else {str(filters['district'])} | |
| logger.info(f"π VALIDATION: Requested years={requested_years}, districts={requested_districts}") | |
| # Compare requested vs retrieved YEARS | |
| if requested_years and doc_years: | |
| missing_years = requested_years - doc_years | |
| if missing_years: | |
| warnings.append(f"You requested data for years {', '.join(sorted(requested_years))}, but the retrieved documents only cover {', '.join(sorted(doc_years))}. Data for {', '.join(sorted(missing_years))} may not be available in the database.") | |
| elif requested_years and not doc_years: | |
| warnings.append(f"You requested data for years {', '.join(sorted(requested_years))}, but no documents were retrieved with year metadata.") | |
| # Compare requested vs retrieved DISTRICTS | |
| if requested_districts and doc_districts: | |
| # Normalize for comparison (case-insensitive) | |
| requested_districts_lower = {d.lower() for d in requested_districts} | |
| doc_districts_lower = {d.lower() for d in doc_districts} | |
| missing_districts_lower = requested_districts_lower - doc_districts_lower | |
| if missing_districts_lower: | |
| # Get original case versions for display | |
| missing_districts = [d for d in requested_districts if d.lower() in missing_districts_lower] | |
| warnings.append(f"You requested data for districts {', '.join(sorted(requested_districts))}, but the retrieved documents only cover {', '.join(sorted(doc_districts))}. Data for {', '.join(sorted(missing_districts))} may not be available in the database.") | |
| elif requested_districts and not doc_districts: | |
| warnings.append(f"You requested data for districts {', '.join(sorted(requested_districts))}, but no documents were retrieved with district metadata.") | |
| # Fallback: Check query text for explicit years not in documents (for cases without filters) | |
| if not requested_years: | |
| year_pattern = r'\b(20\d{2})\b' | |
| query_years = set(re.findall(year_pattern, query)) | |
| missing_years = query_years - doc_years | |
| if missing_years and doc_years: | |
| warnings.append(f"The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents.") | |
| # Add warnings to response if any | |
| if warnings and "β οΈ" not in response: | |
| warning_text = "\n\nβ οΈ **Note:** " + " ".join(warnings) | |
| response = response + warning_text | |
| logger.info(f"π VALIDATION: Added warning: {warning_text}") | |
| return response | |
| def get_multi_agent_chatbot(): | |
| """Get multi-agent chatbot instance""" | |
| return MultiAgentRAGChatbot() | |
| if __name__ == "__main__": | |
| # Test the multi-agent system | |
| chatbot = MultiAgentRAGChatbot() | |
| # Test conversation | |
| result = chatbot.chat("List me top 10 challenges in budget allocation for the last 3 years") | |
| print("Response:", result['response']) | |
| print("Agent Logs:", result['agent_logs']) | |