akryldigital commited on
Commit
8d898c4
Β·
verified Β·
1 Parent(s): 69de8d2

Gemini FSA (#6)

Browse files

- refactor (1154c8ddd84455509a7f915a8588d951f925bbc6)
- refactor (72318ee79a0a0bcafca07cc5be70aace39a25f0c)
- add district Metadata (7c8b7838d143b45686e8360bbf64d2ef3c4a5624)
- refactor and add sample questions (02d7f4f76ed3415d70f279e6ab188435ae60fe92)
- add retrieval visualisations (5262a14ec5cccd4d5e9796cf7c881b211809f9fc)
- add Retrieval Distribution stats (763a8b9f7efaf0b823138df29267434d72edd477)
- Merge branch 'main' of https://huggingface.co/spaces/akryldigital/audit_assistant (264ca849bc4217b186eb0f8de0d23bce39856de7)
- Merge branch 'main' of https://huggingface.co/spaces/akryldigital/audit_assistant (b4984e28f7a41b312d3eba90b68651ab903f7d08)
- refactor + add gemini (72eb0bfa173ea05e8fd8e3b63429ef4678a01663)
- fix use_container_width=False (f8a1d4171fe8d1b3b84938b7b14bbe497e4159fe)
- finalize gemini version (3fc1b5f53b40772ba3c8abf400f1d987c12c4ee1)
- add gemini traceability (6f5999e84c97f19e3d1ff442873b52ef6ca8208e)
- fix gemini chunk extraction (06faccdb62f42a514e9f6b8cc93ca73f5cab5fa1)
- fix gemini chunk extraction (54bf55f7a03022a79da5ead74a289cb893efa88f)
- add upload debug (39edab443db3982dfc9f582868c66b9ab787a208)
- Remove scripts and ignore local_* files (de1d74a230264cf4ae8df516291869434f265d9f)

.gitignore CHANGED
@@ -109,4 +109,7 @@ pytest_cache/
109
  tmp/
110
  temp/
111
  *.tmp
112
- *.temp
 
 
 
 
109
  tmp/
110
  temp/
111
  *.tmp
112
+ *.temp
113
+
114
+
115
+ local_*
app.py CHANGED
@@ -10,6 +10,7 @@ import uuid
10
  import logging
11
  import traceback
12
  from pathlib import Path
 
13
  from collections import Counter
14
  from typing import List, Dict, Any, Optional
15
 
@@ -19,10 +20,11 @@ import streamlit as st
19
  import plotly.express as px
20
  from langchain_core.messages import HumanMessage, AIMessage
21
 
22
- from multi_agent_chatbot import get_multi_agent_chatbot
23
- from smart_chatbot import get_chatbot as get_smart_chatbot
24
- from src.reporting.snowflake_connector import save_to_snowflake
25
- from src.reporting.feedback_schema import create_feedback_from_dict
 
26
  from src.config.paths import (
27
  IS_DEPLOYED,
28
  PROJECT_DIR,
@@ -31,6 +33,7 @@ from src.config.paths import (
31
  CONVERSATIONS_DIR,
32
  )
33
 
 
34
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
35
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
36
  omp_threads = os.environ.get("OMP_NUM_THREADS", "")
@@ -70,6 +73,9 @@ if IS_DEPLOYED and HF_CACHE_DIR:
70
  except (PermissionError, OSError):
71
  # If we can't create it, log but continue (might already exist from Dockerfile)
72
  pass
 
 
 
73
 
74
  # Configure logging
75
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -90,116 +96,9 @@ st.set_page_config(
90
  page_title="Intelligent Audit Report Chatbot"
91
  )
92
 
93
- # Custom CSS
94
- st.markdown("""
95
- <style>
96
- .main-header {
97
- font-size: 2.5rem;
98
- font-weight: bold;
99
- color: #1f77b4;
100
- text-align: center;
101
- margin-bottom: 1rem;
102
- width: 100%;
103
- display: block;
104
- }
105
-
106
- .subtitle {
107
- font-size: 1.2rem;
108
- color: #666;
109
- text-align: center;
110
- margin-bottom: 2rem;
111
- width: 100%;
112
- display: block;
113
- }
114
-
115
- .session-info {
116
- background-color: #f0f2f6;
117
- padding: 10px;
118
- border-radius: 5px;
119
- margin-bottom: 20px;
120
- font-size: 0.9rem;
121
- }
122
-
123
- .user-message {
124
- background-color: #007bff;
125
- color: white;
126
- padding: 12px 16px;
127
- border-radius: 18px 18px 4px 18px;
128
- margin: 8px 0;
129
- margin-left: 20%;
130
- word-wrap: break-word;
131
- }
132
-
133
- .bot-message {
134
- background-color: #f1f3f4;
135
- color: #333;
136
- padding: 12px 16px;
137
- border-radius: 18px 18px 18px 4px;
138
- margin: 8px 0;
139
- margin-right: 20%;
140
- word-wrap: break-word;
141
- border: 1px solid #e0e0e0;
142
- }
143
-
144
- .filter-section {
145
- margin-bottom: 20px;
146
- padding: 15px;
147
- background-color: #f8f9fa;
148
- border-radius: 8px;
149
- border: 1px solid #e9ecef;
150
- }
151
-
152
- .filter-title {
153
- font-weight: bold;
154
- margin-bottom: 10px;
155
- color: #495057;
156
- }
157
-
158
- .feedback-section {
159
- background-color: #f8f9fa;
160
- padding: 20px;
161
- border-radius: 10px;
162
- margin-top: 30px;
163
- border: 2px solid #dee2e6;
164
- }
165
-
166
- .retrieval-history {
167
- background-color: #ffffff;
168
- padding: 15px;
169
- border-radius: 5px;
170
- margin: 10px 0;
171
- border-left: 4px solid #007bff;
172
- }
173
-
174
- .retrieval-distribution-container {
175
- background-color: #ffffff;
176
- padding: 25px;
177
- border-radius: 10px;
178
- margin: 20px 0;
179
- border: 2px solid #e0e0e0;
180
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
181
- }
182
-
183
- .metric-label {
184
- font-size: 0.9rem;
185
- color: #555;
186
- margin-bottom: 5px;
187
- text-align: center;
188
- }
189
-
190
- .metric-value {
191
- font-size: 1.8rem;
192
- font-weight: bold;
193
- color: #000000;
194
- text-align: center;
195
- }
196
-
197
- .metric-container {
198
- text-align: center;
199
- padding: 10px;
200
- }
201
- </style>
202
- """, unsafe_allow_html=True)
203
 
204
  def get_system_type():
205
  """Get the current system type"""
@@ -209,14 +108,17 @@ def get_system_type():
209
  else:
210
  return "Multi-Agent System"
211
 
212
- def get_chatbot():
213
- """Initialize and return the chatbot based on system type"""
214
- # Check environment variable for system type
215
- system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
216
- if system == 'smart':
217
- return get_smart_chatbot()
218
  else:
219
- return get_multi_agent_chatbot()
 
 
 
 
 
220
 
221
  def serialize_messages(messages):
222
  """Serialize LangChain messages to dictionaries"""
@@ -262,368 +164,8 @@ def serialize_documents(sources):
262
  return serialized
263
 
264
 
265
- def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
266
- """Extract transcript from messages - only user and bot messages, no extra metadata"""
267
- transcript = []
268
- for msg in messages:
269
- if isinstance(msg, HumanMessage):
270
- transcript.append({
271
- "role": "user",
272
- "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
273
- })
274
- elif isinstance(msg, AIMessage):
275
- transcript.append({
276
- "role": "assistant",
277
- "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
278
- })
279
- return transcript
280
-
281
-
282
- def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
283
- """Build retrievals structure from retrieval history"""
284
- retrievals = []
285
-
286
- for entry in rag_retrieval_history:
287
- # Get the user message that triggered this retrieval
288
- # The entry has conversation_up_to which includes messages up to that point
289
- conversation_up_to = entry.get("conversation_up_to", [])
290
-
291
- # Find the last user message in conversation_up_to (this is the trigger)
292
- user_message_trigger = ""
293
- for msg_dict in reversed(conversation_up_to):
294
- if msg_dict.get("type") == "HumanMessage":
295
- user_message_trigger = msg_dict.get("content", "")
296
- break
297
-
298
- # Fallback: if not found in conversation_up_to, get from actual messages
299
- # This handles edge cases where conversation_up_to might be incomplete
300
- if not user_message_trigger:
301
- # Find which retrieval this is (0-indexed)
302
- retrieval_idx = rag_retrieval_history.index(entry)
303
- # The user message that triggered this retrieval is at position (retrieval_idx * 2)
304
- # because each retrieval is preceded by: user message, bot response, user message, ...
305
- # But we need to account for the fact that the first retrieval happens after the first user message
306
- user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
307
- if retrieval_idx < len(user_msgs):
308
- user_message_trigger = str(user_msgs[retrieval_idx].content)
309
- elif user_msgs:
310
- # Fallback to last user message
311
- user_message_trigger = str(user_msgs[-1].content)
312
-
313
- # Get retrieved documents and truncate content to 100 chars
314
- docs_retrieved = entry.get("docs_retrieved", [])
315
- retrieved_docs = []
316
- for doc in docs_retrieved:
317
- doc_copy = doc.copy()
318
- # Truncate content to 100 characters (keep all other fields)
319
- if "content" in doc_copy:
320
- doc_copy["content"] = doc_copy["content"][:100]
321
- retrieved_docs.append(doc_copy)
322
-
323
- retrievals.append({
324
- "retrieved_docs": retrieved_docs,
325
- "user_message_trigger": user_message_trigger
326
- })
327
-
328
- return retrievals
329
-
330
-
331
- def build_feedback_score_related_retrieval_docs(
332
- is_feedback_about_last_retrieval: bool,
333
- messages: List[Any],
334
- rag_retrieval_history: List[Dict[str, Any]]
335
- ) -> Optional[Dict[str, Any]]:
336
- """Build feedback_score_related_retrieval_docs structure"""
337
- if not rag_retrieval_history:
338
- return None
339
-
340
- # Get the relevant retrieval entry
341
- if is_feedback_about_last_retrieval:
342
- relevant_entry = rag_retrieval_history[-1]
343
- else:
344
- # If feedback is about all retrievals, use the last one as default
345
- relevant_entry = rag_retrieval_history[-1]
346
-
347
- # Get conversation up to that point
348
- conversation_up_to = relevant_entry.get("conversation_up_to", [])
349
-
350
- # Convert to transcript format (role/content)
351
- conversation_up_to_point = []
352
- for msg_dict in conversation_up_to:
353
- if msg_dict.get("type") == "HumanMessage":
354
- conversation_up_to_point.append({
355
- "role": "user",
356
- "content": msg_dict.get("content", "")
357
- })
358
- elif msg_dict.get("type") == "AIMessage":
359
- conversation_up_to_point.append({
360
- "role": "assistant",
361
- "content": msg_dict.get("content", "")
362
- })
363
-
364
- # Get retrieved docs with full content (not truncated)
365
- retrieved_docs = relevant_entry.get("docs_retrieved", [])
366
-
367
- return {
368
- "conversation_up_to_point": conversation_up_to_point,
369
- "retrieved_docs": retrieved_docs
370
- }
371
 
372
- def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
373
- """Extract statistics from retrieved chunks."""
374
- if not sources:
375
- return {}
376
-
377
- sources_list = []
378
- years = []
379
- filenames = []
380
- districts = []
381
-
382
- for doc in sources:
383
- metadata = getattr(doc, 'metadata', {})
384
-
385
- # Extract source
386
- source = metadata.get('source', 'Unknown')
387
- sources_list.append(source)
388
-
389
- # Extract year
390
- year = metadata.get('year', 'Unknown')
391
- if year and year != 'Unknown':
392
- try:
393
- # Convert to int first, then back to string to ensure it's a proper year
394
- year_int = int(float(year)) # Handle both int and float strings
395
- if 1900 <= year_int <= 2030: # Reasonable year range
396
- years.append(str(year_int))
397
- else:
398
- years.append('Unknown')
399
- except (ValueError, TypeError):
400
- years.append('Unknown')
401
- else:
402
- years.append('Unknown')
403
-
404
- # Extract filename
405
- filename = metadata.get('filename', 'Unknown')
406
- filenames.append(filename)
407
-
408
- # Extract district
409
- district = metadata.get('district', 'Unknown')
410
- if district and district != 'Unknown':
411
- districts.append(district)
412
- else:
413
- districts.append('Unknown')
414
-
415
- # Count occurrences
416
- source_counts = Counter(sources_list)
417
- year_counts = Counter(years)
418
- filename_counts = Counter(filenames)
419
- district_counts = Counter(districts)
420
-
421
- return {
422
- 'total_chunks': len(sources),
423
- 'unique_sources': len(source_counts),
424
- 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
425
- 'unique_filenames': len(filename_counts),
426
- 'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
427
- 'source_distribution': dict(source_counts),
428
- 'year_distribution': dict(year_counts),
429
- 'filename_distribution': dict(filename_counts),
430
- 'district_distribution': dict(district_counts),
431
- 'sources': sources_list,
432
- 'years': years,
433
- 'filenames': filenames,
434
- 'districts': districts
435
- }
436
-
437
- def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
438
- """Display statistics as interactive charts for 10+ results."""
439
- if not stats or stats.get('total_chunks', 0) == 0:
440
- return
441
-
442
- # Wrap everything in one styled container - open it
443
- st.markdown(f"""
444
- <div class="retrieval-distribution-container">
445
- <h3 style="margin-top: 0;">πŸ“Š {title}</h3>
446
- <div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
447
- <div class="metric-container">
448
- <div class="metric-label">Total Chunks</div>
449
- <div class="metric-value">{stats['total_chunks']}</div>
450
- </div>
451
- <div class="metric-container">
452
- <div class="metric-label">Unique Sources</div>
453
- <div class="metric-value">{stats['unique_sources']}</div>
454
- </div>
455
- <div class="metric-container">
456
- <div class="metric-label">Unique Years</div>
457
- <div class="metric-value">{stats['unique_years']}</div>
458
- </div>
459
- <div class="metric-container">
460
- <div class="metric-label">Unique Files</div>
461
- <div class="metric-value">{stats['unique_filenames']}</div>
462
- </div>
463
- </div>
464
- """, unsafe_allow_html=True)
465
-
466
- # Charts - three columns to include Districts
467
- col1, col2, col3 = st.columns(3)
468
-
469
- with col1:
470
- # Source distribution chart
471
- if stats['source_distribution']:
472
- source_df = pd.DataFrame(
473
- list(stats['source_distribution'].items()),
474
- columns=['Source', 'Count']
475
- )
476
- fig_source = px.bar(
477
- source_df,
478
- x='Count',
479
- y='Source',
480
- orientation='h',
481
- title='Distribution by Source',
482
- color='Count',
483
- color_continuous_scale='viridis'
484
- )
485
- fig_source.update_layout(height=400, showlegend=False)
486
- st.plotly_chart(fig_source, use_container_width=True)
487
-
488
- with col2:
489
- # Year distribution chart
490
- if stats['year_distribution']:
491
- # Filter out 'Unknown' years for the chart
492
- year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
493
- if year_dist_filtered:
494
- year_df = pd.DataFrame(
495
- list(year_dist_filtered.items()),
496
- columns=['Year', 'Count']
497
- )
498
- # Sort by year as integer but keep as string for categorical display
499
- year_df['Year_Int'] = year_df['Year'].astype(int)
500
- year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
501
-
502
- fig_year = px.bar(
503
- year_df,
504
- x='Year',
505
- y='Count',
506
- title='Distribution by Year',
507
- color='Count',
508
- color_continuous_scale='plasma'
509
- )
510
- # Ensure years are treated as categorical (discrete) not continuous
511
- fig_year.update_xaxes(type='category')
512
- fig_year.update_layout(height=400, showlegend=False)
513
- st.plotly_chart(fig_year, use_container_width=True)
514
- else:
515
- st.info("No valid years found in the results")
516
-
517
- with col3:
518
- # District distribution chart
519
- if stats.get('district_distribution'):
520
- district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
521
- if district_dist_filtered:
522
- district_df = pd.DataFrame(
523
- list(district_dist_filtered.items()),
524
- columns=['District', 'Count']
525
- )
526
- district_df = district_df.sort_values('Count', ascending=False)
527
-
528
- fig_district = px.bar(
529
- district_df,
530
- x='Count',
531
- y='District',
532
- orientation='h',
533
- title='Distribution by District',
534
- color='Count',
535
- color_continuous_scale='blues'
536
- )
537
- fig_district.update_layout(height=400, showlegend=False)
538
- st.plotly_chart(fig_district, use_container_width=True)
539
- else:
540
- st.info("No valid districts found in the results")
541
-
542
- # Close the container
543
- st.markdown('</div>', unsafe_allow_html=True)
544
-
545
- def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
546
- """Display statistics as tables for smaller results with fixed alignment."""
547
- if not stats or stats.get('total_chunks', 0) == 0:
548
- return
549
-
550
- # Wrap in styled container
551
- st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
552
-
553
- st.subheader(f"πŸ“Š {title}")
554
-
555
- # Create a container with fixed height for alignment
556
- stats_container = st.container()
557
-
558
- with stats_container:
559
- # Create 4 equal columns for consistent alignment
560
- col1, col2, col3, col4 = st.columns(4)
561
-
562
- with col1:
563
- st.markdown("**🏘️ Districts**")
564
- if stats.get('district_distribution'):
565
- district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
566
- if district_dist_filtered:
567
- district_data = {
568
- "District": list(district_dist_filtered.keys()),
569
- "Count": list(district_dist_filtered.values())
570
- }
571
- district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
572
- st.dataframe(district_df, hide_index=True, use_container_width=True)
573
- else:
574
- st.write("No district data")
575
- else:
576
- st.write("No district data")
577
-
578
- with col2:
579
- st.markdown("**πŸ“‚ Sources**")
580
- if stats['source_distribution']:
581
- source_data = {
582
- "Source": list(stats['source_distribution'].keys()),
583
- "Count": list(stats['source_distribution'].values())
584
- }
585
- source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
586
- st.dataframe(source_df, hide_index=True, use_container_width=True)
587
- else:
588
- st.write("No source data")
589
-
590
- with col3:
591
- st.markdown("**πŸ“… Years**")
592
- if stats['year_distribution']:
593
- year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
594
- if year_dist_filtered:
595
- year_data = {
596
- "Year": list(year_dist_filtered.keys()),
597
- "Count": list(year_dist_filtered.values())
598
- }
599
- year_df = pd.DataFrame(year_data)
600
- # Sort by year as integer but display as string
601
- year_df['Year_Int'] = year_df['Year'].astype(int)
602
- year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
603
- st.dataframe(year_df, hide_index=True, use_container_width=True)
604
- else:
605
- st.write("No year data")
606
- else:
607
- st.write("No year data")
608
-
609
- with col4:
610
- st.markdown("**πŸ“„ Files**")
611
- if stats['filename_distribution']:
612
- filename_items = list(stats['filename_distribution'].items())
613
- filename_items.sort(key=lambda x: x[1], reverse=True)
614
-
615
- # Show top files with truncated names
616
- file_data = {
617
- "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
618
- "Count": [c for f, c in filename_items[:5]]
619
- }
620
- file_df = pd.DataFrame(file_data)
621
- st.dataframe(file_df, hide_index=True, use_container_width=True)
622
- else:
623
- st.write("No file data")
624
-
625
- # Close container
626
- st.markdown('</div>', unsafe_allow_html=True)
627
 
628
  @st.cache_data
629
  def load_filter_options():
@@ -649,11 +191,48 @@ def main():
649
  # Track RAG retrieval history for feedback
650
  if 'rag_retrieval_history' not in st.session_state:
651
  st.session_state.rag_retrieval_history = []
652
- # Initialize chatbot only once per app session (cached)
653
- if 'chatbot' not in st.session_state:
654
- with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
655
- st.session_state.chatbot = get_chatbot()
656
- st.success("βœ… AI system ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  # Reset conversation history if needed (but keep chatbot cached)
659
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
@@ -665,9 +244,43 @@ def main():
665
  st.session_state.reset_conversation = False
666
  st.rerun()
667
 
668
- # Header - fully center aligned
669
- st.markdown('<h1 class="main-header">πŸ€– Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
670
- st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
 
672
  # Session info
673
  duration = int(time.time() - st.session_state.session_start_time)
@@ -729,7 +342,7 @@ def main():
729
  # Determine if filename filter is active
730
  filename_mode = len(selected_filenames) > 0
731
  # Sources filter
732
- st.markdown('<div class="filter-section">', unsafe_allow_html=True)
733
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
734
  selected_sources = st.multiselect(
735
  "Select sources:",
@@ -826,7 +439,7 @@ def main():
826
  )
827
 
828
  with col2:
829
- send_button = st.button("Send", key="send_button", use_container_width=True)
830
 
831
  # Clear chat button
832
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
@@ -878,6 +491,36 @@ def main():
878
  if rag_result:
879
  sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
  # Get the actual RAG query
882
  actual_rag_query = chat_result.get('actual_rag_query', '')
883
  if actual_rag_query:
@@ -887,12 +530,25 @@ def main():
887
  else:
888
  formatted_query = "No RAG query available"
889
 
 
 
 
 
 
 
 
 
890
  retrieval_entry = {
891
  "conversation_up_to": serialize_messages(st.session_state.messages),
892
  "rag_query_expansion": formatted_query,
893
- "docs_retrieved": serialize_documents(sources)
 
 
894
  }
895
  st.session_state.rag_retrieval_history.append(retrieval_entry)
 
 
 
896
  else:
897
  response = chat_result
898
  st.session_state.last_rag_result = None
@@ -922,6 +578,16 @@ def main():
922
  # Dictionary format from multi-agent system
923
  sources = rag_result['sources']
924
 
 
 
 
 
 
 
 
 
 
 
925
  if sources and len(sources) > 0:
926
  # Count unique filenames
927
  unique_filenames = set()
@@ -951,9 +617,18 @@ def main():
951
  for i, doc in enumerate(sources): # Show all documents
952
  # Get relevance score and ID if available
953
  metadata = getattr(doc, 'metadata', {})
954
- score = metadata.get('reranked_score', metadata.get('original_score', None))
955
- chunk_id = metadata.get('_id', 'Unknown')
956
- score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
 
 
 
 
 
 
 
 
 
957
 
958
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
959
  # Display document metadata with emojis
@@ -1031,7 +706,7 @@ def main():
1031
 
1032
  submitted = st.form_submit_button(
1033
  "πŸ“€ Submit Feedback",
1034
- use_container_width=True,
1035
  disabled=submit_disabled
1036
  )
1037
 
@@ -1043,16 +718,18 @@ def main():
1043
  st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
1044
 
1045
  # Extract transcript from messages
1046
- transcript = extract_transcript(st.session_state.messages)
1047
 
1048
  # Build retrievals structure
1049
- retrievals = build_retrievals_structure(
 
1050
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
1051
  st.session_state.messages
1052
  )
1053
 
1054
  # Build feedback_score_related_retrieval_docs
1055
- feedback_score_related_retrieval_docs = build_feedback_score_related_retrieval_docs(
 
1056
  is_feedback_about_last_retrieval,
1057
  st.session_state.messages,
1058
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
@@ -1082,7 +759,7 @@ def main():
1082
  # Create UserFeedback dataclass instance
1083
  feedback_obj = None # Initialize outside try block
1084
  try:
1085
- feedback_obj = create_feedback_from_dict(feedback_dict)
1086
  print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
1087
  st.write(f"βœ… **Feedback Object Created**")
1088
  st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
@@ -1138,7 +815,11 @@ def main():
1138
  logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
1139
  print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
1140
 
1141
- snowflake_success = save_to_snowflake(feedback_obj)
 
 
 
 
1142
  if snowflake_success:
1143
  logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
1144
  print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
@@ -1193,7 +874,7 @@ def main():
1193
  st.success("βœ… Feedback already submitted for this conversation!")
1194
  col1, col2 = st.columns([1, 1])
1195
  with col1:
1196
- if st.button("πŸ”„ Submit New Feedback", key="new_feedback_button", use_container_width=True):
1197
  try:
1198
  st.session_state.feedback_submitted = False
1199
  st.rerun()
@@ -1202,7 +883,7 @@ def main():
1202
  logger.error(f"Error resetting feedback state: {e}")
1203
  st.error(f"Error resetting feedback. Please refresh the page.")
1204
  with col2:
1205
- if st.button("πŸ“‹ View Conversation", key="view_conversation_button", use_container_width=True):
1206
  # Scroll to conversation - this is handled by the auto-scroll at bottom
1207
  pass
1208
 
@@ -1211,20 +892,111 @@ def main():
1211
  st.markdown("---")
1212
  st.markdown("#### πŸ“Š Retrieval History")
1213
 
1214
- with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
1215
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
1216
- st.markdown(f"**Retrieval #{idx}**")
 
 
 
 
 
1217
 
1218
  # Display the actual RAG query
1219
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
 
1220
  st.code(rag_query_expansion, language="text")
1221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1222
  # Display summary stats
 
1223
  st.json({
1224
- "conversation_length": len(entry.get("conversation_up_to", [])),
1225
- "documents_retrieved": len(entry.get("docs_retrieved", []))
1226
  })
1227
- st.markdown("---")
 
 
1228
 
1229
  # Example Questions Section
1230
  st.markdown("---")
@@ -1245,7 +1017,7 @@ def main():
1245
  st.markdown(f"**Example:** `{example_q1}`")
1246
  st.info("πŸ’‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
1247
  with col2:
1248
- if st.button("πŸ“‹ Use This Question", key="use_example_1", use_container_width=True):
1249
  st.session_state.pending_question = example_q1
1250
  st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1251
  st.rerun()
@@ -1266,7 +1038,7 @@ def main():
1266
  )
1267
  col1, col2 = st.columns([1, 4])
1268
  with col1:
1269
- if st.button("πŸ“‹ Use Question 2", key="use_custom_1", use_container_width=True):
1270
  if custom_q1.strip():
1271
  st.session_state.pending_question = custom_q1.strip()
1272
  st.session_state.custom_question_1 = custom_q1.strip()
@@ -1292,7 +1064,7 @@ def main():
1292
  )
1293
  col1, col2 = st.columns([1, 4])
1294
  with col1:
1295
- if st.button("πŸ“‹ Use Question 3", key="use_custom_2", use_container_width=True):
1296
  if custom_q2.strip():
1297
  st.session_state.pending_question = custom_q2.strip()
1298
  st.session_state.custom_question_2 = custom_q2.strip()
 
10
  import logging
11
  import traceback
12
  from pathlib import Path
13
+
14
  from collections import Counter
15
  from typing import List, Dict, Any, Optional
16
 
 
20
  import plotly.express as px
21
  from langchain_core.messages import HumanMessage, AIMessage
22
 
23
+
24
+ from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
25
+ from src.feedback import FeedbackManager
26
+ from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
27
+
28
  from src.config.paths import (
29
  IS_DEPLOYED,
30
  PROJECT_DIR,
 
33
  CONVERSATIONS_DIR,
34
  )
35
 
36
+
37
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
38
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
39
  omp_threads = os.environ.get("OMP_NUM_THREADS", "")
 
73
  except (PermissionError, OSError):
74
  # If we can't create it, log but continue (might already exist from Dockerfile)
75
  pass
76
+ else:
77
+ from dotenv import load_dotenv
78
+ load_dotenv()
79
 
80
  # Configure logging
81
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
96
  page_title="Intelligent Audit Report Chatbot"
97
  )
98
 
99
+
100
+ st.markdown(get_custom_css(), unsafe_allow_html=True)
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def get_system_type():
104
  """Get the current system type"""
 
108
  else:
109
  return "Multi-Agent System"
110
 
111
+ def get_chatbot(version: str = "v1"):
112
+ """Initialize and return the chatbot based on version"""
113
+ if version == "beta":
114
+ return get_gemini_chatbot()
 
 
115
  else:
116
+ # Check environment variable for system type (v1)
117
+ system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
118
+ if system == 'smart':
119
+ return get_smart_chatbot()
120
+ else:
121
+ return get_multi_agent_chatbot()
122
 
123
  def serialize_messages(messages):
124
  """Serialize LangChain messages to dictionaries"""
 
164
  return serialized
165
 
166
 
167
+ feedback_manager = FeedbackManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  @st.cache_data
171
  def load_filter_options():
 
191
  # Track RAG retrieval history for feedback
192
  if 'rag_retrieval_history' not in st.session_state:
193
  st.session_state.rag_retrieval_history = []
194
+ # Version selection (v1 or beta)
195
+ if 'chatbot_version' not in st.session_state:
196
+ st.session_state.chatbot_version = "v1"
197
+
198
+ # Initialize chatbot based on version (only if not already initialized for this version)
199
+ chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
200
+
201
+ # Check if we need to initialize: chatbot doesn't exist OR version changed
202
+ needs_init = (
203
+ chatbot_version_key not in st.session_state or
204
+ st.session_state.get('_last_version') != st.session_state.chatbot_version
205
+ )
206
+
207
+ if needs_init:
208
+ try:
209
+ # Different spinner messages for different versions
210
+ if st.session_state.chatbot_version == "beta":
211
+ spinner_msg = "πŸ”„ Initializing Gemini FSA"
212
+ else:
213
+ spinner_msg = "πŸ”„ Loading AI models and connecting to database..."
214
+
215
+ with st.spinner(spinner_msg):
216
+ st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
217
+ st.session_state['_last_version'] = st.session_state.chatbot_version
218
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
219
+ print("βœ… AI system ready!")
220
+ except Exception as e:
221
+ st.error(f"❌ Failed to initialize chatbot: {str(e)}")
222
+ # Only show Gemini-specific error message for beta version
223
+ if st.session_state.chatbot_version == "beta":
224
+ st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
225
+ else:
226
+ st.error("Please check your configuration and ensure all required models and databases are accessible.")
227
+ # Reset to v1 to prevent infinite loop
228
+ st.session_state.chatbot_version = "v1"
229
+ st.session_state['_last_version'] = "v1"
230
+ if 'chatbot' in st.session_state:
231
+ del st.session_state['chatbot']
232
+ st.stop() # Stop execution to prevent infinite loop
233
+ else:
234
+ # Chatbot already initialized for this version, just use it
235
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
236
 
237
  # Reset conversation history if needed (but keep chatbot cached)
238
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
 
244
  st.session_state.reset_conversation = False
245
  st.rerun()
246
 
247
+
248
+ # Version selection radio button (top right)
249
+ col1, col2 = st.columns([3, 1])
250
+ with col1:
251
+ st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
252
+ with col2:
253
+ st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
254
+ selected_version = st.radio(
255
+ "**Version:**",
256
+ options=["v1", "beta"],
257
+ index=0 if st.session_state.chatbot_version == "v1" else 1,
258
+ horizontal=True,
259
+ key="version_selector",
260
+ help="Select v1 (default RAG system) or beta (Gemini FSA)"
261
+ )
262
+
263
+ # Update version if changed
264
+ if selected_version != st.session_state.chatbot_version:
265
+ # Store the old version to check if we need to switch
266
+ old_version = st.session_state.chatbot_version
267
+ st.session_state.chatbot_version = selected_version
268
+
269
+ # If chatbot for new version already exists, just switch to it
270
+ new_chatbot_key = f"chatbot_{selected_version}"
271
+ if new_chatbot_key in st.session_state:
272
+ # Chatbot already exists, just switch
273
+ st.session_state.chatbot = st.session_state[new_chatbot_key]
274
+ st.session_state['_last_version'] = selected_version
275
+ else:
276
+ # Need to initialize new version - will be handled by initialization logic above
277
+ st.session_state['_last_version'] = old_version # Set to old to trigger init check
278
+
279
+ st.rerun()
280
+
281
+ # Show version info
282
+ if st.session_state.chatbot_version == "beta":
283
+ st.info("πŸ”¬ **Beta Mode**: Using Google Gemini FSA")
284
 
285
  # Session info
286
  duration = int(time.time() - st.session_state.session_start_time)
 
342
  # Determine if filename filter is active
343
  filename_mode = len(selected_filenames) > 0
344
  # Sources filter
345
+ # st.markdown('<div class="filter-section">', unsafe_allow_html=True)
346
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
347
  selected_sources = st.multiselect(
348
  "Select sources:",
 
439
  )
440
 
441
  with col2:
442
+ send_button = st.button("Send", key="send_button", width='stretch')
443
 
444
  # Clear chat button
445
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
 
491
  if rag_result:
492
  sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
493
 
494
+ # For Gemini, also check gemini_result for sources
495
+ if not sources or len(sources) == 0:
496
+ gemini_result = chat_result.get('gemini_result')
497
+ print(f"πŸ” DEBUG: Checking gemini_result for sources...")
498
+ print(f" gemini_result exists: {gemini_result is not None}")
499
+ if gemini_result:
500
+ print(f" gemini_result type: {type(gemini_result)}")
501
+ print(f" has sources attr: {hasattr(gemini_result, 'sources')}")
502
+ if hasattr(gemini_result, 'sources'):
503
+ print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}")
504
+
505
+ if gemini_result and hasattr(gemini_result, 'sources'):
506
+ # Format Gemini sources for display
507
+ if hasattr(st.session_state.chatbot, 'gemini_client'):
508
+ sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
509
+ print(f"βœ… Formatted {len(sources)} sources from gemini_client")
510
+ elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
511
+ sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
512
+ print(f"βœ… Formatted {len(sources)} sources from _format_gemini_sources")
513
+
514
+ # Update rag_result with sources if we found them
515
+ if sources and len(sources) > 0:
516
+ if isinstance(rag_result, dict):
517
+ rag_result['sources'] = sources
518
+ elif hasattr(rag_result, 'sources'):
519
+ rag_result.sources = sources
520
+ # Update last_rag_result with sources
521
+ st.session_state.last_rag_result = rag_result
522
+ print(f"βœ… Updated rag_result with {len(sources)} sources")
523
+
524
  # Get the actual RAG query
525
  actual_rag_query = chat_result.get('actual_rag_query', '')
526
  if actual_rag_query:
 
530
  else:
531
  formatted_query = "No RAG query available"
532
 
533
+ # Extract filters from active filters
534
+ filters_used = {
535
+ "sources": st.session_state.active_filters.get('sources', []),
536
+ "years": st.session_state.active_filters.get('years', []),
537
+ "districts": st.session_state.active_filters.get('districts', []),
538
+ "filenames": st.session_state.active_filters.get('filenames', [])
539
+ }
540
+
541
  retrieval_entry = {
542
  "conversation_up_to": serialize_messages(st.session_state.messages),
543
  "rag_query_expansion": formatted_query,
544
+ "docs_retrieved": serialize_documents(sources),
545
+ "filters_applied": filters_used,
546
+ "timestamp": time.time()
547
  }
548
  st.session_state.rag_retrieval_history.append(retrieval_entry)
549
+
550
+ # Debug logging
551
+ print(f"πŸ“Š RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history")
552
  else:
553
  response = chat_result
554
  st.session_state.last_rag_result = None
 
578
  # Dictionary format from multi-agent system
579
  sources = rag_result['sources']
580
 
581
+ # For Gemini, also check if we need to format sources from gemini_result
582
+ if (not sources or len(sources) == 0) and isinstance(rag_result, dict):
583
+ gemini_result = rag_result.get('gemini_result')
584
+ if gemini_result and hasattr(gemini_result, 'sources'):
585
+ # Format Gemini sources for display
586
+ if hasattr(st.session_state.chatbot, 'gemini_client'):
587
+ sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
588
+ elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
589
+ sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
590
+
591
  if sources and len(sources) > 0:
592
  # Count unique filenames
593
  unique_filenames = set()
 
617
  for i, doc in enumerate(sources): # Show all documents
618
  # Get relevance score and ID if available
619
  metadata = getattr(doc, 'metadata', {})
620
+ # Handle both standard RAG scores and Gemini scores
621
+ score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
622
+ chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
623
+ if score is not None:
624
+ try:
625
+ score_text = f" (Score: {float(score):.3f})"
626
+ except (ValueError, TypeError):
627
+ score_text = ""
628
+ else:
629
+ score_text = ""
630
+ if chunk_id and chunk_id != 'Unknown':
631
+ score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
632
 
633
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
634
  # Display document metadata with emojis
 
706
 
707
  submitted = st.form_submit_button(
708
  "πŸ“€ Submit Feedback",
709
+ width='stretch',
710
  disabled=submit_disabled
711
  )
712
 
 
718
  st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
719
 
720
  # Extract transcript from messages
721
+ transcript = feedback_manager.extract_transcript(st.session_state.messages)
722
 
723
  # Build retrievals structure
724
+ retrievals = feedback_manager.build_retrievals_structure(
725
+
726
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
727
  st.session_state.messages
728
  )
729
 
730
  # Build feedback_score_related_retrieval_docs
731
+
732
+ feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
733
  is_feedback_about_last_retrieval,
734
  st.session_state.messages,
735
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
 
759
  # Create UserFeedback dataclass instance
760
  feedback_obj = None # Initialize outside try block
761
  try:
762
+ feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
763
  print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
764
  st.write(f"βœ… **Feedback Object Created**")
765
  st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
 
815
  logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
816
  print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
817
 
818
+ # Show spinner while saving to Snowflake (can take 10-15 seconds)
819
+ # This includes: connection establishment (~5s), data preparation, and SQL execution (~5s)
820
+ with st.spinner("πŸ’Ύ Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"):
821
+ snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
822
+
823
  if snowflake_success:
824
  logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
825
  print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
 
874
  st.success("βœ… Feedback already submitted for this conversation!")
875
  col1, col2 = st.columns([1, 1])
876
  with col1:
877
+ if st.button("πŸ”„ Submit New Feedback", key="new_feedback_button", width='stretch'):
878
  try:
879
  st.session_state.feedback_submitted = False
880
  st.rerun()
 
883
  logger.error(f"Error resetting feedback state: {e}")
884
  st.error(f"Error resetting feedback. Please refresh the page.")
885
  with col2:
886
+ if st.button("πŸ“‹ View Conversation", key="view_conversation_button", width='stretch'):
887
  # Scroll to conversation - this is handled by the auto-scroll at bottom
888
  pass
889
 
 
892
  st.markdown("---")
893
  st.markdown("#### πŸ“Š Retrieval History")
894
 
895
+ with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
896
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
897
+ st.markdown(f"### **Retrieval #{idx}**")
898
+
899
+ # Display timestamp if available
900
+ if entry.get("timestamp"):
901
+ timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
902
+ st.caption(f"πŸ• {timestamp_str}")
903
 
904
  # Display the actual RAG query
905
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
906
+ st.markdown("**πŸ” RAG Query:**")
907
  st.code(rag_query_expansion, language="text")
908
 
909
+ # Display filters used
910
+ filters_applied = entry.get("filters_applied", {})
911
+ if filters_applied and any(filters_applied.values()):
912
+ st.markdown("**🎯 Filters Applied:**")
913
+ filter_display = {}
914
+ if filters_applied.get("sources"):
915
+ filter_display["Sources"] = filters_applied["sources"]
916
+ if filters_applied.get("years"):
917
+ filter_display["Years"] = filters_applied["years"]
918
+ if filters_applied.get("districts"):
919
+ filter_display["Districts"] = filters_applied["districts"]
920
+ if filters_applied.get("filenames"):
921
+ filter_display["Filenames"] = filters_applied["filenames"]
922
+
923
+ if filter_display:
924
+ st.json(filter_display)
925
+ else:
926
+ st.info("No filters applied")
927
+ else:
928
+ st.info("No filters applied")
929
+
930
+ # Display conversation history up to retrieval point
931
+ conversation_up_to = entry.get("conversation_up_to", [])
932
+ if conversation_up_to:
933
+ st.markdown("**πŸ’¬ Conversation History (up to retrieval point):**")
934
+ with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
935
+ for msg_idx, msg in enumerate(conversation_up_to, 1):
936
+ role = msg.get("type", "unknown")
937
+ content = msg.get("content", "")
938
+
939
+ if role == "HumanMessage" or role == "human":
940
+ st.markdown(f"**πŸ‘€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
941
+ elif role == "AIMessage" or role == "ai":
942
+ st.markdown(f"**πŸ€– Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
943
+ else:
944
+ st.info("No conversation history available")
945
+
946
+ # Display documents retrieved
947
+ docs_retrieved = entry.get("docs_retrieved", [])
948
+ if docs_retrieved:
949
+ st.markdown(f"**πŸ“„ Documents Retrieved ({len(docs_retrieved)}):**")
950
+ with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
951
+ for doc_idx, doc in enumerate(docs_retrieved, 1):
952
+ st.markdown(f"**Document {doc_idx}:**")
953
+
954
+ # Display metadata
955
+ metadata = doc.get("metadata", {})
956
+ if metadata:
957
+ col1, col2, col3 = st.columns(3)
958
+ with col1:
959
+ st.write(f"πŸ“„ **File:** {metadata.get('filename', 'Unknown')}")
960
+ with col2:
961
+ st.write(f"πŸ›οΈ **Source:** {metadata.get('source', 'Unknown')}")
962
+ with col3:
963
+ st.write(f"πŸ“… **Year:** {metadata.get('year', 'Unknown')}")
964
+
965
+ # Additional metadata
966
+ if metadata.get('district'):
967
+ st.write(f"πŸ“ **District:** {metadata.get('district')}")
968
+ if metadata.get('page'):
969
+ st.write(f"πŸ“– **Page:** {metadata.get('page')}")
970
+ if metadata.get('score') is not None:
971
+ st.write(f"⭐ **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"⭐ **Score:** {metadata.get('score')}")
972
+
973
+ # Display content preview (first 200 chars)
974
+ content = doc.get("content", doc.get("page_content", ""))
975
+ if content:
976
+ st.markdown("**Content Preview:**")
977
+ st.text_area(
978
+ "Content Preview",
979
+ value=content[:200] + ("..." if len(content) > 200 else ""),
980
+ height=100,
981
+ disabled=True,
982
+ label_visibility="collapsed",
983
+ key=f"retrieval_{idx}_doc_{doc_idx}_preview"
984
+ )
985
+
986
+ if doc_idx < len(docs_retrieved):
987
+ st.markdown("---")
988
+ else:
989
+ st.info("No documents retrieved")
990
+
991
  # Display summary stats
992
+ st.markdown("**πŸ“Š Summary:**")
993
  st.json({
994
+ "conversation_length": len(conversation_up_to),
995
+ "documents_retrieved": len(docs_retrieved)
996
  })
997
+
998
+ if idx < len(st.session_state.rag_retrieval_history):
999
+ st.markdown("---")
1000
 
1001
  # Example Questions Section
1002
  st.markdown("---")
 
1017
  st.markdown(f"**Example:** `{example_q1}`")
1018
  st.info("πŸ’‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
1019
  with col2:
1020
+ if st.button("πŸ“‹ Use This Question", key="use_example_1", width='stretch'):
1021
  st.session_state.pending_question = example_q1
1022
  st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1023
  st.rerun()
 
1038
  )
1039
  col1, col2 = st.columns([1, 4])
1040
  with col1:
1041
+ if st.button("πŸ“‹ Use Question 2", key="use_custom_1", width='stretch'):
1042
  if custom_q1.strip():
1043
  st.session_state.pending_question = custom_q1.strip()
1044
  st.session_state.custom_question_1 = custom_q1.strip()
 
1064
  )
1065
  col1, col2 = st.columns([1, 4])
1066
  with col1:
1067
+ if st.button("πŸ“‹ Use Question 3", key="use_custom_2", width='stretch'):
1068
  if custom_q2.strip():
1069
  st.session_state.pending_question = custom_q2.strip()
1070
  st.session_state.custom_question_2 = custom_q2.strip()
src/agents/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent modules for chatbot implementations
3
+ """
4
+
5
+ from .smart_chatbot import get_chatbot as get_smart_chatbot
6
+ from .multi_agent_chatbot import get_multi_agent_chatbot
7
+ from .gemini_chatbot import get_gemini_chatbot
8
+
9
+ __all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
10
+
src/agents/gemini_chatbot.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Chatbot (Beta Version)
3
+
4
+ This chatbot uses Google Gemini File Search API for RAG.
5
+ It provides a simpler architecture: Main Agent + Gemini Agent
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ import logging
12
+ import traceback
13
+ from pathlib import Path
14
+ from typing import Dict, List, Any, Optional, TypedDict
15
+
16
+ from langgraph.graph import StateGraph, END
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
19
+
20
+ from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
21
+ from src.config.paths import CONVERSATIONS_DIR
22
+
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GeminiState(TypedDict):
28
+ """State for Gemini chatbot conversation flow"""
29
+ conversation_id: str
30
+ messages: List[Any]
31
+ current_query: str
32
+ query_context: Optional[Dict[str, Any]]
33
+ gemini_result: Optional[GeminiFileSearchResult]
34
+ final_response: Optional[str]
35
+ agent_logs: List[str]
36
+ conversation_context: Dict[str, Any]
37
+ session_start_time: float
38
+ last_ai_message_time: float
39
+ filters: Optional[Dict[str, Any]]
40
+
41
+
42
+ class GeminiRAGChatbot:
43
+ """Gemini File Search RAG chatbot (Beta version)"""
44
+
45
+ def __init__(self):
46
+ """Initialize the Gemini chatbot"""
47
+ logger.info("πŸ€– INITIALIZING: Gemini File Search Chatbot (Beta)")
48
+
49
+ # Initialize Gemini File Search client
50
+ try:
51
+ self.gemini_client = GeminiFileSearchClient()
52
+ logger.info("βœ… Gemini File Search client initialized")
53
+ except Exception as e:
54
+ logger.error(f"❌ Failed to initialize Gemini client: {e}")
55
+ raise RuntimeError(f"Gemini client initialization failed: {e}")
56
+
57
+ # Build the LangGraph with LangSmith tracing if enabled
58
+ self.graph = self._build_graph()
59
+
60
+ # Enable LangSmith tracing if configured
61
+ langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
62
+ if langsmith_enabled:
63
+ logger.info("πŸ” LangSmith tracing enabled")
64
+ langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
65
+ logger.info(f"πŸ“Š LangSmith project: {langsmith_project}")
66
+
67
+ # Conversations directory
68
+ self.conversations_dir = CONVERSATIONS_DIR
69
+ try:
70
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
71
+ except (PermissionError, OSError) as e:
72
+ logger.warning(f"Could not create conversations directory: {e}")
73
+ self.conversations_dir = Path("conversations")
74
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
75
+
76
+ logger.info("βœ… Gemini File Search Chatbot initialized")
77
+
78
+ def _build_graph(self) -> StateGraph:
79
+ """Build the LangGraph for Gemini chatbot"""
80
+ graph = StateGraph(GeminiState)
81
+
82
+ # Add nodes
83
+ graph.add_node("main_agent", self._main_agent)
84
+ graph.add_node("gemini_agent", self._gemini_agent)
85
+
86
+ # Define the flow
87
+ graph.set_entry_point("main_agent")
88
+ graph.add_edge("main_agent", "gemini_agent")
89
+ graph.add_edge("gemini_agent", END)
90
+
91
+ return graph.compile()
92
+
93
+ def _main_agent(self, state: GeminiState) -> GeminiState:
94
+ """Main Agent: Extracts filters and prepares query"""
95
+ logger.info("🎯 MAIN AGENT: Processing query")
96
+
97
+ query = state["current_query"]
98
+ messages = state["messages"]
99
+
100
+ # Extract UI filters if present in query
101
+ ui_filters = self._extract_ui_filters(query)
102
+
103
+ # Extract context from conversation
104
+ context = self._extract_context_from_conversation(messages, ui_filters)
105
+
106
+ # Store context and filters
107
+ state["query_context"] = context
108
+ state["filters"] = context.get("filters", {})
109
+
110
+ logger.info(f"🎯 MAIN AGENT: Filters extracted: {state['filters']}")
111
+
112
+ return state
113
+
114
+ def _gemini_agent(self, state: GeminiState) -> GeminiState:
115
+ """Gemini Agent: Performs file search and generates response"""
116
+ logger.info("πŸ” GEMINI AGENT: Starting file search")
117
+
118
+ query = state["current_query"]
119
+ filters = state.get("filters", {})
120
+
121
+ # Perform Gemini file search
122
+ try:
123
+ result = self.gemini_client.search(query=query, filters=filters)
124
+ logger.info(f"βœ… GEMINI AGENT: Search completed, {len(result.sources)} sources found")
125
+
126
+ # Enhance response with document references
127
+ enhanced_response = self._enhance_response_with_references(
128
+ result.answer,
129
+ result.sources,
130
+ query
131
+ )
132
+
133
+ state["gemini_result"] = result
134
+ state["final_response"] = enhanced_response
135
+ state["last_ai_message_time"] = time.time()
136
+
137
+ state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
138
+
139
+ except Exception as e:
140
+ logger.error(f"❌ GEMINI AGENT ERROR: {e}")
141
+ traceback.print_exc()
142
+ state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
143
+ state["last_ai_message_time"] = time.time()
144
+
145
+ return state
146
+
147
+ def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
148
+ """Enhance Gemini response to include document references and format nicely"""
149
+ if not sources or not answer:
150
+ return answer
151
+
152
+ # Use LLM to intelligently add document references and format nicely
153
+ try:
154
+ from src.llm.adapters import get_llm_client
155
+ llm = get_llm_client()
156
+
157
+ # Prepare document summaries for the LLM
158
+ doc_summaries = []
159
+ for idx, doc in enumerate(sources, 1):
160
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
161
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
162
+
163
+ filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
164
+ year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
165
+ source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
166
+ district = metadata.get('district', '') if isinstance(metadata, dict) else ''
167
+
168
+ doc_info = f"{filename}"
169
+ if year and year != 'Unknown':
170
+ doc_info += f" ({year})"
171
+ if source and source != 'Unknown':
172
+ doc_info += f" - {source}"
173
+ if district:
174
+ doc_info += f" - {district}"
175
+
176
+ doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...")
177
+
178
+ prompt = f"""You are enhancing a response from a document search system. The original response is:
179
+
180
+ {answer}
181
+
182
+ The following documents were retrieved and used to generate this response:
183
+
184
+ {chr(10).join(doc_summaries)}
185
+
186
+ CRITICAL RULES:
187
+ 1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate
188
+ 2. The response should ONLY contain information from the retrieved documents listed above
189
+ 3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
190
+ 4. Add document references [Doc i] at the end of sentences that use information from specific documents
191
+ 5. Only reference documents that are actually used in the response
192
+ 6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
193
+ 7. Keep the response natural, conversational, and well-formatted
194
+ 8. Use proper formatting: paragraphs, line breaks, and structure for readability
195
+ 9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting
196
+ 10. If multiple documents support the same claim, use [Doc i, Doc j] format
197
+ 11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
198
+
199
+ Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary."""
200
+
201
+ enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
202
+
203
+ # Fallback: if LLM fails, just return original with basic formatting
204
+ if not enhanced or len(enhanced) < len(answer) * 0.5:
205
+ logger.warning("LLM enhancement failed, using original response with basic formatting")
206
+ # Basic formatting: add line breaks after periods for readability
207
+ formatted = answer.replace('. ', '.\n\n')
208
+ if sources:
209
+ ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
210
+ formatted += f"\n\n*Based on documents: {ref_list}*"
211
+ return formatted
212
+
213
+ return enhanced
214
+
215
+ except Exception as e:
216
+ logger.warning(f"Failed to enhance response with references: {e}")
217
+ # Fallback: add basic formatting and references at the end
218
+ formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting
219
+ if sources:
220
+ ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
221
+ formatted += f"\n\n*Based on documents: {ref_list}*"
222
+ return formatted
223
+
224
+ def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
225
+ """Extract UI filters from query if present"""
226
+ filters = {}
227
+
228
+ if "FILTER CONTEXT:" in query:
229
+ filter_section = query.split("FILTER CONTEXT:")[1]
230
+ if "USER QUERY:" in filter_section:
231
+ filter_section = filter_section.split("USER QUERY:")[0]
232
+ filter_section = filter_section.strip()
233
+
234
+ if "Sources:" in filter_section:
235
+ sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
236
+ if sources_line:
237
+ sources_str = sources_line[0].split("Sources:")[1].strip()
238
+ if sources_str and sources_str != "None":
239
+ filters["sources"] = [s.strip() for s in sources_str.split(",")]
240
+
241
+ if "Years:" in filter_section:
242
+ years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
243
+ if years_line:
244
+ years_str = years_line[0].split("Years:")[1].strip()
245
+ if years_str and years_str != "None":
246
+ filters["year"] = [y.strip() for y in years_str.split(",")]
247
+
248
+ if "Districts:" in filter_section:
249
+ districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
250
+ if districts_line:
251
+ districts_str = districts_line[0].split("Districts:")[1].strip()
252
+ if districts_str and districts_str != "None":
253
+ filters["district"] = [d.strip() for d in districts_str.split(",")]
254
+
255
+ if "Filenames:" in filter_section:
256
+ filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
257
+ if filenames_line:
258
+ filenames_str = filenames_line[0].split("Filenames:")[1].strip()
259
+ if filenames_str and filenames_str != "None":
260
+ filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
261
+
262
+ return filters
263
+
264
+ def _extract_context_from_conversation(
265
+ self,
266
+ messages: List[Any],
267
+ ui_filters: Dict[str, List[str]]
268
+ ) -> Dict[str, Any]:
269
+ """Extract context from conversation history"""
270
+ # Use UI filters if available
271
+ filters = ui_filters.copy() if ui_filters else {}
272
+
273
+ # For Gemini, we pass filters directly to the search function
274
+ # The filters will be used to add context to the query
275
+
276
+ return {
277
+ "filters": filters,
278
+ "has_filters": bool(filters)
279
+ }
280
+
281
+ def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
282
+ """Main chat interface"""
283
+ logger.info(f"πŸ’¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
284
+
285
+ # Load conversation
286
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
287
+ conversation = self._load_conversation(conversation_file)
288
+
289
+ # Add user message
290
+ conversation["messages"].append(HumanMessage(content=user_input))
291
+
292
+ # Prepare state
293
+ state = GeminiState(
294
+ conversation_id=conversation_id,
295
+ messages=conversation["messages"],
296
+ current_query=user_input,
297
+ query_context=None,
298
+ gemini_result=None,
299
+ final_response=None,
300
+ agent_logs=[],
301
+ conversation_context=conversation.get("context", {}),
302
+ session_start_time=conversation["session_start_time"],
303
+ last_ai_message_time=conversation["last_ai_message_time"],
304
+ filters=None
305
+ )
306
+
307
+ # Run graph
308
+ final_state = self.graph.invoke(state)
309
+
310
+ # Add AI response to conversation
311
+ if final_state["final_response"]:
312
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
313
+
314
+ # Update conversation
315
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
316
+ conversation["context"] = final_state["conversation_context"]
317
+
318
+ # Save conversation
319
+ self._save_conversation(conversation_file, conversation)
320
+
321
+ # Format sources for display
322
+ sources = []
323
+ gemini_result = final_state.get("gemini_result")
324
+ if gemini_result:
325
+ sources = self.gemini_client.format_sources_for_display(gemini_result)
326
+ logger.info(f"πŸ“‹ GEMINI CHAT: Formatted {len(sources)} sources for display")
327
+
328
+ return {
329
+ 'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
330
+ 'rag_result': {
331
+ 'sources': sources,
332
+ 'answer': final_state["final_response"]
333
+ },
334
+ 'agent_logs': final_state["agent_logs"],
335
+ 'actual_rag_query': final_state["current_query"],
336
+ 'gemini_result': gemini_result # Include raw result for tracking
337
+ }
338
+
339
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
340
+ """Load conversation from file"""
341
+ if conversation_file.exists():
342
+ try:
343
+ with open(conversation_file) as f:
344
+ data = json.load(f)
345
+ messages = []
346
+ for msg_data in data.get("messages", []):
347
+ if msg_data["type"] == "human":
348
+ messages.append(HumanMessage(content=msg_data["content"]))
349
+ elif msg_data["type"] == "ai":
350
+ messages.append(AIMessage(content=msg_data["content"]))
351
+ data["messages"] = messages
352
+ return data
353
+ except Exception as e:
354
+ logger.warning(f"Could not load conversation: {e}")
355
+
356
+ return {
357
+ "messages": [],
358
+ "session_start_time": time.time(),
359
+ "last_ai_message_time": time.time(),
360
+ "context": {}
361
+ }
362
+
363
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
364
+ """Save conversation to file"""
365
+ try:
366
+ conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
367
+
368
+ messages_data = []
369
+ for msg in conversation["messages"]:
370
+ if isinstance(msg, HumanMessage):
371
+ messages_data.append({"type": "human", "content": msg.content})
372
+ elif isinstance(msg, AIMessage):
373
+ messages_data.append({"type": "ai", "content": msg.content})
374
+
375
+ conversation_data = {
376
+ "messages": messages_data,
377
+ "session_start_time": conversation["session_start_time"],
378
+ "last_ai_message_time": conversation["last_ai_message_time"],
379
+ "context": conversation.get("context", {})
380
+ }
381
+
382
+ with open(conversation_file, 'w') as f:
383
+ json.dump(conversation_data, f, indent=2)
384
+
385
+ except Exception as e:
386
+ logger.error(f"Could not save conversation: {e}")
387
+
388
+
389
+ def get_gemini_chatbot():
390
+ """Get Gemini chatbot instance"""
391
+ return GeminiRAGChatbot()
392
+
multi_agent_chatbot.py β†’ src/agents/multi_agent_chatbot.py RENAMED
@@ -208,6 +208,59 @@ class MultiAgentRAGChatbot:
208
  logger.info(f" Sources: {self.source_whitelist}")
209
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def _build_graph(self) -> StateGraph:
212
  """Build the multi-agent LangGraph"""
213
  graph = StateGraph(MultiAgentState)
@@ -512,6 +565,10 @@ class MultiAgentRAGChatbot:
512
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
513
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
514
  - Always return districts as JSON arrays when multiple districts are mentioned
 
 
 
 
515
  - If no exact matches found, set extracted values to null
516
 
517
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
656
  # Validate each district in the array
657
  valid_districts = []
658
  for district in extracted_district:
659
- if district in self.district_whitelist:
660
- valid_districts.append(district)
661
- else:
662
- # Try removing "District" suffix
663
- district_name = district.replace(" District", "").replace(" district", "")
664
- if district_name in self.district_whitelist:
665
- valid_districts.append(district_name)
666
 
667
  if valid_districts:
668
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
671
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
672
  extracted_district = None
673
  else:
674
- # Single district validation
675
- if extracted_district not in self.district_whitelist:
676
- # Try removing "District" suffix
677
- district_name = extracted_district.replace(" District", "").replace(" district", "")
678
- if district_name in self.district_whitelist:
679
- logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{district_name}'")
680
- extracted_district = district_name
681
- else:
682
- logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
683
- extracted_district = None
684
 
685
  # Validate source (handle both single values and arrays)
686
  if extracted_source:
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
918
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
919
 
920
  # Merge with extracted context for missing filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  if not filters.get("year") and context.extracted_year:
922
  # Handle both single values and arrays
923
  if isinstance(context.extracted_year, list):
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
926
  filters["year"] = [context.extracted_year]
927
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
928
 
929
- if not filters.get("district") and context.extracted_district:
930
- # Handle both single values and arrays
931
- if isinstance(context.extracted_district, list):
932
- # Normalize district names to title case (match Qdrant metadata format)
933
- normalized = [d.title() for d in context.extracted_district]
934
- filters["district"] = normalized
935
- else:
936
- filters["district"] = [context.extracted_district.title()]
937
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
938
-
939
  if not filters.get("sources") and context.extracted_source:
940
  # Handle both single values and arrays
941
  if isinstance(context.extracted_source, list):
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
963
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
964
 
965
  if context.extracted_district:
966
- # Handle both single values and arrays
967
  if isinstance(context.extracted_district, list):
968
- filters["district"] = context.extracted_district
 
 
 
 
 
 
 
969
  else:
970
- filters["district"] = [context.extracted_district]
971
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter: {context.extracted_district}")
 
 
972
 
973
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
974
  return filters
@@ -978,49 +1046,212 @@ Rewrite the best retrieval query:""")
978
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
979
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
980
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
 
 
 
 
 
 
 
 
 
 
981
 
982
  # Create response prompt
983
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
984
  response_prompt = ChatPromptTemplate.from_messages([
985
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
986
 
 
 
 
 
 
 
 
 
987
  RULES:
988
  1. Answer the user's question directly and clearly
989
- 2. Use the retrieved documents as evidence
990
  3. Be conversational, not technical
991
  4. Don't mention scores, retrieval details, or technical implementation
992
  5. If relevant documents were found, reference them naturally
993
- 6. If no relevant documents, explain based on your knowledge (if you have it) or just say you do not have enough information.
994
- 7. If the passages have useful facts or numbers, use them in your answer.
995
- 8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
996
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
997
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
998
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
999
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1000
  13. You do not need to use every passage. Only use the ones that help answer the question.
1001
- 14. If the documents do not have the information needed to answer the question, just say you do not have enough information.
1002
-
 
1003
 
1004
  TONE: Professional but friendly, like talking to a colleague."""),
1005
- HumanMessage(content=f"""User Question: {query}
 
 
 
1006
 
1007
  Retrieved Documents: {len(documents)} documents found
1008
 
 
 
 
 
 
 
1009
  RAG Answer: {rag_answer}
1010
 
1011
- Generate a conversational response:""")
 
 
 
 
 
 
 
1012
  ])
1013
 
1014
  try:
1015
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1016
  response = self.llm.invoke(response_prompt.format_messages())
1017
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1018
- return response.content.strip()
 
 
 
 
 
 
 
 
1019
  except Exception as e:
1020
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1021
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1022
  return rag_answer # Fallback to RAG answer
1023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1025
  """Generate conversational response using only LLM knowledge and conversation history"""
1026
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
 
208
  logger.info(f" Sources: {self.source_whitelist}")
209
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
210
 
211
+ def _normalize_district_name(self, district: str) -> Optional[str]:
212
+ """Normalize district name with fuzzy matching for common misspellings."""
213
+ if not district:
214
+ return None
215
+
216
+ district = district.strip()
217
+
218
+ # Direct match
219
+ if district in self.district_whitelist:
220
+ return district
221
+
222
+ # Remove "District" suffix
223
+ district_name = district.replace(" District", "").replace(" district", "").strip()
224
+ if district_name in self.district_whitelist:
225
+ return district_name
226
+
227
+ # Common misspellings mapping
228
+ misspelling_map = {
229
+ "kalagala": "Kalangala",
230
+ "Kalagala": "Kalangala",
231
+ "KALAGALA": "Kalangala",
232
+ "kalangala": "Kalangala",
233
+ "gulu": "Gulu",
234
+ "GULU": "Gulu",
235
+ "kampala": "Kampala",
236
+ "KAMPALA": "Kampala",
237
+ }
238
+
239
+ # Check misspelling map (case-insensitive)
240
+ district_lower = district_name.lower()
241
+ if district_lower in misspelling_map:
242
+ corrected = misspelling_map[district_lower]
243
+ if corrected in self.district_whitelist:
244
+ return corrected
245
+
246
+ # Fuzzy matching for similar names (simple Levenshtein-like check)
247
+ # Check if the district name is very similar to any whitelist entry
248
+ for whitelist_district in self.district_whitelist:
249
+ # Case-insensitive comparison
250
+ if district_name.lower() == whitelist_district.lower():
251
+ return whitelist_district
252
+
253
+ # Check if one is a substring of the other (for partial matches)
254
+ if len(district_name) >= 4 and len(whitelist_district) >= 4:
255
+ if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
256
+ # Only return if it's a strong match (at least 80% of characters match)
257
+ min_len = min(len(district_name), len(whitelist_district))
258
+ max_len = max(len(district_name), len(whitelist_district))
259
+ if min_len / max_len >= 0.8:
260
+ return whitelist_district
261
+
262
+ return None
263
+
264
  def _build_graph(self) -> StateGraph:
265
  """Build the multi-agent LangGraph"""
266
  graph = StateGraph(MultiAgentState)
 
565
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
566
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
567
  - Always return districts as JSON arrays when multiple districts are mentioned
568
+ - **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
569
+ * "Kalagala" (missing 'n') should be extracted as "Kalangala"
570
+ * "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
571
+ * Similar case-insensitive variations should be normalized to the correct district name
572
  - If no exact matches found, set extracted values to null
573
 
574
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
 
713
  # Validate each district in the array
714
  valid_districts = []
715
  for district in extracted_district:
716
+ normalized = self._normalize_district_name(district)
717
+ if normalized:
718
+ valid_districts.append(normalized)
 
 
 
 
719
 
720
  if valid_districts:
721
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
 
724
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
725
  extracted_district = None
726
  else:
727
+ # Single district validation with fuzzy matching
728
+ normalized = self._normalize_district_name(extracted_district)
729
+ if normalized:
730
+ if normalized != extracted_district:
731
+ logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
732
+ extracted_district = normalized
733
+ else:
734
+ logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
735
+ extracted_district = None
 
736
 
737
  # Validate source (handle both single values and arrays)
738
  if extracted_source:
 
970
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
971
 
972
  # Merge with extracted context for missing filters
973
+ if not filters.get("district") and context.extracted_district:
974
+ # Normalize district names using the normalization function
975
+ if isinstance(context.extracted_district, list):
976
+ normalized_districts = []
977
+ for d in context.extracted_district:
978
+ normalized = self._normalize_district_name(d)
979
+ if normalized:
980
+ normalized_districts.append(normalized)
981
+ if normalized_districts:
982
+ filters["district"] = normalized_districts
983
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
984
+ else:
985
+ normalized = self._normalize_district_name(context.extracted_district)
986
+ if normalized:
987
+ filters["district"] = [normalized]
988
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
989
+
990
  if not filters.get("year") and context.extracted_year:
991
  # Handle both single values and arrays
992
  if isinstance(context.extracted_year, list):
 
995
  filters["year"] = [context.extracted_year]
996
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
997
 
 
 
 
 
 
 
 
 
 
 
998
  if not filters.get("sources") and context.extracted_source:
999
  # Handle both single values and arrays
1000
  if isinstance(context.extracted_source, list):
 
1022
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
1023
 
1024
  if context.extracted_district:
1025
+ # Normalize district names using the normalization function
1026
  if isinstance(context.extracted_district, list):
1027
+ normalized_districts = []
1028
+ for d in context.extracted_district:
1029
+ normalized = self._normalize_district_name(d)
1030
+ if normalized:
1031
+ normalized_districts.append(normalized)
1032
+ if normalized_districts:
1033
+ filters["district"] = normalized_districts
1034
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
1035
  else:
1036
+ normalized = self._normalize_district_name(context.extracted_district)
1037
+ if normalized:
1038
+ filters["district"] = [normalized]
1039
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
1040
 
1041
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
1042
  return filters
 
1046
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
1047
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
1048
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
1049
+ logger.info(f"πŸ’¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
1050
+
1051
+ # Build conversation history context
1052
+ conversation_context = self._build_conversation_context(messages)
1053
+
1054
+ # Build detailed document information
1055
+ document_details = self._build_document_details(documents)
1056
+
1057
+ # Extract correct district/source/year names from documents (to correct misspellings)
1058
+ correct_names = self._extract_correct_names_from_documents(documents)
1059
 
1060
  # Create response prompt
1061
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
1062
  response_prompt = ChatPromptTemplate.from_messages([
1063
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
1064
 
1065
+ CRITICAL RULES - NO HALLUCINATION:
1066
+ 1. **ONLY use information from the retrieved documents provided below**
1067
+ 2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
1068
+ 3. **If a document doesn't contain the information, DO NOT make it up**
1069
+ 4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
1070
+ 5. **Check the document years/districts before making any claims about them**
1071
+ 6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
1072
+
1073
  RULES:
1074
  1. Answer the user's question directly and clearly
1075
+ 2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
1076
  3. Be conversational, not technical
1077
  4. Don't mention scores, retrieval details, or technical implementation
1078
  5. If relevant documents were found, reference them naturally
1079
+ 6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
1080
+ 7. If the passages have useful facts or numbers, use them in your answer WITH references
1081
+ 8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
1082
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
1083
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
1084
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
1085
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1086
  13. You do not need to use every passage. Only use the ones that help answer the question.
1087
+ 14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
1088
+ 15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
1089
+ 16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
1090
 
1091
  TONE: Professional but friendly, like talking to a colleague."""),
1092
+ HumanMessage(content=f"""Conversation History:
1093
+ {conversation_context}
1094
+
1095
+ Current User Question: {query}
1096
 
1097
  Retrieved Documents: {len(documents)} documents found
1098
 
1099
+ CORRECT NAMES TO USE (from document metadata - use these exact spellings):
1100
+ {correct_names}
1101
+
1102
+ Full Document Details:
1103
+ {document_details}
1104
+
1105
  RAG Answer: {rag_answer}
1106
 
1107
+ CRITICAL:
1108
+ - Responses should be grounded to what is available in the retrieved documents
1109
+ - If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
1110
+ - Every factual claim MUST have [Doc i] reference
1111
+ - If information is not in documents, explicitly state it's not available
1112
+ - **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
1113
+
1114
+ Generate a conversational response with proper document references:""")
1115
  ])
1116
 
1117
  try:
1118
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1119
  response = self.llm.invoke(response_prompt.format_messages())
1120
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1121
+
1122
+ # Post-process response to ensure no hallucination
1123
+ final_response = self._validate_and_enhance_response(
1124
+ response.content.strip(),
1125
+ documents,
1126
+ query
1127
+ )
1128
+
1129
+ return final_response
1130
  except Exception as e:
1131
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1132
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1133
  return rag_answer # Fallback to RAG answer
1134
 
1135
+ def _build_conversation_context(self, messages: List[Any]) -> str:
1136
+ """Build conversation history context for response generation."""
1137
+ if not messages:
1138
+ return "No previous conversation."
1139
+
1140
+ context_lines = []
1141
+ # Show last 6 messages for context (to capture the current exchange)
1142
+ for msg in messages[-6:]:
1143
+ if isinstance(msg, HumanMessage):
1144
+ context_lines.append(f"User: {msg.content}")
1145
+ elif isinstance(msg, AIMessage):
1146
+ context_lines.append(f"Assistant: {msg.content}")
1147
+
1148
+ return "\n".join(context_lines) if context_lines else "No previous conversation."
1149
+
1150
+ def _build_document_details(self, documents: List[Any]) -> str:
1151
+ """Build detailed document information for response generation."""
1152
+ if not documents:
1153
+ return "No documents retrieved."
1154
+
1155
+ details = []
1156
+ for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
1157
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1158
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
1159
+
1160
+ if isinstance(metadata, dict):
1161
+ filename = metadata.get('filename', 'Unknown')
1162
+ year = metadata.get('year', 'Unknown')
1163
+ district = metadata.get('district', 'Unknown')
1164
+ source = metadata.get('source', 'Unknown')
1165
+ page = metadata.get('page', metadata.get('page_label', 'Unknown'))
1166
+
1167
+ doc_info = f"[Doc {i}]"
1168
+ doc_info += f"\n Filename: {filename}"
1169
+ doc_info += f"\n Year: {year}"
1170
+ doc_info += f"\n District: {district}"
1171
+ doc_info += f"\n Source: {source}"
1172
+ if page != 'Unknown':
1173
+ doc_info += f"\n Page: {page}"
1174
+ doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
1175
+ details.append(doc_info)
1176
+
1177
+ return "\n\n".join(details) if details else "No document details available."
1178
+
1179
+ def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
1180
+ """Extract correct district/source names from documents to correct misspellings."""
1181
+ districts = set()
1182
+ sources = set()
1183
+ years = set()
1184
+
1185
+ for doc in documents:
1186
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1187
+ if isinstance(metadata, dict):
1188
+ if metadata.get('district'):
1189
+ districts.add(str(metadata['district']))
1190
+ if metadata.get('source'):
1191
+ sources.add(str(metadata['source']))
1192
+ if metadata.get('year'):
1193
+ years.add(str(metadata['year']))
1194
+
1195
+ result = []
1196
+ if districts:
1197
+ result.append(f"Districts: {', '.join(sorted(districts))}")
1198
+ if sources:
1199
+ result.append(f"Sources: {', '.join(sorted(sources))}")
1200
+ if years:
1201
+ result.append(f"Years: {', '.join(sorted(years))}")
1202
+
1203
+ if result:
1204
+ return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
1205
+ return "No metadata available."
1206
+
1207
+ def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
1208
+ """Validate response and ensure all claims are referenced."""
1209
+ # Extract years and districts from documents
1210
+ doc_years = set()
1211
+ doc_districts = set()
1212
+ doc_sources = set()
1213
+
1214
+ for doc in documents:
1215
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1216
+ if isinstance(metadata, dict):
1217
+ if metadata.get('year'):
1218
+ doc_years.add(str(metadata['year']))
1219
+ if metadata.get('district'):
1220
+ doc_districts.add(str(metadata['district']))
1221
+ if metadata.get('source'):
1222
+ doc_sources.add(str(metadata['source']))
1223
+
1224
+ # Correct misspellings in response using correct names from documents
1225
+ # response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
1226
+
1227
+ # Check if response mentions years not in documents
1228
+ year_pattern = r'\b(20\d{2})\b'
1229
+ mentioned_years = set(re.findall(year_pattern, response))
1230
+
1231
+ # Check if user query mentions a year
1232
+ query_years = set(re.findall(year_pattern, query))
1233
+
1234
+ # If user asks about a year not in documents, add a warning
1235
+ missing_years = query_years - doc_years
1236
+ if missing_years and doc_years:
1237
+ warning = f"\n\n⚠️ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
1238
+ if warning not in response:
1239
+ response = response + warning
1240
+
1241
+ # Check if response has document references
1242
+ doc_ref_pattern = r'\[Doc\s+\d+\]'
1243
+ has_refs = bool(re.search(doc_ref_pattern, response))
1244
+
1245
+ # If response has factual claims but no references, add a note
1246
+ if not has_refs and len(documents) > 0:
1247
+ # Check if response has numbers or specific claims (simple heuristic)
1248
+ has_numbers = bool(re.search(r'\d+', response))
1249
+ if has_numbers and len(response) > 50:
1250
+ logger.warning("⚠️ Response contains factual claims but no document references")
1251
+ # Don't modify response, but log the issue
1252
+
1253
+ return response
1254
+
1255
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1256
  """Generate conversational response using only LLM knowledge and conversation history"""
1257
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
smart_chatbot.py β†’ src/agents/smart_chatbot.py RENAMED
File without changes
src/feedback/__init__.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Management Module
3
+
4
+ This module provides a unified interface for handling user feedback,
5
+ including data preparation, validation, and Snowflake storage.
6
+ """
7
+
8
+ from typing import Dict, Any, List, Optional
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+
11
+ from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
12
+ from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
13
+
14
+
15
+ class FeedbackManager:
16
+ """
17
+ Unified manager for feedback operations.
18
+
19
+ This class provides a single interface for all feedback-related functionality,
20
+ including data preparation, validation, and storage.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize the FeedbackManager"""
25
+ pass
26
+
27
+ @staticmethod
28
+ def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
29
+ """Extract transcript from messages - only user and bot messages, no extra metadata"""
30
+ transcript = []
31
+ for msg in messages:
32
+ if isinstance(msg, HumanMessage):
33
+ transcript.append({
34
+ "role": "user",
35
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
36
+ })
37
+ elif isinstance(msg, AIMessage):
38
+ transcript.append({
39
+ "role": "assistant",
40
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
41
+ })
42
+ return transcript
43
+
44
+ @staticmethod
45
+ def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
46
+ """Build retrievals structure from retrieval history"""
47
+ retrievals = []
48
+
49
+ for entry in rag_retrieval_history:
50
+ # Get the user message that triggered this retrieval
51
+ # The entry has conversation_up_to which includes messages up to that point
52
+ conversation_up_to = entry.get("conversation_up_to", [])
53
+
54
+ # Find the last user message in conversation_up_to (this is the trigger)
55
+ user_message_trigger = ""
56
+ for msg_dict in reversed(conversation_up_to):
57
+ if msg_dict.get("type") == "HumanMessage":
58
+ user_message_trigger = msg_dict.get("content", "")
59
+ break
60
+
61
+ # Fallback: if not found in conversation_up_to, get from actual messages
62
+ # This handles edge cases where conversation_up_to might be incomplete
63
+ if not user_message_trigger:
64
+ # Find which retrieval this is (0-indexed)
65
+ retrieval_idx = rag_retrieval_history.index(entry)
66
+ # The user message that triggered this retrieval is at position (retrieval_idx * 2)
67
+ # because each retrieval is preceded by: user message, bot response, user message, ...
68
+ # But we need to account for the fact that the first retrieval happens after the first user message
69
+ user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
70
+ if retrieval_idx < len(user_msgs):
71
+ user_message_trigger = str(user_msgs[retrieval_idx].content)
72
+ elif user_msgs:
73
+ # Fallback to last user message
74
+ user_message_trigger = str(user_msgs[-1].content)
75
+
76
+ # Get retrieved documents and truncate content to 100 chars
77
+ docs_retrieved = entry.get("docs_retrieved", [])
78
+ retrieved_docs = []
79
+ for doc in docs_retrieved:
80
+ doc_copy = doc.copy()
81
+ # Truncate content to 100 characters (keep all other fields)
82
+ if "content" in doc_copy:
83
+ doc_copy["content"] = doc_copy["content"][:100]
84
+ retrieved_docs.append(doc_copy)
85
+
86
+ retrievals.append({
87
+ "retrieved_docs": retrieved_docs,
88
+ "user_message_trigger": user_message_trigger
89
+ })
90
+
91
+ return retrievals
92
+
93
+ @staticmethod
94
+ def build_feedback_score_related_retrieval_docs(
95
+ is_feedback_about_last_retrieval: bool,
96
+ messages: List[Any],
97
+ rag_retrieval_history: List[Dict[str, Any]]
98
+ ) -> Optional[Dict[str, Any]]:
99
+ """Build feedback_score_related_retrieval_docs structure"""
100
+ if not rag_retrieval_history:
101
+ return None
102
+
103
+ # Get the relevant retrieval entry
104
+ if is_feedback_about_last_retrieval:
105
+ relevant_entry = rag_retrieval_history[-1]
106
+ else:
107
+ # If feedback is about all retrievals, use the last one as default
108
+ relevant_entry = rag_retrieval_history[-1]
109
+
110
+ # Get conversation up to that point
111
+ conversation_up_to = relevant_entry.get("conversation_up_to", [])
112
+
113
+ # Convert to transcript format (role/content)
114
+ conversation_up_to_point = []
115
+ for msg_dict in conversation_up_to:
116
+ if msg_dict.get("type") == "HumanMessage":
117
+ conversation_up_to_point.append({
118
+ "role": "user",
119
+ "content": msg_dict.get("content", "")
120
+ })
121
+ elif msg_dict.get("type") == "AIMessage":
122
+ conversation_up_to_point.append({
123
+ "role": "assistant",
124
+ "content": msg_dict.get("content", "")
125
+ })
126
+
127
+ # Get retrieved docs with full content (not truncated)
128
+ retrieved_docs = relevant_entry.get("docs_retrieved", [])
129
+
130
+ return {
131
+ "conversation_up_to_point": conversation_up_to_point,
132
+ "retrieved_docs": retrieved_docs
133
+ }
134
+
135
+ @staticmethod
136
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
137
+ """Create UserFeedback instance from dictionary"""
138
+ return create_feedback_from_dict(data)
139
+
140
+ @staticmethod
141
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
142
+ """Save feedback to Snowflake"""
143
+ return save_to_snowflake(feedback, table_name)
144
+
145
+ @staticmethod
146
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
147
+ """Generate Snowflake schema SQL"""
148
+ return generate_snowflake_schema_sql(table_name)
149
+
150
+
151
+ __all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
152
+
src/feedback/feedback_schema.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Schema for RAG Chatbot
3
+
4
+ This module defines dataclasses for feedback data structures
5
+ and provides Snowflake schema generation.
6
+ """
7
+ import os
8
+ from datetime import datetime
9
+ from dataclasses import dataclass, asdict, field
10
+ from typing import List, Optional, Dict, Any, Union
11
+
12
+
13
+
14
+ @dataclass
15
+ class RetrievedDocument:
16
+ """Single retrieved document metadata"""
17
+ doc_id: str
18
+ filename: str
19
+ page: int
20
+ score: float
21
+ content: str
22
+ metadata: Dict[str, Any]
23
+
24
+
25
+ @dataclass
26
+ class RetrievalEntry:
27
+ """Single retrieval operation metadata"""
28
+ rag_query: str
29
+ documents_retrieved: List[RetrievedDocument]
30
+ conversation_length: int
31
+ filters_applied: Optional[Dict[str, Any]] = None
32
+ timestamp: Optional[float] = None
33
+ _raw_data: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ @dataclass
37
+ class UserFeedback:
38
+ """User feedback submission data"""
39
+ feedback_id: str
40
+ open_ended_feedback: Optional[str]
41
+ score: int
42
+ is_feedback_about_last_retrieval: bool
43
+ conversation_id: str
44
+ timestamp: float
45
+ message_count: int
46
+ has_retrievals: bool
47
+ retrieval_count: int
48
+ transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
49
+ retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
50
+ feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
51
+ retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
52
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert to dictionary with nested data structures"""
56
+ result = asdict(self)
57
+ return result
58
+
59
+ def to_snowflake_schema(self) -> Dict[str, Any]:
60
+ """Generate Snowflake schema for this dataclass"""
61
+ schema = {
62
+ "feedback_id": "VARCHAR(255)",
63
+ "open_ended_feedback": "VARCHAR(16777216)", # Large text
64
+ "score": "INTEGER",
65
+ "is_feedback_about_last_retrieval": "BOOLEAN",
66
+ "conversation_id": "VARCHAR(255)",
67
+ "timestamp": "NUMBER(20, 0)",
68
+ "message_count": "INTEGER",
69
+ "has_retrievals": "BOOLEAN",
70
+ "retrieval_count": "INTEGER",
71
+ "transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
72
+ "retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
73
+ "feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
74
+ "retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
75
+ "created_at": "TIMESTAMP_NTZ",
76
+ # transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
77
+ # retrievals structure: [
78
+ # {
79
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
80
+ # "user_message_trigger": "final user message that triggered this retrieval"
81
+ # },
82
+ # ...
83
+ # ]
84
+ # feedback_score_related_retrieval_docs structure: {
85
+ # "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
86
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
87
+ # }
88
+ }
89
+ return schema
90
+
91
+ @classmethod
92
+ def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
93
+ """Generate CREATE TABLE SQL for Snowflake"""
94
+ schema = cls.to_snowflake_schema(None)
95
+
96
+ columns = []
97
+ for col_name, col_type in schema.items():
98
+ nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
99
+ columns.append(f" {col_name} {col_type} {nullable}")
100
+
101
+ # Build SQL string properly
102
+ columns_str = ",\n".join(columns)
103
+
104
+ sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
105
+ {columns_str},
106
+ PRIMARY KEY (feedback_id)
107
+ )
108
+ CLUSTER BY (timestamp, conversation_id, score);
109
+ -- Note: Snowflake doesn't support traditional indexes on regular tables.
110
+ -- Instead, we use CLUSTER BY to optimize queries on these columns.
111
+ -- Snowflake automatically maintains clustering for efficient querying.
112
+ -- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
113
+ -- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
114
+ """
115
+ return sql
116
+
117
+
118
+ # Snowflake variant schema for retrieved_data array
119
+ RETRIEVAL_ENTRY_SCHEMA = {
120
+ "rag_query": "VARCHAR",
121
+ "documents_retrieved": "ARRAY", # Array of document objects
122
+ "conversation_length": "INTEGER",
123
+ "filters_applied": "OBJECT",
124
+ "timestamp": "NUMBER"
125
+ }
126
+
127
+ DOCUMENT_SCHEMA = {
128
+ "doc_id": "VARCHAR",
129
+ "filename": "VARCHAR",
130
+ "page": "INTEGER",
131
+ "score": "DOUBLE",
132
+ "content": "VARCHAR(16777216)",
133
+ "metadata": "OBJECT"
134
+ }
135
+
136
+
137
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
138
+ """Generate complete Snowflake schema SQL for feedback system"""
139
+ if table_name is None:
140
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
141
+ return UserFeedback.get_snowflake_create_table_sql(table_name)
142
+
143
+
144
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
145
+ """Create UserFeedback instance from dictionary"""
146
+ return UserFeedback(
147
+ feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
148
+ open_ended_feedback=data.get("open_ended_feedback"),
149
+ score=data["score"],
150
+ is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
151
+ conversation_id=data["conversation_id"],
152
+ timestamp=data["timestamp"],
153
+ message_count=data["message_count"],
154
+ has_retrievals=data["has_retrievals"],
155
+ retrieval_count=data["retrieval_count"],
156
+ transcript=data.get("transcript", []),
157
+ retrievals=data.get("retrievals", []),
158
+ feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
159
+ retrieved_data=data.get("retrieved_data")
160
+ )
161
+
src/feedback/snowflake_connector.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snowflake Connector for Feedback System
3
+
4
+ This module handles inserting user feedback into Snowflake.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from typing import Dict, Any, Optional
11
+ from .feedback_schema import UserFeedback
12
+
13
+ # Try to import snowflake connector
14
+ try:
15
+ import snowflake.connector
16
+ SNOWFLAKE_AVAILABLE = True
17
+ except ImportError:
18
+ SNOWFLAKE_AVAILABLE = False
19
+ logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SnowflakeFeedbackConnector:
27
+ """Connector for inserting feedback into Snowflake"""
28
+
29
+ def __init__(
30
+ self,
31
+ user: str,
32
+ password: str,
33
+ account: str,
34
+ warehouse: str,
35
+ database: str = "SNOWFLAKE_LEARNING",
36
+ schema: str = "PUBLIC"
37
+ ):
38
+ self.user = user
39
+ self.password = password
40
+ self.account = account
41
+ self.warehouse = warehouse
42
+ self.database = database
43
+ self.schema = schema
44
+ self._connection = None
45
+
46
+ def connect(self):
47
+ """Establish Snowflake connection"""
48
+ if not SNOWFLAKE_AVAILABLE:
49
+ raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
50
+
51
+ logger.info("=" * 80)
52
+ logger.info("πŸ”Œ SNOWFLAKE CONNECTION: Attempting to connect...")
53
+ logger.info(f" - Account: {self.account}")
54
+ logger.info(f" - Warehouse: {self.warehouse}")
55
+ logger.info(f" - Database: {self.database}")
56
+ logger.info(f" - Schema: {self.schema}")
57
+ logger.info(f" - User: {self.user}")
58
+
59
+ try:
60
+ self._connection = snowflake.connector.connect(
61
+ user=self.user,
62
+ password=self.password,
63
+ account=self.account,
64
+ warehouse=self.warehouse
65
+ # Don't set database/schema in connection - we'll do it per query
66
+ )
67
+ logger.info("βœ… SNOWFLAKE CONNECTION: Successfully connected")
68
+ logger.info("=" * 80)
69
+ print(f"βœ… Connected to Snowflake: {self.database}.{self.schema}")
70
+ except Exception as e:
71
+ logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
72
+ logger.error("=" * 80)
73
+ print(f"❌ Failed to connect to Snowflake: {e}")
74
+ raise
75
+
76
+ def disconnect(self):
77
+ """Close Snowflake connection"""
78
+ if self._connection:
79
+ self._connection.close()
80
+ print("βœ… Disconnected from Snowflake")
81
+
82
+ def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
83
+ """Insert a single feedback record into Snowflake"""
84
+ logger.info("=" * 80)
85
+ logger.info("πŸ”„ SNOWFLAKE INSERT: Starting feedback insertion process")
86
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
87
+
88
+ # Get table name from parameter, env var, or default
89
+ if table_name is None:
90
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
91
+
92
+ if not self._connection:
93
+ logger.error("❌ Not connected to Snowflake. Call connect() first.")
94
+ raise RuntimeError("Not connected to Snowflake. Call connect() first.")
95
+
96
+ try:
97
+ logger.info("πŸ“Š VALIDATION: Validating feedback data structure...")
98
+
99
+ # Validate feedback object
100
+ validation_errors = []
101
+ if not feedback.feedback_id:
102
+ validation_errors.append("Missing feedback_id")
103
+ if feedback.score is None:
104
+ validation_errors.append("Missing score")
105
+ if feedback.timestamp is None:
106
+ validation_errors.append("Missing timestamp")
107
+
108
+ if validation_errors:
109
+ logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
110
+ return False
111
+ else:
112
+ logger.info("βœ… VALIDATION PASSED: All required fields present")
113
+
114
+ logger.info("πŸ“‹ Data Summary:")
115
+ logger.info(f" - Feedback ID: {feedback.feedback_id}")
116
+ logger.info(f" - Score: {feedback.score}")
117
+ logger.info(f" - Conversation ID: {feedback.conversation_id}")
118
+ logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
119
+ logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
120
+ logger.info(f" - Message Count: {feedback.message_count}")
121
+ logger.info(f" - Timestamp: {feedback.timestamp}")
122
+
123
+ cursor = self._connection.cursor()
124
+ logger.info("βœ… SNOWFLAKE CONNECTION: Cursor created")
125
+
126
+ # Set database and schema context
127
+ logger.info(f"πŸ”§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
128
+ try:
129
+ cursor.execute(f'USE DATABASE "{self.database}"')
130
+ cursor.execute(f'USE SCHEMA "{self.schema}"')
131
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
132
+ current_db, current_schema = cursor.fetchone()
133
+ logger.info(f"βœ… Current context verified: Database={current_db}, Schema={current_schema}")
134
+ except Exception as e:
135
+ logger.error(f"❌ Could not set context: {e}")
136
+ raise
137
+
138
+ # Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
139
+ logger.info("πŸ”§ DATA PREPARATION: Preparing VARIANT columns...")
140
+ feedback_dict = feedback.to_dict()
141
+
142
+ # Prepare transcript (ARRAY) - convert to JSON string
143
+ transcript_raw = feedback_dict.get('transcript', [])
144
+ if transcript_raw:
145
+ # Convert to JSON string (same approach as old retrieved_data)
146
+ transcript_for_db = json.dumps(transcript_raw)
147
+ logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
148
+ else:
149
+ transcript_for_db = None
150
+ logger.info(" - Transcript: None")
151
+
152
+ # Prepare retrievals (ARRAY) - convert to JSON string
153
+ retrievals_raw = feedback_dict.get('retrievals', [])
154
+ if retrievals_raw:
155
+ # Convert to JSON string (same approach as old retrieved_data)
156
+ retrievals_for_db = json.dumps(retrievals_raw)
157
+ logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
158
+ else:
159
+ retrievals_for_db = None
160
+ logger.info(" - Retrievals: None")
161
+
162
+ # Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
163
+ feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
164
+ if feedback_score_related_raw:
165
+ # Convert to JSON string (same approach as old retrieved_data)
166
+ feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
167
+ logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
168
+ else:
169
+ feedback_score_related_for_db = None
170
+ logger.info(" - Feedback score related docs: None")
171
+
172
+ # Prepare retrieved_data (preserved old column) - convert to JSON string
173
+ retrieved_data_raw = feedback_dict.get('retrieved_data')
174
+ if retrieved_data_raw:
175
+ # Convert to JSON string (same approach as old retrieved_data)
176
+ retrieved_data_for_db = json.dumps(retrieved_data_raw)
177
+ logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
178
+ else:
179
+ retrieved_data_for_db = None
180
+ logger.info(" - Retrieved data (preserved): None")
181
+
182
+ # Build SQL with new column structure
183
+ # Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
184
+ sql = f"""INSERT INTO {table_name} (
185
+ feedback_id,
186
+ open_ended_feedback,
187
+ score,
188
+ is_feedback_about_last_retrieval,
189
+ conversation_id,
190
+ timestamp,
191
+ message_count,
192
+ has_retrievals,
193
+ retrieval_count,
194
+ transcript,
195
+ retrievals,
196
+ feedback_score_related_retrieval_docs,
197
+ retrieved_data,
198
+ created_at
199
+ ) VALUES (
200
+ %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
201
+ %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
202
+ %(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
203
+ %(retrieved_data)s, %(created_at)s
204
+ )"""
205
+
206
+ logger.info("πŸ“ SQL PREPARATION: Building INSERT statement...")
207
+ logger.info(f" - Target table: {table_name}")
208
+ logger.info(f" - Database: {self.database}")
209
+ logger.info(f" - Schema: {self.schema}")
210
+
211
+ # Prepare parameters
212
+ # Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
213
+ params = {
214
+ 'feedback_id': feedback.feedback_id,
215
+ 'open_ended_feedback': feedback.open_ended_feedback,
216
+ 'score': feedback.score,
217
+ 'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
218
+ 'conversation_id': feedback.conversation_id,
219
+ 'timestamp': int(feedback.timestamp),
220
+ 'message_count': feedback.message_count,
221
+ 'has_retrievals': feedback.has_retrievals,
222
+ 'retrieval_count': feedback.retrieval_count,
223
+ 'transcript': transcript_for_db, # JSON string
224
+ 'retrievals': retrievals_for_db, # JSON string
225
+ 'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
226
+ 'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
227
+ 'created_at': feedback.created_at
228
+ }
229
+
230
+ # Execute insert
231
+ logger.info("πŸš€ SQL EXECUTION: Executing INSERT query...")
232
+ cursor.execute(sql, params)
233
+
234
+ logger.info("βœ… SQL EXECUTION: Query executed successfully")
235
+ logger.info(f" - Rows affected: 1")
236
+ logger.info(f" - Status: SUCCESS")
237
+
238
+ cursor.close()
239
+ logger.info("βœ… SNOWFLAKE INSERT: Feedback inserted successfully")
240
+ logger.info(f"πŸ“ Inserted feedback: {feedback.feedback_id}")
241
+ logger.info("=" * 80)
242
+ return True
243
+
244
+ except Exception as e:
245
+ # Check if it's a Snowflake error
246
+ if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
247
+ logger.error(f"❌ SQL EXECUTION ERROR: {e}")
248
+ logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
249
+ logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
250
+ else:
251
+ logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
252
+ logger.error(f" - Error: {e}")
253
+ logger.error("=" * 80)
254
+ return False
255
+
256
+ def __enter__(self):
257
+ """Context manager entry"""
258
+ self.connect()
259
+ return self
260
+
261
+ def __exit__(self, exc_type, exc_val, exc_tb):
262
+ """Context manager exit"""
263
+ self.disconnect()
264
+
265
+
266
+ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
267
+ """Create Snowflake connector from environment variables"""
268
+ user = os.getenv("SNOWFLAKE_USER")
269
+ password = os.getenv("SNOWFLAKE_PASSWORD")
270
+ account = os.getenv("SNOWFLAKE_ACCOUNT")
271
+ warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
272
+ database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
273
+ schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
274
+
275
+ if not all([user, password, account, warehouse]):
276
+ print("⚠️ Snowflake credentials not found in environment variables")
277
+ print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
278
+ return None
279
+
280
+ return SnowflakeFeedbackConnector(
281
+ user=user,
282
+ password=password,
283
+ account=account,
284
+ warehouse=warehouse,
285
+ database=database,
286
+ schema=schema
287
+ )
288
+
289
+
290
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
291
+ """Helper function to save feedback to Snowflake"""
292
+ logger.info("=" * 80)
293
+ logger.info("πŸ”΅ SNOWFLAKE SAVE: Starting save process")
294
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
295
+
296
+ # Get table name from parameter or env var
297
+ if table_name is None:
298
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
299
+
300
+ connector = get_snowflake_connector_from_env()
301
+
302
+ if not connector:
303
+ logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
304
+ logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
305
+ logger.info("=" * 80)
306
+ return False
307
+
308
+ try:
309
+ logger.info("πŸ“‘ SNOWFLAKE SAVE: Establishing connection...")
310
+ connector.connect()
311
+ logger.info("βœ… SNOWFLAKE SAVE: Connection established")
312
+
313
+ logger.info("πŸ“₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
314
+ success = connector.insert_feedback(feedback, table_name=table_name)
315
+
316
+ logger.info("πŸ”Œ SNOWFLAKE SAVE: Disconnecting...")
317
+ connector.disconnect()
318
+
319
+ if success:
320
+ logger.info("βœ… SNOWFLAKE SAVE: Successfully saved feedback")
321
+ else:
322
+ logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
323
+
324
+ logger.info("=" * 80)
325
+ return success
326
+ except Exception as e:
327
+ logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
328
+ logger.error(f" - Error: {e}")
329
+ logger.info("=" * 80)
330
+ return False
331
+
src/gemini/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Integration Module
3
+
4
+ This module provides integration with Google Gemini File Search API
5
+ for RAG functionality using Gemini's built-in file search capabilities.
6
+ """
7
+
8
+ from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
9
+
10
+ __all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
11
+
src/gemini/file_search.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Client
3
+
4
+ Handles interaction with Google Gemini File Search API for RAG.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from typing import List, Dict, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ try:
15
+ from google import genai
16
+ from google.genai import types
17
+ GEMINI_AVAILABLE = True
18
+ except ImportError:
19
+ GEMINI_AVAILABLE = False
20
+
21
+
22
+ @dataclass
23
+ class GeminiFileSearchResult:
24
+ """Result from Gemini File Search query"""
25
+ answer: str
26
+ sources: List[Dict[str, Any]] # List of document references
27
+ grounding_metadata: Optional[Dict[str, Any]] = None
28
+ query: str = ""
29
+
30
+
31
+ class GeminiFileSearchClient:
32
+ """Client for interacting with Gemini File Search API"""
33
+
34
+ def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
35
+ """
36
+ Initialize Gemini File Search client.
37
+
38
+ Args:
39
+ api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
40
+ store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
41
+ """
42
+ if not GEMINI_AVAILABLE:
43
+ raise ImportError("google-genai package not installed. Install with: pip install google-genai")
44
+
45
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
46
+ if not self.api_key:
47
+ raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
48
+
49
+ store_name_raw = store_name or os.getenv("GEMINI_FILESTORE_NAME")
50
+ if not store_name_raw:
51
+ raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
52
+
53
+ # Normalize store name: API expects the FULL path format (fileSearchStores/xxx)
54
+ # If just the ID is provided, construct the full path
55
+ if store_name_raw.startswith("fileSearchStores/"):
56
+ self.store_name = store_name_raw # Already full path
57
+ else:
58
+ # Just the ID provided, construct full path
59
+ self.store_name = f"fileSearchStores/{store_name_raw}"
60
+
61
+ logger.info(f"πŸ“¦ Using file search store: {self.store_name}")
62
+
63
+ self.client = genai.Client(api_key=self.api_key)
64
+ self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
65
+
66
+ def search(
67
+ self,
68
+ query: str,
69
+ filters: Optional[Dict[str, Any]] = None,
70
+ model: Optional[str] = None
71
+ ) -> GeminiFileSearchResult:
72
+ """
73
+ Search using Gemini File Search.
74
+
75
+ Args:
76
+ query: User query
77
+ filters: Optional filters (year, source, district, etc.)
78
+ model: Model to use (defaults to gemini-2.5-flash)
79
+
80
+ Returns:
81
+ GeminiFileSearchResult with answer and sources
82
+ """
83
+ model = model or self.model
84
+
85
+ # Build filter context for the query if filters are provided
86
+ # Gemini File Search doesn't support explicit filters in the API,
87
+ # so we add them as context in the query
88
+ filter_context = ""
89
+ if filters:
90
+ filter_parts = []
91
+ if filters.get("year"):
92
+ years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
93
+ filter_parts.append(f"Year: {', '.join(years)}")
94
+ if filters.get("sources"):
95
+ sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
96
+ filter_parts.append(f"Source: {', '.join(sources)}")
97
+ if filters.get("district"):
98
+ districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
99
+ filter_parts.append(f"District: {', '.join(districts)}")
100
+ if filters.get("filenames"):
101
+ filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
102
+ filter_parts.append(f"Filename: {', '.join(filenames)}")
103
+
104
+ if filter_parts:
105
+ filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
106
+
107
+ # Combine query with filter context
108
+ # Add comprehensive system instructions similar to multi-agent system
109
+ system_instructions = """You are a helpful audit report assistant specialized in analyzing government audit reports from Uganda's Office of the Auditor General.
110
+
111
+ CRITICAL RULES:
112
+ 1. **NO HALLUCINATION**: Only use information that is explicitly stated in the retrieved documents. Do not make up facts, numbers, or details.
113
+ 2. **Document References**: Always cite which documents you're using with [Doc i] references at the end of sentences that use specific information.
114
+ 3. **Formatting**: Structure your response with clear paragraphs, bullet points, or sections for readability.
115
+ 4. **Accuracy**: If the retrieved documents don't contain the requested information, explicitly state "The retrieved documents do not contain information about [topic]."
116
+ 5. **Years and Data**: Pay careful attention to years mentioned in documents. If a user asks about a specific year but documents show different years, explicitly state this.
117
+ 6. **District/Source Names**: Use the exact district and source names as they appear in the document metadata (e.g., "Kalangala" not "Kalagala").
118
+ 7. **Financial Data**: When providing financial figures, include the currency (UGX) and be precise about amounts.
119
+ 8. **Conversational Tone**: Be helpful, clear, and conversational while maintaining accuracy.
120
+
121
+ IMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents."""
122
+
123
+ # Combine system instructions with query
124
+ full_query = f"{system_instructions}\n\nUser Question: {query}{filter_context}\n\nPlease provide a detailed, well-formatted response with proper document references."
125
+
126
+ try:
127
+ # Generate content with file search
128
+ # Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
129
+ # Try with full path format first, then fallback to just ID if needed
130
+ store_name_to_try = self.store_name
131
+
132
+ try:
133
+ # Try the documented format first with full path
134
+ response = self.client.models.generate_content(
135
+ model=model,
136
+ contents=full_query,
137
+ config=types.GenerateContentConfig(
138
+ tools=[
139
+ types.Tool(
140
+ file_search=types.FileSearch(
141
+ file_search_store_names=[store_name_to_try]
142
+ )
143
+ )
144
+ ]
145
+ )
146
+ )
147
+ except Exception as api_error:
148
+ error_str = str(api_error).lower()
149
+ # If format error, try with just the ID (without fileSearchStores/ prefix)
150
+ if 'format' in error_str or 'invalid' in error_str or 'too long' in error_str:
151
+ logger.warning(f"Full path format failed, trying with just store ID: {api_error}")
152
+ # Extract just the ID part
153
+ if store_name_to_try.startswith("fileSearchStores/"):
154
+ store_id = store_name_to_try.split("/", 1)[1]
155
+ store_name_to_try = store_id
156
+
157
+ try:
158
+ response = self.client.models.generate_content(
159
+ model=model,
160
+ contents=full_query,
161
+ config=types.GenerateContentConfig(
162
+ tools=[
163
+ types.Tool(
164
+ file_search=types.FileSearch(
165
+ file_search_store_names=[store_name_to_try]
166
+ )
167
+ )
168
+ ]
169
+ )
170
+ )
171
+ except Exception as e2:
172
+ raise Exception(f"Failed to call Gemini API with both formats. Full path error: {api_error}, ID-only error: {e2}")
173
+ else:
174
+ # Try alternative dict format
175
+ logger.warning(f"Primary API format failed, trying alternative: {api_error}")
176
+ try:
177
+ response = self.client.models.generate_content(
178
+ model=model,
179
+ contents=full_query,
180
+ tools=[{
181
+ "file_search": {
182
+ "file_search_store_names": [store_name_to_try]
183
+ }
184
+ }]
185
+ )
186
+ except Exception as e2:
187
+ raise Exception(f"Failed to call Gemini API: {e2}")
188
+
189
+ # Extract answer
190
+ answer = ""
191
+ if hasattr(response, 'text'):
192
+ answer = response.text
193
+ elif hasattr(response, 'candidates') and response.candidates:
194
+ # Try to get text from first candidate
195
+ candidate = response.candidates[0]
196
+ if hasattr(candidate, 'content') and candidate.content:
197
+ if hasattr(candidate.content, 'parts'):
198
+ text_parts = []
199
+ for part in candidate.content.parts:
200
+ if hasattr(part, 'text'):
201
+ text_parts.append(part.text)
202
+ answer = " ".join(text_parts)
203
+ elif isinstance(candidate.content, str):
204
+ answer = candidate.content
205
+ else:
206
+ answer = str(response)
207
+
208
+ # Extract grounding metadata (document references)
209
+ sources = []
210
+ grounding_metadata = None
211
+
212
+ logger.info(f"πŸ” Extracting sources from Gemini response...")
213
+
214
+ if hasattr(response, 'candidates') and response.candidates:
215
+ candidate = response.candidates[0]
216
+ logger.info(f" Found candidate, checking for grounding_metadata...")
217
+
218
+ # Get grounding metadata
219
+ if hasattr(candidate, 'grounding_metadata'):
220
+ grounding_metadata = candidate.grounding_metadata
221
+ logger.info(f" Found grounding_metadata: {type(grounding_metadata)}")
222
+
223
+ # Extract source documents from grounding metadata
224
+ # Handle different response formats
225
+ grounding_chunks = None
226
+ if hasattr(grounding_metadata, 'grounding_chunks'):
227
+ grounding_chunks = grounding_metadata.grounding_chunks
228
+ logger.info(f" Found grounding_chunks (attr): {len(grounding_chunks) if grounding_chunks else 0}")
229
+ elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
230
+ grounding_chunks = grounding_metadata['grounding_chunks']
231
+ logger.info(f" Found grounding_chunks (dict): {len(grounding_chunks) if grounding_chunks else 0}")
232
+ elif hasattr(grounding_metadata, '__dict__'):
233
+ # Try to access as object attributes
234
+ metadata_dict = grounding_metadata.__dict__
235
+ if 'grounding_chunks' in metadata_dict:
236
+ grounding_chunks = metadata_dict['grounding_chunks']
237
+ logger.info(f" Found grounding_chunks (__dict__): {len(grounding_chunks) if grounding_chunks else 0}")
238
+
239
+ if grounding_chunks:
240
+ logger.info(f" Processing {len(grounding_chunks)} grounding chunks...")
241
+ for idx, chunk in enumerate(grounding_chunks):
242
+ # Handle both object and dict formats
243
+ try:
244
+ if isinstance(chunk, dict):
245
+ chunk_data = chunk
246
+ else:
247
+ # Object format - convert to dict-like access
248
+ chunk_data = {}
249
+ if hasattr(chunk, 'chunk'):
250
+ chunk_obj = chunk.chunk
251
+ chunk_data['chunk'] = {
252
+ 'text': getattr(chunk_obj, 'text', ''),
253
+ 'file_name': getattr(chunk_obj, 'file_name', '')
254
+ }
255
+ if hasattr(chunk, 'relevance_score'):
256
+ score_obj = chunk.relevance_score
257
+ chunk_data['relevance_score'] = {
258
+ 'score': getattr(score_obj, 'score', 0.0)
259
+ }
260
+
261
+ chunk_info = chunk_data.get('chunk', {})
262
+ text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
263
+ file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
264
+
265
+ # Try to extract file URI and parse metadata from it
266
+ file_uri = chunk_info.get('file_uri', '') if isinstance(chunk_info, dict) else ''
267
+
268
+ # Also check for 'web' attribute (GroundingChunkData structure)
269
+ if hasattr(chunk, 'web') and chunk.web:
270
+ web_data = chunk.web
271
+ file_uri = getattr(web_data, 'file_uri', '') or file_uri
272
+ file_name = getattr(web_data, 'title', '') or getattr(web_data, 'filename', '') or file_name
273
+ text = getattr(web_data, 'text', '') or getattr(web_data, 'content', '') or text
274
+
275
+ # Check retrieved_context - this is where the actual data seems to be!
276
+ if hasattr(chunk, 'retrieved_context') and chunk.retrieved_context:
277
+ rc = chunk.retrieved_context
278
+ # Get text content
279
+ if hasattr(rc, 'text'):
280
+ text = getattr(rc, 'text', '') or text
281
+ # Get document name
282
+ if hasattr(rc, 'document_name'):
283
+ doc_name = getattr(rc, 'document_name', '')
284
+ if doc_name:
285
+ file_name = doc_name or file_name
286
+
287
+ # Fallback: Parse from string representation if we still don't have filename
288
+ if not file_name:
289
+ chunk_str = str(chunk)
290
+ import re
291
+ # Look for PDF filenames
292
+ pdf_match = re.search(r"([A-Za-z0-9\s_-]+\.pdf)", chunk_str)
293
+ if pdf_match:
294
+ file_name = pdf_match.group(1)
295
+ # Or look for title= pattern
296
+ if not file_name and 'title=' in chunk_str:
297
+ title_match = re.search(r"title=['\"]([^'\"]+)['\"]", chunk_str)
298
+ if title_match:
299
+ file_name = title_match.group(1)
300
+
301
+ if not file_name and file_uri:
302
+ # Extract filename from URI if available
303
+ file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
304
+
305
+ score_data = chunk_data.get('relevance_score', {})
306
+ score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
307
+
308
+ if text or file_name: # Only add if we have content
309
+ source_info = {
310
+ "content": text,
311
+ "filename": file_name,
312
+ "score": score,
313
+ "file_uri": file_uri,
314
+ }
315
+ sources.append(source_info)
316
+ logger.info(f"πŸ“„ Extracted source {idx+1}: {file_name} (score: {score:.3f}, content length: {len(text)})")
317
+ except Exception as e:
318
+ logger.warning(f"Error extracting chunk {idx+1} info: {e}")
319
+ import traceback
320
+ logger.debug(traceback.format_exc())
321
+ continue
322
+ else:
323
+ logger.warning(f" No grounding_chunks found in grounding_metadata")
324
+ else:
325
+ logger.warning(f" Candidate does not have grounding_metadata attribute")
326
+
327
+ # Also try to get file references from other parts of the response
328
+ # Sometimes Gemini includes file references in the response itself
329
+ if not sources or len(sources) == 0:
330
+ logger.info(f" No sources from grounding_metadata, trying alternative extraction...")
331
+ # Check if response has file references in other attributes
332
+ if hasattr(candidate, 'content') and candidate.content:
333
+ if hasattr(candidate.content, 'parts'):
334
+ for part in candidate.content.parts:
335
+ if hasattr(part, 'file_data'):
336
+ file_data = part.file_data
337
+ if hasattr(file_data, 'file_uri') or (isinstance(file_data, dict) and 'file_uri' in file_data):
338
+ file_uri = getattr(file_data, 'file_uri', None) or (file_data.get('file_uri') if isinstance(file_data, dict) else None)
339
+ if file_uri:
340
+ file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
341
+ sources.append({
342
+ "content": "",
343
+ "filename": file_name,
344
+ "score": 0.0,
345
+ "file_uri": file_uri,
346
+ })
347
+ logger.info(f"πŸ“„ Extracted source from file_data: {file_name}")
348
+
349
+ logger.info(f"βœ… Total sources extracted: {len(sources)}")
350
+
351
+ return GeminiFileSearchResult(
352
+ answer=answer,
353
+ sources=sources,
354
+ grounding_metadata=grounding_metadata,
355
+ query=query
356
+ )
357
+
358
+ except Exception as e:
359
+ # Return error result
360
+ return GeminiFileSearchResult(
361
+ answer=f"I apologize, but I encountered an error: {str(e)}",
362
+ sources=[],
363
+ query=query
364
+ )
365
+
366
+ def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
367
+ """
368
+ Format Gemini sources to match the format expected by the UI.
369
+
370
+ Returns list of document-like objects compatible with existing display code.
371
+ """
372
+ from langchain.docstore.document import Document
373
+
374
+ formatted_sources = []
375
+
376
+ for i, source in enumerate(result.sources):
377
+ filename = source.get("filename", "Unknown")
378
+
379
+ # Try to extract metadata from filename (e.g., "Kalangala DLG Report of Auditor General 2021.pdf")
380
+ year = None
381
+ district = None
382
+ source_name = "Gemini File Search"
383
+
384
+ # Parse filename for year
385
+ import re
386
+ year_match = re.search(r'\b(20\d{2})\b', filename)
387
+ if year_match:
388
+ year = int(year_match.group(1))
389
+
390
+ # Parse filename for district/source
391
+ if "Kalangala" in filename:
392
+ district = "Kalangala"
393
+ source_name = "Kalangala DLG"
394
+ elif "Gulu" in filename:
395
+ district = "Gulu"
396
+ source_name = "Gulu DLG"
397
+ elif "KCCA" in filename:
398
+ district = "Kampala"
399
+ source_name = "KCCA"
400
+ elif "MAAIF" in filename:
401
+ source_name = "MAAIF"
402
+ elif "MWTS" in filename:
403
+ source_name = "MWTS"
404
+ elif "Consolidated" in filename:
405
+ source_name = "Consolidated"
406
+
407
+ # Create a Document object compatible with existing code
408
+ doc = Document(
409
+ page_content=source.get("content", ""),
410
+ metadata={
411
+ "filename": filename,
412
+ "source": source_name,
413
+ "score": source.get("score"),
414
+ "chunk_index": i,
415
+ "page": None, # Gemini doesn't provide page numbers
416
+ "year": year,
417
+ "district": district,
418
+ "chunk_id": f"gemini_{i}",
419
+ "_id": f"gemini_{i}",
420
+ }
421
+ )
422
+ formatted_sources.append(doc)
423
+ logger.info(f"πŸ“‹ Formatted source {i+1}: {filename} ({year}, {source_name})")
424
+
425
+ logger.info(f"βœ… Formatted {len(formatted_sources)} sources for display")
426
+ return formatted_sources
427
+
src/{loader.py β†’ llm/loader.py} RENAMED
File without changes
src/pipeline.py CHANGED
@@ -14,7 +14,7 @@ except ModuleNotFoundError as me:
14
 
15
  from .logging import log_error
16
 
17
- from .loader import chunks_to_documents
18
  from .vectorstore import VectorStoreManager
19
  from .reporting.service import ReportService
20
  from .retrieval.context import ContextRetriever
 
14
 
15
  from .logging import log_error
16
 
17
+ from .llm.loader import chunks_to_documents
18
  from .vectorstore import VectorStoreManager
19
  from .reporting.service import ReportService
20
  from .retrieval.context import ContextRetriever
src/reporting/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- """Report metadata and utilities."""
 
 
 
 
2
 
3
  from .metadata import get_report_metadata, get_available_sources
4
  from .service import ReportService
 
1
+ """Report metadata and utilities.
2
+
3
+ This module is kept for backward compatibility with pipeline.py.
4
+ For feedback-related functionality, use src.feedback instead.
5
+ """
6
 
7
  from .metadata import get_report_metadata, get_available_sources
8
  from .service import ReportService
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ui_components/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI Components Module
3
+
4
+ This module contains UI-related components including styles, visualizations,
5
+ and utility functions for the Streamlit application.
6
+ """
7
+
8
+ from .styles import get_custom_css
9
+ from .components import (
10
+ display_chunk_statistics_charts,
11
+ display_chunk_statistics_table
12
+ )
13
+ from .utils import extract_chunk_statistics
14
+
15
+ __all__ = [
16
+ "get_custom_css",
17
+ "display_chunk_statistics_charts",
18
+ "display_chunk_statistics_table",
19
+ "extract_chunk_statistics"
20
+ ]
21
+
src/ui_components/components.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI components for displaying statistics and visualizations
3
+ """
4
+
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ from typing import Dict, Any
9
+
10
+
11
+ def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
12
+ """Display statistics as interactive charts for 10+ results."""
13
+ if not stats or stats.get('total_chunks', 0) == 0:
14
+ return
15
+
16
+ # Wrap everything in one styled container - open it
17
+ st.markdown(f"""
18
+ <div class="retrieval-distribution-container">
19
+ <h3 style="margin-top: 0;">πŸ“Š {title}</h3>
20
+ <div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
21
+ <div class="metric-container">
22
+ <div class="metric-label">Total Chunks</div>
23
+ <div class="metric-value">{stats['total_chunks']}</div>
24
+ </div>
25
+ <div class="metric-container">
26
+ <div class="metric-label">Unique Sources</div>
27
+ <div class="metric-value">{stats['unique_sources']}</div>
28
+ </div>
29
+ <div class="metric-container">
30
+ <div class="metric-label">Unique Years</div>
31
+ <div class="metric-value">{stats['unique_years']}</div>
32
+ </div>
33
+ <div class="metric-container">
34
+ <div class="metric-label">Unique Files</div>
35
+ <div class="metric-value">{stats['unique_filenames']}</div>
36
+ </div>
37
+ </div>
38
+ """, unsafe_allow_html=True)
39
+
40
+ # Charts - three columns to include Districts
41
+ col1, col2, col3 = st.columns(3)
42
+
43
+ with col1:
44
+ # Source distribution chart
45
+ if stats['source_distribution']:
46
+ source_df = pd.DataFrame(
47
+ list(stats['source_distribution'].items()),
48
+ columns=['Source', 'Count']
49
+ )
50
+ fig_source = px.bar(
51
+ source_df,
52
+ x='Count',
53
+ y='Source',
54
+ orientation='h',
55
+ title='Distribution by Source',
56
+ color='Count',
57
+ color_continuous_scale='viridis'
58
+ )
59
+ fig_source.update_layout(height=400, showlegend=False)
60
+ st.plotly_chart(fig_source, use_container_width=True) # Note: plotly_chart still uses use_container_width
61
+
62
+ with col2:
63
+ # Year distribution chart
64
+ if stats['year_distribution']:
65
+ # Filter out 'Unknown' years for the chart
66
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
67
+ if year_dist_filtered:
68
+ year_df = pd.DataFrame(
69
+ list(year_dist_filtered.items()),
70
+ columns=['Year', 'Count']
71
+ )
72
+ # Sort by year as integer but keep as string for categorical display
73
+ year_df['Year_Int'] = year_df['Year'].astype(int)
74
+ year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
75
+
76
+ fig_year = px.bar(
77
+ year_df,
78
+ x='Year',
79
+ y='Count',
80
+ title='Distribution by Year',
81
+ color='Count',
82
+ color_continuous_scale='plasma'
83
+ )
84
+ # Ensure years are treated as categorical (discrete) not continuous
85
+ fig_year.update_xaxes(type='category')
86
+ fig_year.update_layout(height=400, showlegend=False)
87
+ st.plotly_chart(fig_year, use_container_width=True) # Note: plotly_chart still uses use_container_width
88
+ else:
89
+ st.info("No valid years found in the results")
90
+
91
+ with col3:
92
+ # District distribution chart
93
+ if stats.get('district_distribution'):
94
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
95
+ if district_dist_filtered:
96
+ district_df = pd.DataFrame(
97
+ list(district_dist_filtered.items()),
98
+ columns=['District', 'Count']
99
+ )
100
+ district_df = district_df.sort_values('Count', ascending=False)
101
+
102
+ fig_district = px.bar(
103
+ district_df,
104
+ x='Count',
105
+ y='District',
106
+ orientation='h',
107
+ title='Distribution by District',
108
+ color='Count',
109
+ color_continuous_scale='blues'
110
+ )
111
+ fig_district.update_layout(height=400, showlegend=False)
112
+ st.plotly_chart(fig_district, use_container_width=True) # Note: plotly_chart still uses use_container_width
113
+ else:
114
+ st.info("No valid districts found in the results")
115
+
116
+ # Close the container
117
+ st.markdown('</div>', unsafe_allow_html=True)
118
+
119
+
120
+ def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
121
+ """Display statistics as tables for smaller results with fixed alignment."""
122
+ if not stats or stats.get('total_chunks', 0) == 0:
123
+ return
124
+
125
+ # Wrap in styled container
126
+ st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
127
+
128
+ st.subheader(f"πŸ“Š {title}")
129
+
130
+ # Create a container with fixed height for alignment
131
+ stats_container = st.container()
132
+
133
+ with stats_container:
134
+ # Create 4 equal columns for consistent alignment
135
+ col1, col2, col3, col4 = st.columns(4)
136
+
137
+ with col1:
138
+ st.markdown("**🏘️ Districts**")
139
+ if stats.get('district_distribution'):
140
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
141
+ if district_dist_filtered:
142
+ district_data = {
143
+ "District": list(district_dist_filtered.keys()),
144
+ "Count": list(district_dist_filtered.values())
145
+ }
146
+ district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
147
+ st.dataframe(district_df, hide_index=True, width='stretch')
148
+ else:
149
+ st.write("No district data")
150
+ else:
151
+ st.write("No district data")
152
+
153
+ with col2:
154
+ st.markdown("**πŸ“‚ Sources**")
155
+ if stats['source_distribution']:
156
+ source_data = {
157
+ "Source": list(stats['source_distribution'].keys()),
158
+ "Count": list(stats['source_distribution'].values())
159
+ }
160
+ source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
161
+ st.dataframe(source_df, hide_index=True, width='stretch')
162
+ else:
163
+ st.write("No source data")
164
+
165
+ with col3:
166
+ st.markdown("**πŸ“… Years**")
167
+ if stats['year_distribution']:
168
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
169
+ if year_dist_filtered:
170
+ year_data = {
171
+ "Year": list(year_dist_filtered.keys()),
172
+ "Count": list(year_dist_filtered.values())
173
+ }
174
+ year_df = pd.DataFrame(year_data)
175
+ # Sort by year as integer but display as string
176
+ year_df['Year_Int'] = year_df['Year'].astype(int)
177
+ year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
178
+ st.dataframe(year_df, hide_index=True, width='stretch')
179
+ else:
180
+ st.write("No year data")
181
+ else:
182
+ st.write("No year data")
183
+
184
+ with col4:
185
+ st.markdown("**πŸ“„ Files**")
186
+ if stats['filename_distribution']:
187
+ filename_items = list(stats['filename_distribution'].items())
188
+ filename_items.sort(key=lambda x: x[1], reverse=True)
189
+
190
+ # Show top files with truncated names
191
+ file_data = {
192
+ "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
193
+ "Count": [c for f, c in filename_items[:5]]
194
+ }
195
+ file_df = pd.DataFrame(file_data)
196
+ st.dataframe(file_df, hide_index=True, width='stretch')
197
+ else:
198
+ st.write("No file data")
199
+
200
+ # Close container
201
+ st.markdown('</div>', unsafe_allow_html=True)
202
+
src/ui_components/styles.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom CSS styles for Streamlit application
3
+ """
4
+
5
+
6
+ def get_custom_css() -> str:
7
+ """Get custom CSS styles as a string"""
8
+ return """
9
+ <style>
10
+ .main-header {
11
+ font-size: 2.5rem;
12
+ font-weight: bold;
13
+ color: #1f77b4;
14
+ text-align: center;
15
+ margin-bottom: 1rem;
16
+ width: 100%;
17
+ display: block;
18
+ }
19
+
20
+ .subtitle {
21
+ font-size: 1.2rem;
22
+ color: #666;
23
+ text-align: center;
24
+ margin-bottom: 2rem;
25
+ width: 100%;
26
+ display: block;
27
+ }
28
+
29
+ .session-info {
30
+ background-color: #f0f2f6;
31
+ padding: 10px;
32
+ border-radius: 5px;
33
+ margin-bottom: 20px;
34
+ font-size: 0.9rem;
35
+ }
36
+
37
+ .user-message {
38
+ background-color: #007bff;
39
+ color: white;
40
+ padding: 12px 16px;
41
+ border-radius: 18px 18px 4px 18px;
42
+ margin: 8px 0;
43
+ margin-left: 20%;
44
+ word-wrap: break-word;
45
+ }
46
+
47
+ .bot-message {
48
+ background-color: #f1f3f4;
49
+ color: #333;
50
+ padding: 12px 16px;
51
+ border-radius: 18px 18px 18px 4px;
52
+ margin: 8px 0;
53
+ margin-right: 20%;
54
+ word-wrap: break-word;
55
+ border: 1px solid #e0e0e0;
56
+ }
57
+
58
+ .filter-section {
59
+ margin-bottom: 20px;
60
+ padding: 15px;
61
+ background-color: #f8f9fa;
62
+ border-radius: 8px;
63
+ border: 1px solid #e9ecef;
64
+ }
65
+
66
+ .filter-title {
67
+ font-weight: bold;
68
+ margin-bottom: 10px;
69
+ color: #495057;
70
+ }
71
+
72
+ .feedback-section {
73
+ background-color: #f8f9fa;
74
+ padding: 20px;
75
+ border-radius: 10px;
76
+ margin-top: 30px;
77
+ border: 2px solid #dee2e6;
78
+ }
79
+
80
+ .retrieval-history {
81
+ background-color: #ffffff;
82
+ padding: 15px;
83
+ border-radius: 5px;
84
+ margin: 10px 0;
85
+ border-left: 4px solid #007bff;
86
+ }
87
+
88
+ .retrieval-distribution-container {
89
+ background-color: #ffffff;
90
+ padding: 25px;
91
+ border-radius: 10px;
92
+ margin: 20px 0;
93
+ border: 2px solid #e0e0e0;
94
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
95
+ }
96
+
97
+ .metric-label {
98
+ font-size: 0.9rem;
99
+ color: #555;
100
+ margin-bottom: 5px;
101
+ text-align: center;
102
+ }
103
+
104
+ .metric-value {
105
+ font-size: 1.8rem;
106
+ font-weight: bold;
107
+ color: #000000;
108
+ text-align: center;
109
+ }
110
+
111
+ .metric-container {
112
+ text-align: center;
113
+ padding: 10px;
114
+ }
115
+ </style>
116
+ """
117
+
src/ui_components/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI utility functions for data processing and statistics
3
+ """
4
+
5
+ from typing import Dict, Any, List
6
+ from collections import Counter
7
+
8
+
9
+ def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
10
+ """Extract statistics from retrieved chunks."""
11
+ if not sources:
12
+ return {}
13
+
14
+ sources_list = []
15
+ years = []
16
+ filenames = []
17
+ districts = []
18
+
19
+ for doc in sources:
20
+ metadata = getattr(doc, 'metadata', {})
21
+
22
+ # Extract source
23
+ source = metadata.get('source', 'Unknown')
24
+ sources_list.append(source)
25
+
26
+ # Extract year
27
+ year = metadata.get('year', 'Unknown')
28
+ if year and year != 'Unknown':
29
+ try:
30
+ # Convert to int first, then back to string to ensure it's a proper year
31
+ year_int = int(float(year)) # Handle both int and float strings
32
+ if 1900 <= year_int <= 2030: # Reasonable year range
33
+ years.append(str(year_int))
34
+ else:
35
+ years.append('Unknown')
36
+ except (ValueError, TypeError):
37
+ years.append('Unknown')
38
+ else:
39
+ years.append('Unknown')
40
+
41
+ # Extract filename
42
+ filename = metadata.get('filename', 'Unknown')
43
+ filenames.append(filename)
44
+
45
+ # Extract district
46
+ district = metadata.get('district', 'Unknown')
47
+ if district and district != 'Unknown':
48
+ districts.append(district)
49
+ else:
50
+ districts.append('Unknown')
51
+
52
+ # Count occurrences
53
+ source_counts = Counter(sources_list)
54
+ year_counts = Counter(years)
55
+ filename_counts = Counter(filenames)
56
+ district_counts = Counter(districts)
57
+
58
+ return {
59
+ 'total_chunks': len(sources),
60
+ 'unique_sources': len(source_counts),
61
+ 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
62
+ 'unique_filenames': len(filename_counts),
63
+ 'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
64
+ 'source_distribution': dict(source_counts),
65
+ 'year_distribution': dict(year_counts),
66
+ 'filename_distribution': dict(filename_counts),
67
+ 'district_distribution': dict(district_counts),
68
+ 'sources': sources_list,
69
+ 'years': years,
70
+ 'filenames': filenames,
71
+ 'districts': districts
72
+ }
73
+
utils.py β†’ src/utils.py RENAMED
File without changes
src/vectorstore.py CHANGED
@@ -1,9 +1,20 @@
1
  """Vector store management and operations."""
 
 
 
 
 
2
  from pathlib import Path
3
  from typing import Dict, Any, List, Optional
4
 
5
 
6
  import torch
 
 
 
 
 
 
7
  from langchain_qdrant import QdrantVectorStore
8
  from langchain.docstore.document import Document
9
  from langchain_core.embeddings import Embeddings
@@ -28,11 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
28
 
29
  if truncate_dim and "matryoshka" in model_name.lower():
30
  # Use SentenceTransformer directly for Matryoshka models
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
 
 
 
 
 
 
33
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
34
  else:
35
  # Use standard HuggingFaceEmbeddings
 
 
 
 
 
 
36
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
37
 
38
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -76,12 +99,17 @@ class VectorStoreManager:
76
 
77
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
78
  """Create embeddings model from configuration."""
79
- device = "cuda" if torch.cuda.is_available() else "cpu"
80
-
81
  model_name = self.config["retriever"]["model"]
82
  normalize = self.config["retriever"]["normalize"]
83
 
84
- model_kwargs = {"device": device}
 
 
 
 
 
 
 
85
  encode_kwargs = {
86
  "normalize_embeddings": normalize,
87
  "batch_size": 100,
@@ -108,6 +136,8 @@ class VectorStoreManager:
108
  return embeddings
109
 
110
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
 
 
111
  embeddings = HuggingFaceEmbeddings(
112
  model_name=model_name,
113
  model_kwargs=model_kwargs,
 
1
  """Vector store management and operations."""
2
+ import os
3
+ # Disable MPS before importing torch to prevent meta tensor issues on Mac
4
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
5
+ os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
6
+
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional
9
 
10
 
11
  import torch
12
+ # Disable MPS backend explicitly to prevent meta tensor issues
13
+ if hasattr(torch.backends, 'mps'):
14
+ # Monkey patch to disable MPS
15
+ original_mps_available = torch.backends.mps.is_available
16
+ torch.backends.mps.is_available = lambda: False
17
+
18
  from langchain_qdrant import QdrantVectorStore
19
  from langchain.docstore.document import Document
20
  from langchain_core.embeddings import Embeddings
 
39
 
40
  if truncate_dim and "matryoshka" in model_name.lower():
41
  # Use SentenceTransformer directly for Matryoshka models
42
+ # Fix for meta tensor issue: Explicitly force CPU
43
+ # MPS is already disabled at module level
44
+ # Explicitly pass device="cpu" to prevent MPS/CUDA detection
45
+ self.model = SentenceTransformer(
46
+ model_name,
47
+ truncate_dim=truncate_dim,
48
+ device="cpu" # Force CPU to prevent meta tensor issues
49
+ )
50
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
51
  else:
52
  # Use standard HuggingFaceEmbeddings
53
+ # Don't pass device parameter - let it load naturally on CPU
54
+ # This prevents the meta tensor error
55
+ if "model_kwargs" not in kwargs:
56
+ kwargs["model_kwargs"] = {}
57
+ # Remove device from model_kwargs if present to prevent meta tensor issues
58
+ kwargs["model_kwargs"].pop("device", None)
59
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
60
 
61
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
99
 
100
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
101
  """Create embeddings model from configuration."""
 
 
102
  model_name = self.config["retriever"]["model"]
103
  normalize = self.config["retriever"]["normalize"]
104
 
105
+ # Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
106
+ # The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
107
+ # MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
108
+ model_kwargs = {
109
+ "device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
110
+ "trust_remote_code": True, # Some models need this
111
+ }
112
+
113
  encode_kwargs = {
114
  "normalize_embeddings": normalize,
115
  "batch_size": 100,
 
136
  return embeddings
137
 
138
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
139
+ # Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
140
+ # but ensure we're not using meta device
141
  embeddings = HuggingFaceEmbeddings(
142
  model_name=model_name,
143
  model_kwargs=model_kwargs,