Spaces:
Sleeping
Sleeping
| """ | |
| Visual Chatbot - Integrates ColPali visual search with LLM | |
| This chatbot uses visual document retrieval (ColPali) instead of traditional | |
| text-based RAG, then generates responses using an LLM. | |
| """ | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| import os | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_openai import ChatOpenAI | |
| from src.colpali.visual_search import VisualSearchAdapter, create_visual_search_adapter | |
| logger = logging.getLogger(__name__) | |
| class VisualChatbot: | |
| """ | |
| Chatbot that uses visual document retrieval (ColPali) for RAG. | |
| Flow: | |
| 1. User query → Visual search (ColPali embeddings) | |
| 2. Retrieved visual documents → Context | |
| 3. Context + Query → LLM → Response | |
| """ | |
| def __init__( | |
| self, | |
| visual_search: VisualSearchAdapter, | |
| llm_model: str = "gpt-4o-mini", | |
| top_k: int = 10, | |
| temperature: float = 0.1 | |
| ): | |
| """ | |
| Initialize visual chatbot. | |
| Args: | |
| visual_search: Visual search adapter | |
| llm_model: LLM model to use | |
| top_k: Number of documents to retrieve | |
| temperature: LLM temperature | |
| """ | |
| self.visual_search = visual_search | |
| self.top_k = top_k | |
| # Initialize LLM | |
| logger.info(f"🤖 Initializing LLM: {llm_model}") | |
| self.llm = ChatOpenAI( | |
| model=llm_model, | |
| temperature=temperature, | |
| api_key=os.environ.get("OPENAI_API_KEY") | |
| ) | |
| logger.info("✅ Visual Chatbot initialized!") | |
| def chat( | |
| self, | |
| query: str, | |
| conversation_id: str, | |
| filters: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Process a chat query using visual retrieval. | |
| Args: | |
| query: User query | |
| conversation_id: Conversation ID (for tracking) | |
| filters: Optional filters (parsed from query if present) | |
| Returns: | |
| Dictionary with: | |
| - response: LLM response | |
| - rag_result: Visual search results | |
| - actual_rag_query: The query used for retrieval | |
| """ | |
| logger.info(f"💬 Visual chat (conv={conversation_id}): '{query[:100]}...'") | |
| # Parse filters from query if present | |
| parsed_filters = self._parse_filters_from_query(query) | |
| if parsed_filters: | |
| logger.info(f" Parsed filters: {parsed_filters}") | |
| # Extract clean query without filter context | |
| clean_query = self._extract_clean_query(query) | |
| else: | |
| clean_query = query | |
| parsed_filters = filters or {} | |
| # Perform visual search | |
| logger.info(f"🔍 Visual search: '{clean_query}'") | |
| visual_results = self.visual_search.search( | |
| query=clean_query, | |
| top_k=self.top_k, | |
| filters=parsed_filters, | |
| search_strategy="multi_vector" # Use best strategy | |
| ) | |
| # Build context from visual results | |
| context = self._build_context(visual_results) | |
| # Generate response using LLM | |
| logger.info(f"🤖 Generating response with {len(visual_results)} visual documents") | |
| response = self._generate_response(clean_query, context) | |
| # Return in format expected by app.py | |
| return { | |
| 'response': response, | |
| 'rag_result': { | |
| 'sources': visual_results, | |
| 'query': clean_query, | |
| 'num_results': len(visual_results) | |
| }, | |
| 'actual_rag_query': clean_query | |
| } | |
| def _parse_filters_from_query(self, query: str) -> Dict[str, List[str]]: | |
| """ | |
| Parse filter context from query. | |
| Expected format: | |
| FILTER CONTEXT: | |
| Sources: Source1, Source2 | |
| Years: 2020, 2021 | |
| Districts: District1 | |
| Filenames: file1.pdf, file2.pdf | |
| USER QUERY: | |
| actual query text | |
| """ | |
| filters = {} | |
| if "FILTER CONTEXT:" not in query: | |
| return filters | |
| lines = query.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith("Sources:"): | |
| sources = [s.strip() for s in line.replace("Sources:", "").split(',')] | |
| filters['sources'] = sources | |
| elif line.startswith("Years:"): | |
| years = [int(y.strip()) for y in line.replace("Years:", "").split(',')] | |
| filters['years'] = years | |
| elif line.startswith("Districts:"): | |
| districts = [d.strip() for d in line.replace("Districts:", "").split(',')] | |
| filters['districts'] = districts | |
| elif line.startswith("Filenames:"): | |
| filenames = [f.strip() for f in line.replace("Filenames:", "").split(',')] | |
| filters['filenames'] = filenames | |
| return filters | |
| def _extract_clean_query(self, query: str) -> str: | |
| """Extract the actual query without filter context.""" | |
| if "USER QUERY:" in query: | |
| return query.split("USER QUERY:")[-1].strip() | |
| return query | |
| def _build_context(self, results: List[Any]) -> str: | |
| """ | |
| Build context string from visual search results. | |
| Args: | |
| results: List of VisualSearchResult objects | |
| Returns: | |
| Formatted context string | |
| """ | |
| if not results: | |
| return "No relevant documents found." | |
| context_parts = [] | |
| for i, result in enumerate(results, 1): | |
| # Extract metadata | |
| metadata = result.metadata | |
| filename = metadata.get('filename', 'Unknown') | |
| page_number = metadata.get('page_number', '?') | |
| year = metadata.get('year', 'Unknown') | |
| source = metadata.get('source', 'Unknown') | |
| text = result.page_content | |
| score = result.score | |
| # Format document | |
| doc_str = f""" | |
| Document {i} (Score: {score:.3f}): | |
| Source: {source} | Year: {year} | File: {filename} | Page: {page_number} | |
| Content: | |
| {text} | |
| --- | |
| """ | |
| context_parts.append(doc_str) | |
| return "\n".join(context_parts) | |
| def _generate_response(self, query: str, context: str) -> str: | |
| """ | |
| Generate response using LLM with visual retrieval context. | |
| Args: | |
| query: User query | |
| context: Context from visual search | |
| Returns: | |
| LLM response | |
| """ | |
| # Build prompt | |
| system_prompt = """You are an intelligent assistant helping users analyze audit reports. | |
| You have been provided with relevant document excerpts retrieved using visual document search (ColPali). | |
| These documents were selected based on their visual and semantic similarity to the user's query. | |
| Your task: | |
| 1. Analyze the provided documents carefully | |
| 2. Answer the user's question based ONLY on the information in the documents | |
| 3. Cite specific sources (document number, page, year) when making claims | |
| 4. If the documents don't contain enough information, say so clearly | |
| 5. Be concise but comprehensive | |
| Remember: The documents were retrieved using advanced visual search, so they may contain tables, figures, or structured data that is highly relevant.""" | |
| user_prompt = f"""Context from visual document search: | |
| {context} | |
| User Question: {query} | |
| Please provide a detailed answer based on the documents above. Cite your sources.""" | |
| # Generate response | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| response = self.llm.invoke(messages) | |
| return response.content | |
| def get_visual_chatbot() -> VisualChatbot: | |
| """ | |
| Factory function to create a visual chatbot. | |
| Uses the same QDRANT_URL and QDRANT_API_KEY as the colpali_colab_package, | |
| but connects to the 'colSmol-500M' collection instead of v1's collections. | |
| Returns: | |
| Initialized VisualChatbot | |
| """ | |
| logger.info("🎨 Creating Visual Chatbot...") | |
| # Check for ColPali cluster credentials in .env file | |
| # Try multiple possible env var names | |
| qdrant_url = ( | |
| os.environ.get("QDRANT_URL_AKRYL") or # Your .env has this | |
| os.environ.get("DEST_QDRANT_URL") or # Your .env has this too | |
| os.environ.get("QDRANT_URL") # Fallback | |
| ) | |
| qdrant_api_key = ( | |
| os.environ.get("QDRANT_API_KEY_AKRYL") or # Your .env has this | |
| os.environ.get("DEST_QDRANT_API_KEY") or # Your .env has this too | |
| os.environ.get("QDRANT_API_KEY") # Fallback | |
| ) | |
| if not qdrant_url or not qdrant_api_key: | |
| raise ValueError( | |
| "Visual mode requires Qdrant credentials for the ColPali cluster.\n" | |
| "Please set one of these in your .env file:\n" | |
| " - QDRANT_URL_AKRYL and QDRANT_API_KEY_AKRYL\n" | |
| " - DEST_QDRANT_URL and DEST_QDRANT_API_KEY\n" | |
| " - QDRANT_URL and QDRANT_API_KEY\n\n" | |
| "These should point to the cluster containing the 'colSmol-500M' collection." | |
| ) | |
| logger.info(f" Using Qdrant URL: {qdrant_url}") | |
| logger.info(f" Collection: colSmol-500M") | |
| # Create visual search adapter with explicit credentials | |
| visual_search = VisualSearchAdapter( | |
| qdrant_url=qdrant_url, | |
| qdrant_api_key=qdrant_api_key, | |
| collection_name="colSmol-500M" | |
| ) | |
| # Get LLM config from settings.yaml | |
| from src.config.loader import load_config | |
| config = load_config("src/config/settings.yaml") | |
| reader_config = config.get('reader', {}) | |
| openai_config = reader_config.get('OPENAI', {}) | |
| llm_model = openai_config.get('model', 'gpt-4o-mini') | |
| # Create chatbot | |
| chatbot = VisualChatbot( | |
| visual_search=visual_search, | |
| llm_model=llm_model, | |
| top_k=10 | |
| ) | |
| return chatbot | |