akryldigital commited on
Commit
b0fe395
·
verified ·
1 Parent(s): 31b2f81

add visual chatbot

Browse files
Files changed (1) hide show
  1. src/agents/visual_chatbot.py +300 -0
src/agents/visual_chatbot.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Chatbot - Integrates ColPali visual search with LLM
3
+
4
+ This chatbot uses visual document retrieval (ColPali) instead of traditional
5
+ text-based RAG, then generates responses using an LLM.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Any, List, Optional
10
+ import os
11
+
12
+ from langchain_core.messages import HumanMessage, AIMessage
13
+ from langchain_openai import ChatOpenAI
14
+
15
+ from src.colpali.visual_search import VisualSearchAdapter, create_visual_search_adapter
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class VisualChatbot:
21
+ """
22
+ Chatbot that uses visual document retrieval (ColPali) for RAG.
23
+
24
+ Flow:
25
+ 1. User query → Visual search (ColPali embeddings)
26
+ 2. Retrieved visual documents → Context
27
+ 3. Context + Query → LLM → Response
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ visual_search: VisualSearchAdapter,
33
+ llm_model: str = "gpt-4o-mini",
34
+ top_k: int = 10,
35
+ temperature: float = 0.1
36
+ ):
37
+ """
38
+ Initialize visual chatbot.
39
+
40
+ Args:
41
+ visual_search: Visual search adapter
42
+ llm_model: LLM model to use
43
+ top_k: Number of documents to retrieve
44
+ temperature: LLM temperature
45
+ """
46
+ self.visual_search = visual_search
47
+ self.top_k = top_k
48
+
49
+ # Initialize LLM
50
+ logger.info(f"🤖 Initializing LLM: {llm_model}")
51
+ self.llm = ChatOpenAI(
52
+ model=llm_model,
53
+ temperature=temperature,
54
+ api_key=os.environ.get("OPENAI_API_KEY")
55
+ )
56
+
57
+ logger.info("✅ Visual Chatbot initialized!")
58
+
59
+ def chat(
60
+ self,
61
+ query: str,
62
+ conversation_id: str,
63
+ filters: Optional[Dict[str, Any]] = None
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Process a chat query using visual retrieval.
67
+
68
+ Args:
69
+ query: User query
70
+ conversation_id: Conversation ID (for tracking)
71
+ filters: Optional filters (parsed from query if present)
72
+
73
+ Returns:
74
+ Dictionary with:
75
+ - response: LLM response
76
+ - rag_result: Visual search results
77
+ - actual_rag_query: The query used for retrieval
78
+ """
79
+ logger.info(f"💬 Visual chat (conv={conversation_id}): '{query[:100]}...'")
80
+
81
+ # Parse filters from query if present
82
+ parsed_filters = self._parse_filters_from_query(query)
83
+ if parsed_filters:
84
+ logger.info(f" Parsed filters: {parsed_filters}")
85
+ # Extract clean query without filter context
86
+ clean_query = self._extract_clean_query(query)
87
+ else:
88
+ clean_query = query
89
+ parsed_filters = filters or {}
90
+
91
+ # Perform visual search
92
+ logger.info(f"🔍 Visual search: '{clean_query}'")
93
+ visual_results = self.visual_search.search(
94
+ query=clean_query,
95
+ top_k=self.top_k,
96
+ filters=parsed_filters,
97
+ search_strategy="multi_vector" # Use best strategy
98
+ )
99
+
100
+ # Build context from visual results
101
+ context = self._build_context(visual_results)
102
+
103
+ # Generate response using LLM
104
+ logger.info(f"🤖 Generating response with {len(visual_results)} visual documents")
105
+ response = self._generate_response(clean_query, context)
106
+
107
+ # Return in format expected by app.py
108
+ return {
109
+ 'response': response,
110
+ 'rag_result': {
111
+ 'sources': visual_results,
112
+ 'query': clean_query,
113
+ 'num_results': len(visual_results)
114
+ },
115
+ 'actual_rag_query': clean_query
116
+ }
117
+
118
+ def _parse_filters_from_query(self, query: str) -> Dict[str, List[str]]:
119
+ """
120
+ Parse filter context from query.
121
+
122
+ Expected format:
123
+ FILTER CONTEXT:
124
+ Sources: Source1, Source2
125
+ Years: 2020, 2021
126
+ Districts: District1
127
+ Filenames: file1.pdf, file2.pdf
128
+ USER QUERY:
129
+ actual query text
130
+ """
131
+ filters = {}
132
+
133
+ if "FILTER CONTEXT:" not in query:
134
+ return filters
135
+
136
+ lines = query.split('\n')
137
+ for line in lines:
138
+ line = line.strip()
139
+ if line.startswith("Sources:"):
140
+ sources = [s.strip() for s in line.replace("Sources:", "").split(',')]
141
+ filters['sources'] = sources
142
+ elif line.startswith("Years:"):
143
+ years = [int(y.strip()) for y in line.replace("Years:", "").split(',')]
144
+ filters['years'] = years
145
+ elif line.startswith("Districts:"):
146
+ districts = [d.strip() for d in line.replace("Districts:", "").split(',')]
147
+ filters['districts'] = districts
148
+ elif line.startswith("Filenames:"):
149
+ filenames = [f.strip() for f in line.replace("Filenames:", "").split(',')]
150
+ filters['filenames'] = filenames
151
+
152
+ return filters
153
+
154
+ def _extract_clean_query(self, query: str) -> str:
155
+ """Extract the actual query without filter context."""
156
+ if "USER QUERY:" in query:
157
+ return query.split("USER QUERY:")[-1].strip()
158
+ return query
159
+
160
+ def _build_context(self, results: List[Any]) -> str:
161
+ """
162
+ Build context string from visual search results.
163
+
164
+ Args:
165
+ results: List of VisualSearchResult objects
166
+
167
+ Returns:
168
+ Formatted context string
169
+ """
170
+ if not results:
171
+ return "No relevant documents found."
172
+
173
+ context_parts = []
174
+ for i, result in enumerate(results, 1):
175
+ # Extract metadata
176
+ metadata = result.metadata
177
+ filename = metadata.get('filename', 'Unknown')
178
+ page_number = metadata.get('page_number', '?')
179
+ year = metadata.get('year', 'Unknown')
180
+ source = metadata.get('source', 'Unknown')
181
+ text = result.page_content
182
+ score = result.score
183
+
184
+ # Format document
185
+ doc_str = f"""
186
+ Document {i} (Score: {score:.3f}):
187
+ Source: {source} | Year: {year} | File: {filename} | Page: {page_number}
188
+ Content:
189
+ {text}
190
+ ---
191
+ """
192
+ context_parts.append(doc_str)
193
+
194
+ return "\n".join(context_parts)
195
+
196
+ def _generate_response(self, query: str, context: str) -> str:
197
+ """
198
+ Generate response using LLM with visual retrieval context.
199
+
200
+ Args:
201
+ query: User query
202
+ context: Context from visual search
203
+
204
+ Returns:
205
+ LLM response
206
+ """
207
+ # Build prompt
208
+ system_prompt = """You are an intelligent assistant helping users analyze audit reports.
209
+
210
+ You have been provided with relevant document excerpts retrieved using visual document search (ColPali).
211
+ These documents were selected based on their visual and semantic similarity to the user's query.
212
+
213
+ Your task:
214
+ 1. Analyze the provided documents carefully
215
+ 2. Answer the user's question based ONLY on the information in the documents
216
+ 3. Cite specific sources (document number, page, year) when making claims
217
+ 4. If the documents don't contain enough information, say so clearly
218
+ 5. Be concise but comprehensive
219
+
220
+ Remember: The documents were retrieved using advanced visual search, so they may contain tables, figures, or structured data that is highly relevant."""
221
+
222
+ user_prompt = f"""Context from visual document search:
223
+
224
+ {context}
225
+
226
+ User Question: {query}
227
+
228
+ Please provide a detailed answer based on the documents above. Cite your sources."""
229
+
230
+ # Generate response
231
+ messages = [
232
+ {"role": "system", "content": system_prompt},
233
+ {"role": "user", "content": user_prompt}
234
+ ]
235
+
236
+ response = self.llm.invoke(messages)
237
+ return response.content
238
+
239
+
240
+ def get_visual_chatbot() -> VisualChatbot:
241
+ """
242
+ Factory function to create a visual chatbot.
243
+
244
+ Uses the same QDRANT_URL and QDRANT_API_KEY as the colpali_colab_package,
245
+ but connects to the 'colSmol-500M' collection instead of v1's collections.
246
+
247
+ Returns:
248
+ Initialized VisualChatbot
249
+ """
250
+ logger.info("🎨 Creating Visual Chatbot...")
251
+
252
+ # Check for ColPali cluster credentials in .env file
253
+ # Try multiple possible env var names
254
+ qdrant_url = (
255
+ os.environ.get("QDRANT_URL_AKRYL") or # Your .env has this
256
+ os.environ.get("DEST_QDRANT_URL") or # Your .env has this too
257
+ os.environ.get("QDRANT_URL") # Fallback
258
+ )
259
+ qdrant_api_key = (
260
+ os.environ.get("QDRANT_API_KEY_AKRYL") or # Your .env has this
261
+ os.environ.get("DEST_QDRANT_API_KEY") or # Your .env has this too
262
+ os.environ.get("QDRANT_API_KEY") # Fallback
263
+ )
264
+
265
+ if not qdrant_url or not qdrant_api_key:
266
+ raise ValueError(
267
+ "Visual mode requires Qdrant credentials for the ColPali cluster.\n"
268
+ "Please set one of these in your .env file:\n"
269
+ " - QDRANT_URL_AKRYL and QDRANT_API_KEY_AKRYL\n"
270
+ " - DEST_QDRANT_URL and DEST_QDRANT_API_KEY\n"
271
+ " - QDRANT_URL and QDRANT_API_KEY\n\n"
272
+ "These should point to the cluster containing the 'colSmol-500M' collection."
273
+ )
274
+
275
+ logger.info(f" Using Qdrant URL: {qdrant_url}")
276
+ logger.info(f" Collection: colSmol-500M")
277
+
278
+ # Create visual search adapter with explicit credentials
279
+ visual_search = VisualSearchAdapter(
280
+ qdrant_url=qdrant_url,
281
+ qdrant_api_key=qdrant_api_key,
282
+ collection_name="colSmol-500M"
283
+ )
284
+
285
+ # Get LLM config from settings.yaml
286
+ from src.config.loader import load_config
287
+ config = load_config("src/config/settings.yaml")
288
+ reader_config = config.get('reader', {})
289
+ openai_config = reader_config.get('OPENAI', {})
290
+ llm_model = openai_config.get('model', 'gpt-4o-mini')
291
+
292
+ # Create chatbot
293
+ chatbot = VisualChatbot(
294
+ visual_search=visual_search,
295
+ llm_model=llm_model,
296
+ top_k=10
297
+ )
298
+
299
+ return chatbot
300
+