Ara Yeroyan commited on
Commit
72eb0bf
Β·
1 Parent(s): b4984e2

refactor + add gemini

Browse files
app.py CHANGED
@@ -10,6 +10,7 @@ import uuid
10
  import logging
11
  import traceback
12
  from pathlib import Path
 
13
  from collections import Counter
14
  from typing import List, Dict, Any, Optional
15
 
@@ -19,10 +20,11 @@ import streamlit as st
19
  import plotly.express as px
20
  from langchain_core.messages import HumanMessage, AIMessage
21
 
22
- from multi_agent_chatbot import get_multi_agent_chatbot
23
- from smart_chatbot import get_chatbot as get_smart_chatbot
24
- from src.reporting.snowflake_connector import save_to_snowflake
25
- from src.reporting.feedback_schema import create_feedback_from_dict
 
26
  from src.config.paths import (
27
  IS_DEPLOYED,
28
  PROJECT_DIR,
@@ -90,116 +92,9 @@ st.set_page_config(
90
  page_title="Intelligent Audit Report Chatbot"
91
  )
92
 
93
- # Custom CSS
94
- st.markdown("""
95
- <style>
96
- .main-header {
97
- font-size: 2.5rem;
98
- font-weight: bold;
99
- color: #1f77b4;
100
- text-align: center;
101
- margin-bottom: 1rem;
102
- width: 100%;
103
- display: block;
104
- }
105
-
106
- .subtitle {
107
- font-size: 1.2rem;
108
- color: #666;
109
- text-align: center;
110
- margin-bottom: 2rem;
111
- width: 100%;
112
- display: block;
113
- }
114
-
115
- .session-info {
116
- background-color: #f0f2f6;
117
- padding: 10px;
118
- border-radius: 5px;
119
- margin-bottom: 20px;
120
- font-size: 0.9rem;
121
- }
122
-
123
- .user-message {
124
- background-color: #007bff;
125
- color: white;
126
- padding: 12px 16px;
127
- border-radius: 18px 18px 4px 18px;
128
- margin: 8px 0;
129
- margin-left: 20%;
130
- word-wrap: break-word;
131
- }
132
-
133
- .bot-message {
134
- background-color: #f1f3f4;
135
- color: #333;
136
- padding: 12px 16px;
137
- border-radius: 18px 18px 18px 4px;
138
- margin: 8px 0;
139
- margin-right: 20%;
140
- word-wrap: break-word;
141
- border: 1px solid #e0e0e0;
142
- }
143
-
144
- .filter-section {
145
- margin-bottom: 20px;
146
- padding: 15px;
147
- background-color: #f8f9fa;
148
- border-radius: 8px;
149
- border: 1px solid #e9ecef;
150
- }
151
-
152
- .filter-title {
153
- font-weight: bold;
154
- margin-bottom: 10px;
155
- color: #495057;
156
- }
157
-
158
- .feedback-section {
159
- background-color: #f8f9fa;
160
- padding: 20px;
161
- border-radius: 10px;
162
- margin-top: 30px;
163
- border: 2px solid #dee2e6;
164
- }
165
-
166
- .retrieval-history {
167
- background-color: #ffffff;
168
- padding: 15px;
169
- border-radius: 5px;
170
- margin: 10px 0;
171
- border-left: 4px solid #007bff;
172
- }
173
-
174
- .retrieval-distribution-container {
175
- background-color: #ffffff;
176
- padding: 25px;
177
- border-radius: 10px;
178
- margin: 20px 0;
179
- border: 2px solid #e0e0e0;
180
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
181
- }
182
-
183
- .metric-label {
184
- font-size: 0.9rem;
185
- color: #555;
186
- margin-bottom: 5px;
187
- text-align: center;
188
- }
189
-
190
- .metric-value {
191
- font-size: 1.8rem;
192
- font-weight: bold;
193
- color: #000000;
194
- text-align: center;
195
- }
196
-
197
- .metric-container {
198
- text-align: center;
199
- padding: 10px;
200
- }
201
- </style>
202
- """, unsafe_allow_html=True)
203
 
204
  def get_system_type():
205
  """Get the current system type"""
@@ -209,14 +104,17 @@ def get_system_type():
209
  else:
210
  return "Multi-Agent System"
211
 
212
- def get_chatbot():
213
- """Initialize and return the chatbot based on system type"""
214
- # Check environment variable for system type
215
- system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
216
- if system == 'smart':
217
- return get_smart_chatbot()
218
  else:
219
- return get_multi_agent_chatbot()
 
 
 
 
 
220
 
221
  def serialize_messages(messages):
222
  """Serialize LangChain messages to dictionaries"""
@@ -262,372 +160,9 @@ def serialize_documents(sources):
262
  return serialized
263
 
264
 
265
- def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
266
- """Extract transcript from messages - only user and bot messages, no extra metadata"""
267
- transcript = []
268
- for msg in messages:
269
- if isinstance(msg, HumanMessage):
270
- transcript.append({
271
- "role": "user",
272
- "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
273
- })
274
- elif isinstance(msg, AIMessage):
275
- transcript.append({
276
- "role": "assistant",
277
- "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
278
- })
279
- return transcript
280
 
281
 
282
- def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
283
- """Build retrievals structure from retrieval history"""
284
- retrievals = []
285
-
286
- for entry in rag_retrieval_history:
287
- # Get the user message that triggered this retrieval
288
- # The entry has conversation_up_to which includes messages up to that point
289
- conversation_up_to = entry.get("conversation_up_to", [])
290
-
291
- # Find the last user message in conversation_up_to (this is the trigger)
292
- user_message_trigger = ""
293
- for msg_dict in reversed(conversation_up_to):
294
- if msg_dict.get("type") == "HumanMessage":
295
- user_message_trigger = msg_dict.get("content", "")
296
- break
297
-
298
- # Fallback: if not found in conversation_up_to, get from actual messages
299
- # This handles edge cases where conversation_up_to might be incomplete
300
- if not user_message_trigger:
301
- # Find which retrieval this is (0-indexed)
302
- retrieval_idx = rag_retrieval_history.index(entry)
303
- # The user message that triggered this retrieval is at position (retrieval_idx * 2)
304
- # because each retrieval is preceded by: user message, bot response, user message, ...
305
- # But we need to account for the fact that the first retrieval happens after the first user message
306
- user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
307
- if retrieval_idx < len(user_msgs):
308
- user_message_trigger = str(user_msgs[retrieval_idx].content)
309
- elif user_msgs:
310
- # Fallback to last user message
311
- user_message_trigger = str(user_msgs[-1].content)
312
-
313
- # Get retrieved documents and truncate content to 100 chars
314
- docs_retrieved = entry.get("docs_retrieved", [])
315
- retrieved_docs = []
316
- for doc in docs_retrieved:
317
- doc_copy = doc.copy()
318
- # Truncate content to 100 characters (keep all other fields)
319
- if "content" in doc_copy:
320
- doc_copy["content"] = doc_copy["content"][:100]
321
- retrieved_docs.append(doc_copy)
322
-
323
- retrievals.append({
324
- "retrieved_docs": retrieved_docs,
325
- "user_message_trigger": user_message_trigger
326
- })
327
-
328
- return retrievals
329
-
330
-
331
- def build_feedback_score_related_retrieval_docs(
332
- is_feedback_about_last_retrieval: bool,
333
- messages: List[Any],
334
- rag_retrieval_history: List[Dict[str, Any]]
335
- ) -> Optional[Dict[str, Any]]:
336
- """Build feedback_score_related_retrieval_docs structure"""
337
- if not rag_retrieval_history:
338
- return None
339
-
340
- # Get the relevant retrieval entry
341
- if is_feedback_about_last_retrieval:
342
- relevant_entry = rag_retrieval_history[-1]
343
- else:
344
- # If feedback is about all retrievals, use the last one as default
345
- relevant_entry = rag_retrieval_history[-1]
346
-
347
- # Get conversation up to that point
348
- conversation_up_to = relevant_entry.get("conversation_up_to", [])
349
-
350
- # Convert to transcript format (role/content)
351
- conversation_up_to_point = []
352
- for msg_dict in conversation_up_to:
353
- if msg_dict.get("type") == "HumanMessage":
354
- conversation_up_to_point.append({
355
- "role": "user",
356
- "content": msg_dict.get("content", "")
357
- })
358
- elif msg_dict.get("type") == "AIMessage":
359
- conversation_up_to_point.append({
360
- "role": "assistant",
361
- "content": msg_dict.get("content", "")
362
- })
363
-
364
- # Get retrieved docs with full content (not truncated)
365
- retrieved_docs = relevant_entry.get("docs_retrieved", [])
366
-
367
- return {
368
- "conversation_up_to_point": conversation_up_to_point,
369
- "retrieved_docs": retrieved_docs
370
- }
371
-
372
- def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
373
- """Extract statistics from retrieved chunks."""
374
- if not sources:
375
- return {}
376
-
377
- sources_list = []
378
- years = []
379
- filenames = []
380
- districts = []
381
-
382
- for doc in sources:
383
- metadata = getattr(doc, 'metadata', {})
384
-
385
- # Extract source
386
- source = metadata.get('source', 'Unknown')
387
- sources_list.append(source)
388
-
389
- # Extract year
390
- year = metadata.get('year', 'Unknown')
391
- if year and year != 'Unknown':
392
- try:
393
- # Convert to int first, then back to string to ensure it's a proper year
394
- year_int = int(float(year)) # Handle both int and float strings
395
- if 1900 <= year_int <= 2030: # Reasonable year range
396
- years.append(str(year_int))
397
- else:
398
- years.append('Unknown')
399
- except (ValueError, TypeError):
400
- years.append('Unknown')
401
- else:
402
- years.append('Unknown')
403
-
404
- # Extract filename
405
- filename = metadata.get('filename', 'Unknown')
406
- filenames.append(filename)
407
-
408
- # Extract district
409
- district = metadata.get('district', 'Unknown')
410
- if district and district != 'Unknown':
411
- districts.append(district)
412
- else:
413
- districts.append('Unknown')
414
-
415
- # Count occurrences
416
- source_counts = Counter(sources_list)
417
- year_counts = Counter(years)
418
- filename_counts = Counter(filenames)
419
- district_counts = Counter(districts)
420
-
421
- return {
422
- 'total_chunks': len(sources),
423
- 'unique_sources': len(source_counts),
424
- 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
425
- 'unique_filenames': len(filename_counts),
426
- 'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
427
- 'source_distribution': dict(source_counts),
428
- 'year_distribution': dict(year_counts),
429
- 'filename_distribution': dict(filename_counts),
430
- 'district_distribution': dict(district_counts),
431
- 'sources': sources_list,
432
- 'years': years,
433
- 'filenames': filenames,
434
- 'districts': districts
435
- }
436
-
437
- def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
438
- """Display statistics as interactive charts for 10+ results."""
439
- if not stats or stats.get('total_chunks', 0) == 0:
440
- return
441
-
442
- # Wrap everything in one styled container - open it
443
- st.markdown(f"""
444
- <div class="retrieval-distribution-container">
445
- <h3 style="margin-top: 0;">πŸ“Š {title}</h3>
446
- <div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
447
- <div class="metric-container">
448
- <div class="metric-label">Total Chunks</div>
449
- <div class="metric-value">{stats['total_chunks']}</div>
450
- </div>
451
- <div class="metric-container">
452
- <div class="metric-label">Unique Sources</div>
453
- <div class="metric-value">{stats['unique_sources']}</div>
454
- </div>
455
- <div class="metric-container">
456
- <div class="metric-label">Unique Years</div>
457
- <div class="metric-value">{stats['unique_years']}</div>
458
- </div>
459
- <div class="metric-container">
460
- <div class="metric-label">Unique Files</div>
461
- <div class="metric-value">{stats['unique_filenames']}</div>
462
- </div>
463
- </div>
464
- """, unsafe_allow_html=True)
465
-
466
- # Charts - three columns to include Districts
467
- col1, col2, col3 = st.columns(3)
468
-
469
- with col1:
470
- # Source distribution chart
471
- if stats['source_distribution']:
472
- source_df = pd.DataFrame(
473
- list(stats['source_distribution'].items()),
474
- columns=['Source', 'Count']
475
- )
476
- fig_source = px.bar(
477
- source_df,
478
- x='Count',
479
- y='Source',
480
- orientation='h',
481
- title='Distribution by Source',
482
- color='Count',
483
- color_continuous_scale='viridis'
484
- )
485
- fig_source.update_layout(height=400, showlegend=False)
486
- st.plotly_chart(fig_source, use_container_width=True)
487
-
488
- with col2:
489
- # Year distribution chart
490
- if stats['year_distribution']:
491
- # Filter out 'Unknown' years for the chart
492
- year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
493
- if year_dist_filtered:
494
- year_df = pd.DataFrame(
495
- list(year_dist_filtered.items()),
496
- columns=['Year', 'Count']
497
- )
498
- # Sort by year as integer but keep as string for categorical display
499
- year_df['Year_Int'] = year_df['Year'].astype(int)
500
- year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
501
-
502
- fig_year = px.bar(
503
- year_df,
504
- x='Year',
505
- y='Count',
506
- title='Distribution by Year',
507
- color='Count',
508
- color_continuous_scale='plasma'
509
- )
510
- # Ensure years are treated as categorical (discrete) not continuous
511
- fig_year.update_xaxes(type='category')
512
- fig_year.update_layout(height=400, showlegend=False)
513
- st.plotly_chart(fig_year, use_container_width=True)
514
- else:
515
- st.info("No valid years found in the results")
516
-
517
- with col3:
518
- # District distribution chart
519
- if stats.get('district_distribution'):
520
- district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
521
- if district_dist_filtered:
522
- district_df = pd.DataFrame(
523
- list(district_dist_filtered.items()),
524
- columns=['District', 'Count']
525
- )
526
- district_df = district_df.sort_values('Count', ascending=False)
527
-
528
- fig_district = px.bar(
529
- district_df,
530
- x='Count',
531
- y='District',
532
- orientation='h',
533
- title='Distribution by District',
534
- color='Count',
535
- color_continuous_scale='blues'
536
- )
537
- fig_district.update_layout(height=400, showlegend=False)
538
- st.plotly_chart(fig_district, use_container_width=True)
539
- else:
540
- st.info("No valid districts found in the results")
541
-
542
- # Close the container
543
- st.markdown('</div>', unsafe_allow_html=True)
544
-
545
- def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
546
- """Display statistics as tables for smaller results with fixed alignment."""
547
- if not stats or stats.get('total_chunks', 0) == 0:
548
- return
549
-
550
- # Wrap in styled container
551
-
552
- st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
553
- # st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
554
-
555
-
556
- st.subheader(f"πŸ“Š {title}")
557
-
558
- # Create a container with fixed height for alignment
559
- stats_container = st.container()
560
-
561
- with stats_container:
562
- # Create 4 equal columns for consistent alignment
563
- col1, col2, col3, col4 = st.columns(4)
564
-
565
- with col1:
566
- st.markdown("**🏘️ Districts**")
567
- if stats.get('district_distribution'):
568
- district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
569
- if district_dist_filtered:
570
- district_data = {
571
- "District": list(district_dist_filtered.keys()),
572
- "Count": list(district_dist_filtered.values())
573
- }
574
- district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
575
- st.dataframe(district_df, hide_index=True, use_container_width=True)
576
- else:
577
- st.write("No district data")
578
- else:
579
- st.write("No district data")
580
-
581
- with col2:
582
- st.markdown("**πŸ“‚ Sources**")
583
- if stats['source_distribution']:
584
- source_data = {
585
- "Source": list(stats['source_distribution'].keys()),
586
- "Count": list(stats['source_distribution'].values())
587
- }
588
- source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
589
- st.dataframe(source_df, hide_index=True, use_container_width=True)
590
- else:
591
- st.write("No source data")
592
-
593
- with col3:
594
- st.markdown("**πŸ“… Years**")
595
- if stats['year_distribution']:
596
- year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
597
- if year_dist_filtered:
598
- year_data = {
599
- "Year": list(year_dist_filtered.keys()),
600
- "Count": list(year_dist_filtered.values())
601
- }
602
- year_df = pd.DataFrame(year_data)
603
- # Sort by year as integer but display as string
604
- year_df['Year_Int'] = year_df['Year'].astype(int)
605
- year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
606
- st.dataframe(year_df, hide_index=True, use_container_width=True)
607
- else:
608
- st.write("No year data")
609
- else:
610
- st.write("No year data")
611
-
612
- with col4:
613
- st.markdown("**πŸ“„ Files**")
614
- if stats['filename_distribution']:
615
- filename_items = list(stats['filename_distribution'].items())
616
- filename_items.sort(key=lambda x: x[1], reverse=True)
617
-
618
- # Show top files with truncated names
619
- file_data = {
620
- "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
621
- "Count": [c for f, c in filename_items[:5]]
622
- }
623
- file_df = pd.DataFrame(file_data)
624
- st.dataframe(file_df, hide_index=True, use_container_width=True)
625
- else:
626
- st.write("No file data")
627
-
628
- # Close container
629
- st.markdown('</div>', unsafe_allow_html=True)
630
-
631
  @st.cache_data
632
  def load_filter_options():
633
  try:
@@ -652,11 +187,30 @@ def main():
652
  # Track RAG retrieval history for feedback
653
  if 'rag_retrieval_history' not in st.session_state:
654
  st.session_state.rag_retrieval_history = []
655
- # Initialize chatbot only once per app session (cached)
656
- if 'chatbot' not in st.session_state:
657
- with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
658
- st.session_state.chatbot = get_chatbot()
659
- st.success("βœ… AI system ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
  # Reset conversation history if needed (but keep chatbot cached)
662
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
@@ -668,10 +222,13 @@ def main():
668
  st.session_state.reset_conversation = False
669
  st.rerun()
670
 
671
- # Header - fully center aligned
672
- st.markdown('<h1 class="main-header">πŸ€– Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
673
  st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
674
 
 
 
 
 
675
  # Session info
676
  duration = int(time.time() - st.session_state.session_start_time)
677
  duration_str = f"{duration // 60}m {duration % 60}s"
@@ -829,7 +386,7 @@ def main():
829
  )
830
 
831
  with col2:
832
- send_button = st.button("Send", key="send_button", use_container_width=True)
833
 
834
  # Clear chat button
835
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
@@ -890,10 +447,20 @@ def main():
890
  else:
891
  formatted_query = "No RAG query available"
892
 
 
 
 
 
 
 
 
 
893
  retrieval_entry = {
894
  "conversation_up_to": serialize_messages(st.session_state.messages),
895
  "rag_query_expansion": formatted_query,
896
- "docs_retrieved": serialize_documents(sources)
 
 
897
  }
898
  st.session_state.rag_retrieval_history.append(retrieval_entry)
899
  else:
@@ -954,9 +521,18 @@ def main():
954
  for i, doc in enumerate(sources): # Show all documents
955
  # Get relevance score and ID if available
956
  metadata = getattr(doc, 'metadata', {})
957
- score = metadata.get('reranked_score', metadata.get('original_score', None))
958
- chunk_id = metadata.get('_id', 'Unknown')
959
- score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
 
 
 
 
 
 
 
 
 
960
 
961
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
962
  # Display document metadata with emojis
@@ -1034,7 +610,7 @@ def main():
1034
 
1035
  submitted = st.form_submit_button(
1036
  "πŸ“€ Submit Feedback",
1037
- use_container_width=True,
1038
  disabled=submit_disabled
1039
  )
1040
 
@@ -1046,16 +622,18 @@ def main():
1046
  st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
1047
 
1048
  # Extract transcript from messages
1049
- transcript = extract_transcript(st.session_state.messages)
1050
 
1051
  # Build retrievals structure
1052
- retrievals = build_retrievals_structure(
 
1053
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
1054
  st.session_state.messages
1055
  )
1056
 
1057
  # Build feedback_score_related_retrieval_docs
1058
- feedback_score_related_retrieval_docs = build_feedback_score_related_retrieval_docs(
 
1059
  is_feedback_about_last_retrieval,
1060
  st.session_state.messages,
1061
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
@@ -1085,7 +663,7 @@ def main():
1085
  # Create UserFeedback dataclass instance
1086
  feedback_obj = None # Initialize outside try block
1087
  try:
1088
- feedback_obj = create_feedback_from_dict(feedback_dict)
1089
  print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
1090
  st.write(f"βœ… **Feedback Object Created**")
1091
  st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
@@ -1141,7 +719,7 @@ def main():
1141
  logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
1142
  print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
1143
 
1144
- snowflake_success = save_to_snowflake(feedback_obj)
1145
  if snowflake_success:
1146
  logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
1147
  print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
@@ -1214,20 +792,111 @@ def main():
1214
  st.markdown("---")
1215
  st.markdown("#### πŸ“Š Retrieval History")
1216
 
1217
- with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
1218
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
1219
- st.markdown(f"**Retrieval #{idx}**")
 
 
 
 
 
1220
 
1221
  # Display the actual RAG query
1222
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
 
1223
  st.code(rag_query_expansion, language="text")
1224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1225
  # Display summary stats
 
1226
  st.json({
1227
- "conversation_length": len(entry.get("conversation_up_to", [])),
1228
- "documents_retrieved": len(entry.get("docs_retrieved", []))
1229
  })
1230
- st.markdown("---")
 
 
1231
 
1232
  # Example Questions Section
1233
  st.markdown("---")
@@ -1307,93 +976,6 @@ def main():
1307
  st.caption("πŸ’‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
1308
 
1309
 
1310
- # Store selected question for next render (handled in input section above)
1311
- # This ensures the question populates the input field correctly
1312
-
1313
- # Example Questions Section
1314
- st.markdown("---")
1315
- st.markdown(
1316
- "<h3 class='example-questions-header'>πŸ’‘ Example Questions</h3>",
1317
- unsafe_allow_html=True
1318
- )
1319
- st.markdown(
1320
- "<p class='example-questions-description'>Click on any question below to use it, or modify the editable examples:</p>",
1321
- unsafe_allow_html=True
1322
- )
1323
-
1324
- # Initialize example question state
1325
- if 'custom_question_1' not in st.session_state:
1326
- st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?"
1327
- if 'custom_question_2' not in st.session_state:
1328
- st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?"
1329
-
1330
- # Question 1: Filename insights (fixed, clickable)
1331
- st.markdown("#### πŸ“„ Question 1: List insights from a specific file")
1332
- col1, col2 = st.columns([3, 1])
1333
- with col1:
1334
- example_q1 = "List couple of insights from the filename."
1335
- st.markdown(f"**Example:** `{example_q1}`")
1336
- st.info("πŸ’‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
1337
- with col2:
1338
- if st.button("πŸ“‹ Use This Question", key="use_example_1", use_container_width=True):
1339
- st.session_state.pending_question = example_q1
1340
- st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1341
- st.rerun()
1342
-
1343
- st.markdown("---")
1344
-
1345
- # Questions 2 & 3: Editable examples
1346
- st.markdown("#### ✏️ Customizable Questions (Edit and use)")
1347
-
1348
- # Question 2
1349
- # st.markdown("**Question 2:**")
1350
- custom_q1 = st.text_area(
1351
- "Edit question 2:",
1352
- value=st.session_state.custom_question_1,
1353
- height=80,
1354
- key="edit_question_2",
1355
- help="Modify this question to fit your needs, then click 'Use This Question'"
1356
- )
1357
- col1, col2 = st.columns([1, 4])
1358
- with col1:
1359
- if st.button("πŸ“‹ Use Question 2", key="use_custom_1", use_container_width=True):
1360
- if custom_q1.strip():
1361
- st.session_state.pending_question = custom_q1.strip()
1362
- st.session_state.custom_question_1 = custom_q1.strip()
1363
- st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1364
- st.rerun()
1365
- else:
1366
- st.warning("Please enter a question first!")
1367
- with col2:
1368
- st.caption("πŸ’‘ Tip: Add specific details like dates, names, or amounts to get more precise answers")
1369
-
1370
- st.info("πŸ’‘ **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
1371
-
1372
- st.markdown("---")
1373
-
1374
- # Question 3
1375
- # st.markdown("**Question 3:**")
1376
- custom_q2 = st.text_area(
1377
- "Edit question 3:",
1378
- value=st.session_state.custom_question_2,
1379
- height=80,
1380
- key="edit_question_3",
1381
- help="Modify this question to fit your needs, then click 'Use This Question'"
1382
- )
1383
- col1, col2 = st.columns([1, 4])
1384
- with col1:
1385
- if st.button("πŸ“‹ Use Question 3", key="use_custom_2", use_container_width=True):
1386
- if custom_q2.strip():
1387
- st.session_state.pending_question = custom_q2.strip()
1388
- st.session_state.custom_question_2 = custom_q2.strip()
1389
- st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1390
- st.rerun()
1391
- else:
1392
- st.warning("Please enter a question first!")
1393
- with col2:
1394
- st.caption("πŸ’‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
1395
-
1396
-
1397
  # Store selected question for next render (handled in input section above)
1398
  # This ensures the question populates the input field correctly
1399
 
 
10
  import logging
11
  import traceback
12
  from pathlib import Path
13
+
14
  from collections import Counter
15
  from typing import List, Dict, Any, Optional
16
 
 
20
  import plotly.express as px
21
  from langchain_core.messages import HumanMessage, AIMessage
22
 
23
+
24
+ from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
25
+ from src.feedback import FeedbackManager
26
+ from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
27
+
28
  from src.config.paths import (
29
  IS_DEPLOYED,
30
  PROJECT_DIR,
 
92
  page_title="Intelligent Audit Report Chatbot"
93
  )
94
 
95
+
96
+ st.markdown(get_custom_css(), unsafe_allow_html=True)
97
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def get_system_type():
100
  """Get the current system type"""
 
104
  else:
105
  return "Multi-Agent System"
106
 
107
+ def get_chatbot(version: str = "v1"):
108
+ """Initialize and return the chatbot based on version"""
109
+ if version == "beta":
110
+ return get_gemini_chatbot()
 
 
111
  else:
112
+ # Check environment variable for system type (v1)
113
+ system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
114
+ if system == 'smart':
115
+ return get_smart_chatbot()
116
+ else:
117
+ return get_multi_agent_chatbot()
118
 
119
  def serialize_messages(messages):
120
  """Serialize LangChain messages to dictionaries"""
 
160
  return serialized
161
 
162
 
163
+ feedback_manager = FeedbackManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  @st.cache_data
167
  def load_filter_options():
168
  try:
 
187
  # Track RAG retrieval history for feedback
188
  if 'rag_retrieval_history' not in st.session_state:
189
  st.session_state.rag_retrieval_history = []
190
+ # Version selection (v1 or beta)
191
+ if 'chatbot_version' not in st.session_state:
192
+ st.session_state.chatbot_version = "v1"
193
+
194
+ # Initialize chatbot based on version (reinitialize if version changes)
195
+ chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
196
+ if chatbot_version_key not in st.session_state or st.session_state.get('_last_version') != st.session_state.chatbot_version:
197
+ try:
198
+ with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
199
+ st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
200
+ st.session_state['_last_version'] = st.session_state.chatbot_version
201
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
202
+ st.success("βœ… AI system ready!")
203
+ except Exception as e:
204
+ st.error(f"❌ Failed to initialize chatbot: {str(e)}")
205
+ st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
206
+ # Reset to v1 to prevent infinite loop
207
+ st.session_state.chatbot_version = "v1"
208
+ st.session_state['_last_version'] = "v1"
209
+ if 'chatbot' in st.session_state:
210
+ del st.session_state['chatbot']
211
+ st.stop() # Stop execution to prevent infinite loop
212
+ else:
213
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
214
 
215
  # Reset conversation history if needed (but keep chatbot cached)
216
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
 
222
  st.session_state.reset_conversation = False
223
  st.rerun()
224
 
225
+
 
226
  st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
227
 
228
+ # Show version info
229
+ if st.session_state.chatbot_version == "beta":
230
+ st.info("πŸ”¬ **Beta Mode**: Using Google Gemini File Search API")
231
+
232
  # Session info
233
  duration = int(time.time() - st.session_state.session_start_time)
234
  duration_str = f"{duration // 60}m {duration % 60}s"
 
386
  )
387
 
388
  with col2:
389
+ send_button = st.button("Send", key="send_button", width='stretch')
390
 
391
  # Clear chat button
392
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
 
447
  else:
448
  formatted_query = "No RAG query available"
449
 
450
+ # Extract filters from active filters
451
+ filters_used = {
452
+ "sources": st.session_state.active_filters.get('sources', []),
453
+ "years": st.session_state.active_filters.get('years', []),
454
+ "districts": st.session_state.active_filters.get('districts', []),
455
+ "filenames": st.session_state.active_filters.get('filenames', [])
456
+ }
457
+
458
  retrieval_entry = {
459
  "conversation_up_to": serialize_messages(st.session_state.messages),
460
  "rag_query_expansion": formatted_query,
461
+ "docs_retrieved": serialize_documents(sources),
462
+ "filters_applied": filters_used,
463
+ "timestamp": time.time()
464
  }
465
  st.session_state.rag_retrieval_history.append(retrieval_entry)
466
  else:
 
521
  for i, doc in enumerate(sources): # Show all documents
522
  # Get relevance score and ID if available
523
  metadata = getattr(doc, 'metadata', {})
524
+ # Handle both standard RAG scores and Gemini scores
525
+ score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
526
+ chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
527
+ if score is not None:
528
+ try:
529
+ score_text = f" (Score: {float(score):.3f})"
530
+ except (ValueError, TypeError):
531
+ score_text = ""
532
+ else:
533
+ score_text = ""
534
+ if chunk_id and chunk_id != 'Unknown':
535
+ score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
536
 
537
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
538
  # Display document metadata with emojis
 
610
 
611
  submitted = st.form_submit_button(
612
  "πŸ“€ Submit Feedback",
613
+ width='stretch',
614
  disabled=submit_disabled
615
  )
616
 
 
622
  st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
623
 
624
  # Extract transcript from messages
625
+ transcript = feedback_manager.extract_transcript(st.session_state.messages)
626
 
627
  # Build retrievals structure
628
+ retrievals = feedback_manager.build_retrievals_structure(
629
+
630
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
631
  st.session_state.messages
632
  )
633
 
634
  # Build feedback_score_related_retrieval_docs
635
+
636
+ feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
637
  is_feedback_about_last_retrieval,
638
  st.session_state.messages,
639
  st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
 
663
  # Create UserFeedback dataclass instance
664
  feedback_obj = None # Initialize outside try block
665
  try:
666
+ feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
667
  print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
668
  st.write(f"βœ… **Feedback Object Created**")
669
  st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
 
719
  logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
720
  print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
721
 
722
+ snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
723
  if snowflake_success:
724
  logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
725
  print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
 
792
  st.markdown("---")
793
  st.markdown("#### πŸ“Š Retrieval History")
794
 
795
+ with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
796
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
797
+ st.markdown(f"### **Retrieval #{idx}**")
798
+
799
+ # Display timestamp if available
800
+ if entry.get("timestamp"):
801
+ timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
802
+ st.caption(f"πŸ• {timestamp_str}")
803
 
804
  # Display the actual RAG query
805
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
806
+ st.markdown("**πŸ” RAG Query:**")
807
  st.code(rag_query_expansion, language="text")
808
 
809
+ # Display filters used
810
+ filters_applied = entry.get("filters_applied", {})
811
+ if filters_applied and any(filters_applied.values()):
812
+ st.markdown("**🎯 Filters Applied:**")
813
+ filter_display = {}
814
+ if filters_applied.get("sources"):
815
+ filter_display["Sources"] = filters_applied["sources"]
816
+ if filters_applied.get("years"):
817
+ filter_display["Years"] = filters_applied["years"]
818
+ if filters_applied.get("districts"):
819
+ filter_display["Districts"] = filters_applied["districts"]
820
+ if filters_applied.get("filenames"):
821
+ filter_display["Filenames"] = filters_applied["filenames"]
822
+
823
+ if filter_display:
824
+ st.json(filter_display)
825
+ else:
826
+ st.info("No filters applied")
827
+ else:
828
+ st.info("No filters applied")
829
+
830
+ # Display conversation history up to retrieval point
831
+ conversation_up_to = entry.get("conversation_up_to", [])
832
+ if conversation_up_to:
833
+ st.markdown("**πŸ’¬ Conversation History (up to retrieval point):**")
834
+ with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
835
+ for msg_idx, msg in enumerate(conversation_up_to, 1):
836
+ role = msg.get("type", "unknown")
837
+ content = msg.get("content", "")
838
+
839
+ if role == "HumanMessage" or role == "human":
840
+ st.markdown(f"**πŸ‘€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
841
+ elif role == "AIMessage" or role == "ai":
842
+ st.markdown(f"**πŸ€– Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
843
+ else:
844
+ st.info("No conversation history available")
845
+
846
+ # Display documents retrieved
847
+ docs_retrieved = entry.get("docs_retrieved", [])
848
+ if docs_retrieved:
849
+ st.markdown(f"**πŸ“„ Documents Retrieved ({len(docs_retrieved)}):**")
850
+ with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
851
+ for doc_idx, doc in enumerate(docs_retrieved, 1):
852
+ st.markdown(f"**Document {doc_idx}:**")
853
+
854
+ # Display metadata
855
+ metadata = doc.get("metadata", {})
856
+ if metadata:
857
+ col1, col2, col3 = st.columns(3)
858
+ with col1:
859
+ st.write(f"πŸ“„ **File:** {metadata.get('filename', 'Unknown')}")
860
+ with col2:
861
+ st.write(f"πŸ›οΈ **Source:** {metadata.get('source', 'Unknown')}")
862
+ with col3:
863
+ st.write(f"πŸ“… **Year:** {metadata.get('year', 'Unknown')}")
864
+
865
+ # Additional metadata
866
+ if metadata.get('district'):
867
+ st.write(f"πŸ“ **District:** {metadata.get('district')}")
868
+ if metadata.get('page'):
869
+ st.write(f"πŸ“– **Page:** {metadata.get('page')}")
870
+ if metadata.get('score') is not None:
871
+ st.write(f"⭐ **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"⭐ **Score:** {metadata.get('score')}")
872
+
873
+ # Display content preview (first 200 chars)
874
+ content = doc.get("content", doc.get("page_content", ""))
875
+ if content:
876
+ st.markdown("**Content Preview:**")
877
+ st.text_area(
878
+ "Content Preview",
879
+ value=content[:200] + ("..." if len(content) > 200 else ""),
880
+ height=100,
881
+ disabled=True,
882
+ label_visibility="collapsed",
883
+ key=f"retrieval_{idx}_doc_{doc_idx}_preview"
884
+ )
885
+
886
+ if doc_idx < len(docs_retrieved):
887
+ st.markdown("---")
888
+ else:
889
+ st.info("No documents retrieved")
890
+
891
  # Display summary stats
892
+ st.markdown("**πŸ“Š Summary:**")
893
  st.json({
894
+ "conversation_length": len(conversation_up_to),
895
+ "documents_retrieved": len(docs_retrieved)
896
  })
897
+
898
+ if idx < len(st.session_state.rag_retrieval_history):
899
+ st.markdown("---")
900
 
901
  # Example Questions Section
902
  st.markdown("---")
 
976
  st.caption("πŸ’‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
977
 
978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
  # Store selected question for next render (handled in input section above)
980
  # This ensures the question populates the input field correctly
981
 
src/agents/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent modules for chatbot implementations
3
+ """
4
+
5
+ from .smart_chatbot import get_chatbot as get_smart_chatbot
6
+ from .multi_agent_chatbot import get_multi_agent_chatbot
7
+ from .gemini_chatbot import get_gemini_chatbot
8
+
9
+ __all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
10
+
src/agents/gemini_chatbot.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Chatbot (Beta Version)
3
+
4
+ This chatbot uses Google Gemini File Search API for RAG.
5
+ It provides a simpler architecture: Main Agent + Gemini Agent
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ import logging
12
+ import traceback
13
+ from pathlib import Path
14
+ from typing import Dict, List, Any, Optional, TypedDict
15
+
16
+ from langgraph.graph import StateGraph, END
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
19
+
20
+ from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
21
+ from src.config.paths import CONVERSATIONS_DIR
22
+
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GeminiState(TypedDict):
28
+ """State for Gemini chatbot conversation flow"""
29
+ conversation_id: str
30
+ messages: List[Any]
31
+ current_query: str
32
+ query_context: Optional[Dict[str, Any]]
33
+ gemini_result: Optional[GeminiFileSearchResult]
34
+ final_response: Optional[str]
35
+ agent_logs: List[str]
36
+ conversation_context: Dict[str, Any]
37
+ session_start_time: float
38
+ last_ai_message_time: float
39
+ filters: Optional[Dict[str, Any]]
40
+
41
+
42
+ class GeminiRAGChatbot:
43
+ """Gemini File Search RAG chatbot (Beta version)"""
44
+
45
+ def __init__(self):
46
+ """Initialize the Gemini chatbot"""
47
+ logger.info("πŸ€– INITIALIZING: Gemini File Search Chatbot (Beta)")
48
+
49
+ # Initialize Gemini File Search client
50
+ try:
51
+ self.gemini_client = GeminiFileSearchClient()
52
+ logger.info("βœ… Gemini File Search client initialized")
53
+ except Exception as e:
54
+ logger.error(f"❌ Failed to initialize Gemini client: {e}")
55
+ raise RuntimeError(f"Gemini client initialization failed: {e}")
56
+
57
+ # Build the LangGraph with LangSmith tracing if enabled
58
+ self.graph = self._build_graph()
59
+
60
+ # Enable LangSmith tracing if configured
61
+ langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
62
+ if langsmith_enabled:
63
+ logger.info("πŸ” LangSmith tracing enabled")
64
+ langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
65
+ logger.info(f"πŸ“Š LangSmith project: {langsmith_project}")
66
+
67
+ # Conversations directory
68
+ self.conversations_dir = CONVERSATIONS_DIR
69
+ try:
70
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
71
+ except (PermissionError, OSError) as e:
72
+ logger.warning(f"Could not create conversations directory: {e}")
73
+ self.conversations_dir = Path("conversations")
74
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
75
+
76
+ logger.info("βœ… Gemini File Search Chatbot initialized")
77
+
78
+ def _build_graph(self) -> StateGraph:
79
+ """Build the LangGraph for Gemini chatbot"""
80
+ graph = StateGraph(GeminiState)
81
+
82
+ # Add nodes
83
+ graph.add_node("main_agent", self._main_agent)
84
+ graph.add_node("gemini_agent", self._gemini_agent)
85
+
86
+ # Define the flow
87
+ graph.set_entry_point("main_agent")
88
+ graph.add_edge("main_agent", "gemini_agent")
89
+ graph.add_edge("gemini_agent", END)
90
+
91
+ return graph.compile()
92
+
93
+ def _main_agent(self, state: GeminiState) -> GeminiState:
94
+ """Main Agent: Extracts filters and prepares query"""
95
+ logger.info("🎯 MAIN AGENT: Processing query")
96
+
97
+ query = state["current_query"]
98
+ messages = state["messages"]
99
+
100
+ # Extract UI filters if present in query
101
+ ui_filters = self._extract_ui_filters(query)
102
+
103
+ # Extract context from conversation
104
+ context = self._extract_context_from_conversation(messages, ui_filters)
105
+
106
+ # Store context and filters
107
+ state["query_context"] = context
108
+ state["filters"] = context.get("filters", {})
109
+
110
+ logger.info(f"🎯 MAIN AGENT: Filters extracted: {state['filters']}")
111
+
112
+ return state
113
+
114
+ def _gemini_agent(self, state: GeminiState) -> GeminiState:
115
+ """Gemini Agent: Performs file search and generates response"""
116
+ logger.info("πŸ” GEMINI AGENT: Starting file search")
117
+
118
+ query = state["current_query"]
119
+ filters = state.get("filters", {})
120
+
121
+ # Perform Gemini file search
122
+ try:
123
+ result = self.gemini_client.search(query=query, filters=filters)
124
+ logger.info(f"βœ… GEMINI AGENT: Search completed, {len(result.sources)} sources found")
125
+
126
+ # Enhance response with document references
127
+ enhanced_response = self._enhance_response_with_references(
128
+ result.answer,
129
+ result.sources,
130
+ query
131
+ )
132
+
133
+ state["gemini_result"] = result
134
+ state["final_response"] = enhanced_response
135
+ state["last_ai_message_time"] = time.time()
136
+
137
+ state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
138
+
139
+ except Exception as e:
140
+ logger.error(f"❌ GEMINI AGENT ERROR: {e}")
141
+ traceback.print_exc()
142
+ state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
143
+ state["last_ai_message_time"] = time.time()
144
+
145
+ return state
146
+
147
+ def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
148
+ """Enhance Gemini response to include document references"""
149
+ if not sources or not answer:
150
+ return answer
151
+
152
+ # Use LLM to intelligently add document references
153
+ try:
154
+ from src.llm.adapters import get_llm_client
155
+ llm = get_llm_client()
156
+
157
+ # Prepare document summaries for the LLM
158
+ doc_summaries = []
159
+ for idx, doc in enumerate(sources, 1):
160
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
161
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
162
+
163
+ filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
164
+ year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
165
+ source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
166
+
167
+ doc_summaries.append(f"[Doc {idx}] {filename} ({year}, {source}): {content[:300]}...")
168
+
169
+ prompt = f"""You are enhancing a response from a document search system. The original response is:
170
+
171
+ {answer}
172
+
173
+ The following documents were retrieved and used to generate this response:
174
+
175
+ {chr(10).join(doc_summaries)}
176
+
177
+ CRITICAL RULES:
178
+ 1. The response should ONLY contain information from the retrieved documents listed above
179
+ 2. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
180
+ 3. Add document references [Doc i] at the end of sentences that use information from specific documents
181
+ 4. Only reference documents that are actually used in the response
182
+ 5. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
183
+ 6. Keep the response natural and conversational
184
+ 7. Don't change the core content that matches the documents, just add references where appropriate
185
+ 8. If multiple documents support the same claim, use [Doc i, Doc j] format
186
+ 9. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
187
+
188
+ Return ONLY the enhanced response with references added and any corrections made. Do not include any explanation or meta-commentary."""
189
+
190
+ enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
191
+
192
+ # Fallback: if LLM fails, just return original
193
+ if not enhanced or len(enhanced) < len(answer) * 0.5:
194
+ logger.warning("LLM enhancement failed, using original response")
195
+ return answer
196
+
197
+ return enhanced
198
+
199
+ except Exception as e:
200
+ logger.warning(f"Failed to enhance response with references: {e}")
201
+ # Fallback: add basic references at the end
202
+ if sources:
203
+ ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
204
+ return f"{answer}\n\n*Based on documents: {ref_list}*"
205
+ return answer
206
+
207
+ def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
208
+ """Extract UI filters from query if present"""
209
+ filters = {}
210
+
211
+ if "FILTER CONTEXT:" in query:
212
+ filter_section = query.split("FILTER CONTEXT:")[1]
213
+ if "USER QUERY:" in filter_section:
214
+ filter_section = filter_section.split("USER QUERY:")[0]
215
+ filter_section = filter_section.strip()
216
+
217
+ if "Sources:" in filter_section:
218
+ sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
219
+ if sources_line:
220
+ sources_str = sources_line[0].split("Sources:")[1].strip()
221
+ if sources_str and sources_str != "None":
222
+ filters["sources"] = [s.strip() for s in sources_str.split(",")]
223
+
224
+ if "Years:" in filter_section:
225
+ years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
226
+ if years_line:
227
+ years_str = years_line[0].split("Years:")[1].strip()
228
+ if years_str and years_str != "None":
229
+ filters["year"] = [y.strip() for y in years_str.split(",")]
230
+
231
+ if "Districts:" in filter_section:
232
+ districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
233
+ if districts_line:
234
+ districts_str = districts_line[0].split("Districts:")[1].strip()
235
+ if districts_str and districts_str != "None":
236
+ filters["district"] = [d.strip() for d in districts_str.split(",")]
237
+
238
+ if "Filenames:" in filter_section:
239
+ filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
240
+ if filenames_line:
241
+ filenames_str = filenames_line[0].split("Filenames:")[1].strip()
242
+ if filenames_str and filenames_str != "None":
243
+ filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
244
+
245
+ return filters
246
+
247
+ def _extract_context_from_conversation(
248
+ self,
249
+ messages: List[Any],
250
+ ui_filters: Dict[str, List[str]]
251
+ ) -> Dict[str, Any]:
252
+ """Extract context from conversation history"""
253
+ # Use UI filters if available
254
+ filters = ui_filters.copy() if ui_filters else {}
255
+
256
+ # For Gemini, we pass filters directly to the search function
257
+ # The filters will be used to add context to the query
258
+
259
+ return {
260
+ "filters": filters,
261
+ "has_filters": bool(filters)
262
+ }
263
+
264
+ def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
265
+ """Main chat interface"""
266
+ logger.info(f"πŸ’¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
267
+
268
+ # Load conversation
269
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
270
+ conversation = self._load_conversation(conversation_file)
271
+
272
+ # Add user message
273
+ conversation["messages"].append(HumanMessage(content=user_input))
274
+
275
+ # Prepare state
276
+ state = GeminiState(
277
+ conversation_id=conversation_id,
278
+ messages=conversation["messages"],
279
+ current_query=user_input,
280
+ query_context=None,
281
+ gemini_result=None,
282
+ final_response=None,
283
+ agent_logs=[],
284
+ conversation_context=conversation.get("context", {}),
285
+ session_start_time=conversation["session_start_time"],
286
+ last_ai_message_time=conversation["last_ai_message_time"],
287
+ filters=None
288
+ )
289
+
290
+ # Run graph
291
+ final_state = self.graph.invoke(state)
292
+
293
+ # Add AI response to conversation
294
+ if final_state["final_response"]:
295
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
296
+
297
+ # Update conversation
298
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
299
+ conversation["context"] = final_state["conversation_context"]
300
+
301
+ # Save conversation
302
+ self._save_conversation(conversation_file, conversation)
303
+
304
+ # Format sources for display
305
+ sources = []
306
+ if final_state.get("gemini_result"):
307
+ sources = self.gemini_client.format_sources_for_display(final_state["gemini_result"])
308
+
309
+ return {
310
+ 'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
311
+ 'rag_result': {
312
+ 'sources': sources,
313
+ 'answer': final_state["final_response"]
314
+ },
315
+ 'agent_logs': final_state["agent_logs"],
316
+ 'actual_rag_query': final_state["current_query"]
317
+ }
318
+
319
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
320
+ """Load conversation from file"""
321
+ if conversation_file.exists():
322
+ try:
323
+ with open(conversation_file) as f:
324
+ data = json.load(f)
325
+ messages = []
326
+ for msg_data in data.get("messages", []):
327
+ if msg_data["type"] == "human":
328
+ messages.append(HumanMessage(content=msg_data["content"]))
329
+ elif msg_data["type"] == "ai":
330
+ messages.append(AIMessage(content=msg_data["content"]))
331
+ data["messages"] = messages
332
+ return data
333
+ except Exception as e:
334
+ logger.warning(f"Could not load conversation: {e}")
335
+
336
+ return {
337
+ "messages": [],
338
+ "session_start_time": time.time(),
339
+ "last_ai_message_time": time.time(),
340
+ "context": {}
341
+ }
342
+
343
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
344
+ """Save conversation to file"""
345
+ try:
346
+ conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
347
+
348
+ messages_data = []
349
+ for msg in conversation["messages"]:
350
+ if isinstance(msg, HumanMessage):
351
+ messages_data.append({"type": "human", "content": msg.content})
352
+ elif isinstance(msg, AIMessage):
353
+ messages_data.append({"type": "ai", "content": msg.content})
354
+
355
+ conversation_data = {
356
+ "messages": messages_data,
357
+ "session_start_time": conversation["session_start_time"],
358
+ "last_ai_message_time": conversation["last_ai_message_time"],
359
+ "context": conversation.get("context", {})
360
+ }
361
+
362
+ with open(conversation_file, 'w') as f:
363
+ json.dump(conversation_data, f, indent=2)
364
+
365
+ except Exception as e:
366
+ logger.error(f"Could not save conversation: {e}")
367
+
368
+
369
+ def get_gemini_chatbot():
370
+ """Get Gemini chatbot instance"""
371
+ return GeminiRAGChatbot()
372
+
multi_agent_chatbot.py β†’ src/agents/multi_agent_chatbot.py RENAMED
@@ -208,6 +208,59 @@ class MultiAgentRAGChatbot:
208
  logger.info(f" Sources: {self.source_whitelist}")
209
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def _build_graph(self) -> StateGraph:
212
  """Build the multi-agent LangGraph"""
213
  graph = StateGraph(MultiAgentState)
@@ -512,6 +565,10 @@ class MultiAgentRAGChatbot:
512
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
513
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
514
  - Always return districts as JSON arrays when multiple districts are mentioned
 
 
 
 
515
  - If no exact matches found, set extracted values to null
516
 
517
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
656
  # Validate each district in the array
657
  valid_districts = []
658
  for district in extracted_district:
659
- if district in self.district_whitelist:
660
- valid_districts.append(district)
661
- else:
662
- # Try removing "District" suffix
663
- district_name = district.replace(" District", "").replace(" district", "")
664
- if district_name in self.district_whitelist:
665
- valid_districts.append(district_name)
666
 
667
  if valid_districts:
668
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
671
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
672
  extracted_district = None
673
  else:
674
- # Single district validation
675
- if extracted_district not in self.district_whitelist:
676
- # Try removing "District" suffix
677
- district_name = extracted_district.replace(" District", "").replace(" district", "")
678
- if district_name in self.district_whitelist:
679
- logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{district_name}'")
680
- extracted_district = district_name
681
- else:
682
- logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
683
- extracted_district = None
684
 
685
  # Validate source (handle both single values and arrays)
686
  if extracted_source:
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
918
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
919
 
920
  # Merge with extracted context for missing filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  if not filters.get("year") and context.extracted_year:
922
  # Handle both single values and arrays
923
  if isinstance(context.extracted_year, list):
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
926
  filters["year"] = [context.extracted_year]
927
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
928
 
929
- if not filters.get("district") and context.extracted_district:
930
- # Handle both single values and arrays
931
- if isinstance(context.extracted_district, list):
932
- # Normalize district names to title case (match Qdrant metadata format)
933
- normalized = [d.title() for d in context.extracted_district]
934
- filters["district"] = normalized
935
- else:
936
- filters["district"] = [context.extracted_district.title()]
937
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
938
-
939
  if not filters.get("sources") and context.extracted_source:
940
  # Handle both single values and arrays
941
  if isinstance(context.extracted_source, list):
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
963
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
964
 
965
  if context.extracted_district:
966
- # Handle both single values and arrays
967
  if isinstance(context.extracted_district, list):
968
- filters["district"] = context.extracted_district
 
 
 
 
 
 
 
969
  else:
970
- filters["district"] = [context.extracted_district]
971
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter: {context.extracted_district}")
 
 
972
 
973
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
974
  return filters
@@ -978,49 +1046,233 @@ Rewrite the best retrieval query:""")
978
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
979
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
980
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
 
 
 
 
 
 
 
 
 
 
981
 
982
  # Create response prompt
983
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
984
  response_prompt = ChatPromptTemplate.from_messages([
985
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
986
 
 
 
 
 
 
 
 
 
987
  RULES:
988
  1. Answer the user's question directly and clearly
989
- 2. Use the retrieved documents as evidence
990
  3. Be conversational, not technical
991
  4. Don't mention scores, retrieval details, or technical implementation
992
  5. If relevant documents were found, reference them naturally
993
- 6. If no relevant documents, explain based on your knowledge (if you have it) or just say you do not have enough information.
994
- 7. If the passages have useful facts or numbers, use them in your answer.
995
- 8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
996
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
997
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
998
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
999
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1000
  13. You do not need to use every passage. Only use the ones that help answer the question.
1001
- 14. If the documents do not have the information needed to answer the question, just say you do not have enough information.
1002
-
 
1003
 
1004
  TONE: Professional but friendly, like talking to a colleague."""),
1005
- HumanMessage(content=f"""User Question: {query}
 
 
 
1006
 
1007
  Retrieved Documents: {len(documents)} documents found
1008
 
 
 
 
 
 
 
1009
  RAG Answer: {rag_answer}
1010
 
1011
- Generate a conversational response:""")
 
 
 
 
 
 
 
1012
  ])
1013
 
1014
  try:
1015
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1016
  response = self.llm.invoke(response_prompt.format_messages())
1017
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1018
- return response.content.strip()
 
 
 
 
 
 
 
 
1019
  except Exception as e:
1020
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1021
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1022
  return rag_answer # Fallback to RAG answer
1023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1025
  """Generate conversational response using only LLM knowledge and conversation history"""
1026
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
 
208
  logger.info(f" Sources: {self.source_whitelist}")
209
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
210
 
211
+ def _normalize_district_name(self, district: str) -> Optional[str]:
212
+ """Normalize district name with fuzzy matching for common misspellings."""
213
+ if not district:
214
+ return None
215
+
216
+ district = district.strip()
217
+
218
+ # Direct match
219
+ if district in self.district_whitelist:
220
+ return district
221
+
222
+ # Remove "District" suffix
223
+ district_name = district.replace(" District", "").replace(" district", "").strip()
224
+ if district_name in self.district_whitelist:
225
+ return district_name
226
+
227
+ # Common misspellings mapping
228
+ misspelling_map = {
229
+ "kalagala": "Kalangala",
230
+ "Kalagala": "Kalangala",
231
+ "KALAGALA": "Kalangala",
232
+ "kalangala": "Kalangala",
233
+ "gulu": "Gulu",
234
+ "GULU": "Gulu",
235
+ "kampala": "Kampala",
236
+ "KAMPALA": "Kampala",
237
+ }
238
+
239
+ # Check misspelling map (case-insensitive)
240
+ district_lower = district_name.lower()
241
+ if district_lower in misspelling_map:
242
+ corrected = misspelling_map[district_lower]
243
+ if corrected in self.district_whitelist:
244
+ return corrected
245
+
246
+ # Fuzzy matching for similar names (simple Levenshtein-like check)
247
+ # Check if the district name is very similar to any whitelist entry
248
+ for whitelist_district in self.district_whitelist:
249
+ # Case-insensitive comparison
250
+ if district_name.lower() == whitelist_district.lower():
251
+ return whitelist_district
252
+
253
+ # Check if one is a substring of the other (for partial matches)
254
+ if len(district_name) >= 4 and len(whitelist_district) >= 4:
255
+ if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
256
+ # Only return if it's a strong match (at least 80% of characters match)
257
+ min_len = min(len(district_name), len(whitelist_district))
258
+ max_len = max(len(district_name), len(whitelist_district))
259
+ if min_len / max_len >= 0.8:
260
+ return whitelist_district
261
+
262
+ return None
263
+
264
  def _build_graph(self) -> StateGraph:
265
  """Build the multi-agent LangGraph"""
266
  graph = StateGraph(MultiAgentState)
 
565
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
566
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
567
  - Always return districts as JSON arrays when multiple districts are mentioned
568
+ - **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
569
+ * "Kalagala" (missing 'n') should be extracted as "Kalangala"
570
+ * "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
571
+ * Similar case-insensitive variations should be normalized to the correct district name
572
  - If no exact matches found, set extracted values to null
573
 
574
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
 
713
  # Validate each district in the array
714
  valid_districts = []
715
  for district in extracted_district:
716
+ normalized = self._normalize_district_name(district)
717
+ if normalized:
718
+ valid_districts.append(normalized)
 
 
 
 
719
 
720
  if valid_districts:
721
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
 
724
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
725
  extracted_district = None
726
  else:
727
+ # Single district validation with fuzzy matching
728
+ normalized = self._normalize_district_name(extracted_district)
729
+ if normalized:
730
+ if normalized != extracted_district:
731
+ logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
732
+ extracted_district = normalized
733
+ else:
734
+ logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
735
+ extracted_district = None
 
736
 
737
  # Validate source (handle both single values and arrays)
738
  if extracted_source:
 
970
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
971
 
972
  # Merge with extracted context for missing filters
973
+ if not filters.get("district") and context.extracted_district:
974
+ # Normalize district names using the normalization function
975
+ if isinstance(context.extracted_district, list):
976
+ normalized_districts = []
977
+ for d in context.extracted_district:
978
+ normalized = self._normalize_district_name(d)
979
+ if normalized:
980
+ normalized_districts.append(normalized)
981
+ if normalized_districts:
982
+ filters["district"] = normalized_districts
983
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
984
+ else:
985
+ normalized = self._normalize_district_name(context.extracted_district)
986
+ if normalized:
987
+ filters["district"] = [normalized]
988
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
989
+
990
  if not filters.get("year") and context.extracted_year:
991
  # Handle both single values and arrays
992
  if isinstance(context.extracted_year, list):
 
995
  filters["year"] = [context.extracted_year]
996
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
997
 
 
 
 
 
 
 
 
 
 
 
998
  if not filters.get("sources") and context.extracted_source:
999
  # Handle both single values and arrays
1000
  if isinstance(context.extracted_source, list):
 
1022
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
1023
 
1024
  if context.extracted_district:
1025
+ # Normalize district names using the normalization function
1026
  if isinstance(context.extracted_district, list):
1027
+ normalized_districts = []
1028
+ for d in context.extracted_district:
1029
+ normalized = self._normalize_district_name(d)
1030
+ if normalized:
1031
+ normalized_districts.append(normalized)
1032
+ if normalized_districts:
1033
+ filters["district"] = normalized_districts
1034
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
1035
  else:
1036
+ normalized = self._normalize_district_name(context.extracted_district)
1037
+ if normalized:
1038
+ filters["district"] = [normalized]
1039
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
1040
 
1041
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
1042
  return filters
 
1046
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
1047
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
1048
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
1049
+ logger.info(f"πŸ’¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
1050
+
1051
+ # Build conversation history context
1052
+ conversation_context = self._build_conversation_context(messages)
1053
+
1054
+ # Build detailed document information
1055
+ document_details = self._build_document_details(documents)
1056
+
1057
+ # Extract correct district/source/year names from documents (to correct misspellings)
1058
+ correct_names = self._extract_correct_names_from_documents(documents)
1059
 
1060
  # Create response prompt
1061
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
1062
  response_prompt = ChatPromptTemplate.from_messages([
1063
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
1064
 
1065
+ CRITICAL RULES - NO HALLUCINATION:
1066
+ 1. **ONLY use information from the retrieved documents provided below**
1067
+ 2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
1068
+ 3. **If a document doesn't contain the information, DO NOT make it up**
1069
+ 4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
1070
+ 5. **Check the document years/districts before making any claims about them**
1071
+ 6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
1072
+
1073
  RULES:
1074
  1. Answer the user's question directly and clearly
1075
+ 2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
1076
  3. Be conversational, not technical
1077
  4. Don't mention scores, retrieval details, or technical implementation
1078
  5. If relevant documents were found, reference them naturally
1079
+ 6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
1080
+ 7. If the passages have useful facts or numbers, use them in your answer WITH references
1081
+ 8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
1082
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
1083
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
1084
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
1085
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1086
  13. You do not need to use every passage. Only use the ones that help answer the question.
1087
+ 14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
1088
+ 15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
1089
+ 16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
1090
 
1091
  TONE: Professional but friendly, like talking to a colleague."""),
1092
+ HumanMessage(content=f"""Conversation History:
1093
+ {conversation_context}
1094
+
1095
+ Current User Question: {query}
1096
 
1097
  Retrieved Documents: {len(documents)} documents found
1098
 
1099
+ CORRECT NAMES TO USE (from document metadata - use these exact spellings):
1100
+ {correct_names}
1101
+
1102
+ Full Document Details:
1103
+ {document_details}
1104
+
1105
  RAG Answer: {rag_answer}
1106
 
1107
+ CRITICAL:
1108
+ - Responses should be grounded to what is available in the retrieved documents
1109
+ - If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
1110
+ - Every factual claim MUST have [Doc i] reference
1111
+ - If information is not in documents, explicitly state it's not available
1112
+ - **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
1113
+
1114
+ Generate a conversational response with proper document references:""")
1115
  ])
1116
 
1117
  try:
1118
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1119
  response = self.llm.invoke(response_prompt.format_messages())
1120
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1121
+
1122
+ # Post-process response to ensure no hallucination
1123
+ final_response = self._validate_and_enhance_response(
1124
+ response.content.strip(),
1125
+ documents,
1126
+ query
1127
+ )
1128
+
1129
+ return final_response
1130
  except Exception as e:
1131
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1132
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1133
  return rag_answer # Fallback to RAG answer
1134
 
1135
+ def _build_conversation_context(self, messages: List[Any]) -> str:
1136
+ """Build conversation history context for response generation."""
1137
+ if not messages:
1138
+ return "No previous conversation."
1139
+
1140
+ context_lines = []
1141
+ # Show last 6 messages for context (to capture the current exchange)
1142
+ for msg in messages[-6:]:
1143
+ if isinstance(msg, HumanMessage):
1144
+ context_lines.append(f"User: {msg.content}")
1145
+ elif isinstance(msg, AIMessage):
1146
+ context_lines.append(f"Assistant: {msg.content}")
1147
+
1148
+ return "\n".join(context_lines) if context_lines else "No previous conversation."
1149
+
1150
+ def _build_document_details(self, documents: List[Any]) -> str:
1151
+ """Build detailed document information for response generation."""
1152
+ if not documents:
1153
+ return "No documents retrieved."
1154
+
1155
+ details = []
1156
+ for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
1157
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1158
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
1159
+
1160
+ if isinstance(metadata, dict):
1161
+ filename = metadata.get('filename', 'Unknown')
1162
+ year = metadata.get('year', 'Unknown')
1163
+ district = metadata.get('district', 'Unknown')
1164
+ source = metadata.get('source', 'Unknown')
1165
+ page = metadata.get('page', metadata.get('page_label', 'Unknown'))
1166
+
1167
+ doc_info = f"[Doc {i}]"
1168
+ doc_info += f"\n Filename: {filename}"
1169
+ doc_info += f"\n Year: {year}"
1170
+ doc_info += f"\n District: {district}"
1171
+ doc_info += f"\n Source: {source}"
1172
+ if page != 'Unknown':
1173
+ doc_info += f"\n Page: {page}"
1174
+ doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
1175
+ details.append(doc_info)
1176
+
1177
+ return "\n\n".join(details) if details else "No document details available."
1178
+
1179
+ def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
1180
+ """Extract correct district/source names from documents to correct misspellings."""
1181
+ districts = set()
1182
+ sources = set()
1183
+ years = set()
1184
+
1185
+ for doc in documents:
1186
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1187
+ if isinstance(metadata, dict):
1188
+ if metadata.get('district'):
1189
+ districts.add(str(metadata['district']))
1190
+ if metadata.get('source'):
1191
+ sources.add(str(metadata['source']))
1192
+ if metadata.get('year'):
1193
+ years.add(str(metadata['year']))
1194
+
1195
+ result = []
1196
+ if districts:
1197
+ result.append(f"Districts: {', '.join(sorted(districts))}")
1198
+ if sources:
1199
+ result.append(f"Sources: {', '.join(sorted(sources))}")
1200
+ if years:
1201
+ result.append(f"Years: {', '.join(sorted(years))}")
1202
+
1203
+ if result:
1204
+ return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
1205
+ return "No metadata available."
1206
+
1207
+ def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
1208
+ """Validate response and ensure all claims are referenced."""
1209
+ # Extract years and districts from documents
1210
+ doc_years = set()
1211
+ doc_districts = set()
1212
+ doc_sources = set()
1213
+
1214
+ for doc in documents:
1215
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1216
+ if isinstance(metadata, dict):
1217
+ if metadata.get('year'):
1218
+ doc_years.add(str(metadata['year']))
1219
+ if metadata.get('district'):
1220
+ doc_districts.add(str(metadata['district']))
1221
+ if metadata.get('source'):
1222
+ doc_sources.add(str(metadata['source']))
1223
+
1224
+ # Correct misspellings in response using correct names from documents
1225
+ response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
1226
+
1227
+ # Check if response mentions years not in documents
1228
+ year_pattern = r'\b(20\d{2})\b'
1229
+ mentioned_years = set(re.findall(year_pattern, response))
1230
+
1231
+ # Check if user query mentions a year
1232
+ query_years = set(re.findall(year_pattern, query))
1233
+
1234
+ # If user asks about a year not in documents, add a warning
1235
+ missing_years = query_years - doc_years
1236
+ if missing_years and doc_years:
1237
+ warning = f"\n\n⚠️ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
1238
+ if warning not in response:
1239
+ response = response + warning
1240
+
1241
+ # Check if response has document references
1242
+ doc_ref_pattern = r'\[Doc\s+\d+\]'
1243
+ has_refs = bool(re.search(doc_ref_pattern, response))
1244
+
1245
+ # If response has factual claims but no references, add a note
1246
+ if not has_refs and len(documents) > 0:
1247
+ # Check if response has numbers or specific claims (simple heuristic)
1248
+ has_numbers = bool(re.search(r'\d+', response))
1249
+ if has_numbers and len(response) > 50:
1250
+ logger.warning("⚠️ Response contains factual claims but no document references")
1251
+ # Don't modify response, but log the issue
1252
+
1253
+ return response
1254
+
1255
+ def _correct_misspellings_in_response(self, response: str, correct_districts: set, correct_sources: set) -> str:
1256
+ """Correct common misspellings in response using correct names from documents."""
1257
+ # Common misspelling mappings (e.g., "Kalagala" -> "Kalangala")
1258
+ # We'll use fuzzy matching if needed, but first try direct corrections
1259
+
1260
+ corrected = response
1261
+
1262
+ # Correct district names
1263
+ for correct_district in correct_districts:
1264
+ # Try common misspellings
1265
+ if correct_district.lower() == "kalangala":
1266
+ # Replace "Kalagala" (missing 'n') with "Kalangala"
1267
+ corrected = re.sub(r'\bKalagala\b', 'Kalangala', corrected, flags=re.IGNORECASE)
1268
+ # Add more common misspellings as needed
1269
+ # For now, we rely on the LLM to use correct names from the prompt
1270
+
1271
+ # Correct source names if needed
1272
+ # Add source corrections as needed in the future
1273
+
1274
+ return corrected
1275
+
1276
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1277
  """Generate conversational response using only LLM knowledge and conversation history"""
1278
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
smart_chatbot.py β†’ src/agents/smart_chatbot.py RENAMED
File without changes
src/feedback/__init__.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Management Module
3
+
4
+ This module provides a unified interface for handling user feedback,
5
+ including data preparation, validation, and Snowflake storage.
6
+ """
7
+
8
+ from typing import Dict, Any, List, Optional
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+
11
+ from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
12
+ from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
13
+
14
+
15
+ class FeedbackManager:
16
+ """
17
+ Unified manager for feedback operations.
18
+
19
+ This class provides a single interface for all feedback-related functionality,
20
+ including data preparation, validation, and storage.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize the FeedbackManager"""
25
+ pass
26
+
27
+ @staticmethod
28
+ def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
29
+ """Extract transcript from messages - only user and bot messages, no extra metadata"""
30
+ transcript = []
31
+ for msg in messages:
32
+ if isinstance(msg, HumanMessage):
33
+ transcript.append({
34
+ "role": "user",
35
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
36
+ })
37
+ elif isinstance(msg, AIMessage):
38
+ transcript.append({
39
+ "role": "assistant",
40
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
41
+ })
42
+ return transcript
43
+
44
+ @staticmethod
45
+ def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
46
+ """Build retrievals structure from retrieval history"""
47
+ retrievals = []
48
+
49
+ for entry in rag_retrieval_history:
50
+ # Get the user message that triggered this retrieval
51
+ # The entry has conversation_up_to which includes messages up to that point
52
+ conversation_up_to = entry.get("conversation_up_to", [])
53
+
54
+ # Find the last user message in conversation_up_to (this is the trigger)
55
+ user_message_trigger = ""
56
+ for msg_dict in reversed(conversation_up_to):
57
+ if msg_dict.get("type") == "HumanMessage":
58
+ user_message_trigger = msg_dict.get("content", "")
59
+ break
60
+
61
+ # Fallback: if not found in conversation_up_to, get from actual messages
62
+ # This handles edge cases where conversation_up_to might be incomplete
63
+ if not user_message_trigger:
64
+ # Find which retrieval this is (0-indexed)
65
+ retrieval_idx = rag_retrieval_history.index(entry)
66
+ # The user message that triggered this retrieval is at position (retrieval_idx * 2)
67
+ # because each retrieval is preceded by: user message, bot response, user message, ...
68
+ # But we need to account for the fact that the first retrieval happens after the first user message
69
+ user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
70
+ if retrieval_idx < len(user_msgs):
71
+ user_message_trigger = str(user_msgs[retrieval_idx].content)
72
+ elif user_msgs:
73
+ # Fallback to last user message
74
+ user_message_trigger = str(user_msgs[-1].content)
75
+
76
+ # Get retrieved documents and truncate content to 100 chars
77
+ docs_retrieved = entry.get("docs_retrieved", [])
78
+ retrieved_docs = []
79
+ for doc in docs_retrieved:
80
+ doc_copy = doc.copy()
81
+ # Truncate content to 100 characters (keep all other fields)
82
+ if "content" in doc_copy:
83
+ doc_copy["content"] = doc_copy["content"][:100]
84
+ retrieved_docs.append(doc_copy)
85
+
86
+ retrievals.append({
87
+ "retrieved_docs": retrieved_docs,
88
+ "user_message_trigger": user_message_trigger
89
+ })
90
+
91
+ return retrievals
92
+
93
+ @staticmethod
94
+ def build_feedback_score_related_retrieval_docs(
95
+ is_feedback_about_last_retrieval: bool,
96
+ messages: List[Any],
97
+ rag_retrieval_history: List[Dict[str, Any]]
98
+ ) -> Optional[Dict[str, Any]]:
99
+ """Build feedback_score_related_retrieval_docs structure"""
100
+ if not rag_retrieval_history:
101
+ return None
102
+
103
+ # Get the relevant retrieval entry
104
+ if is_feedback_about_last_retrieval:
105
+ relevant_entry = rag_retrieval_history[-1]
106
+ else:
107
+ # If feedback is about all retrievals, use the last one as default
108
+ relevant_entry = rag_retrieval_history[-1]
109
+
110
+ # Get conversation up to that point
111
+ conversation_up_to = relevant_entry.get("conversation_up_to", [])
112
+
113
+ # Convert to transcript format (role/content)
114
+ conversation_up_to_point = []
115
+ for msg_dict in conversation_up_to:
116
+ if msg_dict.get("type") == "HumanMessage":
117
+ conversation_up_to_point.append({
118
+ "role": "user",
119
+ "content": msg_dict.get("content", "")
120
+ })
121
+ elif msg_dict.get("type") == "AIMessage":
122
+ conversation_up_to_point.append({
123
+ "role": "assistant",
124
+ "content": msg_dict.get("content", "")
125
+ })
126
+
127
+ # Get retrieved docs with full content (not truncated)
128
+ retrieved_docs = relevant_entry.get("docs_retrieved", [])
129
+
130
+ return {
131
+ "conversation_up_to_point": conversation_up_to_point,
132
+ "retrieved_docs": retrieved_docs
133
+ }
134
+
135
+ @staticmethod
136
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
137
+ """Create UserFeedback instance from dictionary"""
138
+ return create_feedback_from_dict(data)
139
+
140
+ @staticmethod
141
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
142
+ """Save feedback to Snowflake"""
143
+ return save_to_snowflake(feedback, table_name)
144
+
145
+ @staticmethod
146
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
147
+ """Generate Snowflake schema SQL"""
148
+ return generate_snowflake_schema_sql(table_name)
149
+
150
+
151
+ __all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
152
+
src/feedback/feedback_schema.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Schema for RAG Chatbot
3
+
4
+ This module defines dataclasses for feedback data structures
5
+ and provides Snowflake schema generation.
6
+ """
7
+ import os
8
+ from datetime import datetime
9
+ from dataclasses import dataclass, asdict, field
10
+ from typing import List, Optional, Dict, Any, Union
11
+
12
+
13
+
14
+ @dataclass
15
+ class RetrievedDocument:
16
+ """Single retrieved document metadata"""
17
+ doc_id: str
18
+ filename: str
19
+ page: int
20
+ score: float
21
+ content: str
22
+ metadata: Dict[str, Any]
23
+
24
+
25
+ @dataclass
26
+ class RetrievalEntry:
27
+ """Single retrieval operation metadata"""
28
+ rag_query: str
29
+ documents_retrieved: List[RetrievedDocument]
30
+ conversation_length: int
31
+ filters_applied: Optional[Dict[str, Any]] = None
32
+ timestamp: Optional[float] = None
33
+ _raw_data: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ @dataclass
37
+ class UserFeedback:
38
+ """User feedback submission data"""
39
+ feedback_id: str
40
+ open_ended_feedback: Optional[str]
41
+ score: int
42
+ is_feedback_about_last_retrieval: bool
43
+ conversation_id: str
44
+ timestamp: float
45
+ message_count: int
46
+ has_retrievals: bool
47
+ retrieval_count: int
48
+ transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
49
+ retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
50
+ feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
51
+ retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
52
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert to dictionary with nested data structures"""
56
+ result = asdict(self)
57
+ return result
58
+
59
+ def to_snowflake_schema(self) -> Dict[str, Any]:
60
+ """Generate Snowflake schema for this dataclass"""
61
+ schema = {
62
+ "feedback_id": "VARCHAR(255)",
63
+ "open_ended_feedback": "VARCHAR(16777216)", # Large text
64
+ "score": "INTEGER",
65
+ "is_feedback_about_last_retrieval": "BOOLEAN",
66
+ "conversation_id": "VARCHAR(255)",
67
+ "timestamp": "NUMBER(20, 0)",
68
+ "message_count": "INTEGER",
69
+ "has_retrievals": "BOOLEAN",
70
+ "retrieval_count": "INTEGER",
71
+ "transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
72
+ "retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
73
+ "feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
74
+ "retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
75
+ "created_at": "TIMESTAMP_NTZ",
76
+ # transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
77
+ # retrievals structure: [
78
+ # {
79
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
80
+ # "user_message_trigger": "final user message that triggered this retrieval"
81
+ # },
82
+ # ...
83
+ # ]
84
+ # feedback_score_related_retrieval_docs structure: {
85
+ # "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
86
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
87
+ # }
88
+ }
89
+ return schema
90
+
91
+ @classmethod
92
+ def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
93
+ """Generate CREATE TABLE SQL for Snowflake"""
94
+ schema = cls.to_snowflake_schema(None)
95
+
96
+ columns = []
97
+ for col_name, col_type in schema.items():
98
+ nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
99
+ columns.append(f" {col_name} {col_type} {nullable}")
100
+
101
+ # Build SQL string properly
102
+ columns_str = ",\n".join(columns)
103
+
104
+ sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
105
+ {columns_str},
106
+ PRIMARY KEY (feedback_id)
107
+ )
108
+ CLUSTER BY (timestamp, conversation_id, score);
109
+ -- Note: Snowflake doesn't support traditional indexes on regular tables.
110
+ -- Instead, we use CLUSTER BY to optimize queries on these columns.
111
+ -- Snowflake automatically maintains clustering for efficient querying.
112
+ -- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
113
+ -- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
114
+ """
115
+ return sql
116
+
117
+
118
+ # Snowflake variant schema for retrieved_data array
119
+ RETRIEVAL_ENTRY_SCHEMA = {
120
+ "rag_query": "VARCHAR",
121
+ "documents_retrieved": "ARRAY", # Array of document objects
122
+ "conversation_length": "INTEGER",
123
+ "filters_applied": "OBJECT",
124
+ "timestamp": "NUMBER"
125
+ }
126
+
127
+ DOCUMENT_SCHEMA = {
128
+ "doc_id": "VARCHAR",
129
+ "filename": "VARCHAR",
130
+ "page": "INTEGER",
131
+ "score": "DOUBLE",
132
+ "content": "VARCHAR(16777216)",
133
+ "metadata": "OBJECT"
134
+ }
135
+
136
+
137
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
138
+ """Generate complete Snowflake schema SQL for feedback system"""
139
+ if table_name is None:
140
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
141
+ return UserFeedback.get_snowflake_create_table_sql(table_name)
142
+
143
+
144
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
145
+ """Create UserFeedback instance from dictionary"""
146
+ return UserFeedback(
147
+ feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
148
+ open_ended_feedback=data.get("open_ended_feedback"),
149
+ score=data["score"],
150
+ is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
151
+ conversation_id=data["conversation_id"],
152
+ timestamp=data["timestamp"],
153
+ message_count=data["message_count"],
154
+ has_retrievals=data["has_retrievals"],
155
+ retrieval_count=data["retrieval_count"],
156
+ transcript=data.get("transcript", []),
157
+ retrievals=data.get("retrievals", []),
158
+ feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
159
+ retrieved_data=data.get("retrieved_data")
160
+ )
161
+
src/feedback/snowflake_connector.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snowflake Connector for Feedback System
3
+
4
+ This module handles inserting user feedback into Snowflake.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from typing import Dict, Any, Optional
11
+ from .feedback_schema import UserFeedback
12
+
13
+ # Try to import snowflake connector
14
+ try:
15
+ import snowflake.connector
16
+ SNOWFLAKE_AVAILABLE = True
17
+ except ImportError:
18
+ SNOWFLAKE_AVAILABLE = False
19
+ logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SnowflakeFeedbackConnector:
27
+ """Connector for inserting feedback into Snowflake"""
28
+
29
+ def __init__(
30
+ self,
31
+ user: str,
32
+ password: str,
33
+ account: str,
34
+ warehouse: str,
35
+ database: str = "SNOWFLAKE_LEARNING",
36
+ schema: str = "PUBLIC"
37
+ ):
38
+ self.user = user
39
+ self.password = password
40
+ self.account = account
41
+ self.warehouse = warehouse
42
+ self.database = database
43
+ self.schema = schema
44
+ self._connection = None
45
+
46
+ def connect(self):
47
+ """Establish Snowflake connection"""
48
+ if not SNOWFLAKE_AVAILABLE:
49
+ raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
50
+
51
+ logger.info("=" * 80)
52
+ logger.info("πŸ”Œ SNOWFLAKE CONNECTION: Attempting to connect...")
53
+ logger.info(f" - Account: {self.account}")
54
+ logger.info(f" - Warehouse: {self.warehouse}")
55
+ logger.info(f" - Database: {self.database}")
56
+ logger.info(f" - Schema: {self.schema}")
57
+ logger.info(f" - User: {self.user}")
58
+
59
+ try:
60
+ self._connection = snowflake.connector.connect(
61
+ user=self.user,
62
+ password=self.password,
63
+ account=self.account,
64
+ warehouse=self.warehouse
65
+ # Don't set database/schema in connection - we'll do it per query
66
+ )
67
+ logger.info("βœ… SNOWFLAKE CONNECTION: Successfully connected")
68
+ logger.info("=" * 80)
69
+ print(f"βœ… Connected to Snowflake: {self.database}.{self.schema}")
70
+ except Exception as e:
71
+ logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
72
+ logger.error("=" * 80)
73
+ print(f"❌ Failed to connect to Snowflake: {e}")
74
+ raise
75
+
76
+ def disconnect(self):
77
+ """Close Snowflake connection"""
78
+ if self._connection:
79
+ self._connection.close()
80
+ print("βœ… Disconnected from Snowflake")
81
+
82
+ def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
83
+ """Insert a single feedback record into Snowflake"""
84
+ logger.info("=" * 80)
85
+ logger.info("πŸ”„ SNOWFLAKE INSERT: Starting feedback insertion process")
86
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
87
+
88
+ # Get table name from parameter, env var, or default
89
+ if table_name is None:
90
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
91
+
92
+ if not self._connection:
93
+ logger.error("❌ Not connected to Snowflake. Call connect() first.")
94
+ raise RuntimeError("Not connected to Snowflake. Call connect() first.")
95
+
96
+ try:
97
+ logger.info("πŸ“Š VALIDATION: Validating feedback data structure...")
98
+
99
+ # Validate feedback object
100
+ validation_errors = []
101
+ if not feedback.feedback_id:
102
+ validation_errors.append("Missing feedback_id")
103
+ if feedback.score is None:
104
+ validation_errors.append("Missing score")
105
+ if feedback.timestamp is None:
106
+ validation_errors.append("Missing timestamp")
107
+
108
+ if validation_errors:
109
+ logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
110
+ return False
111
+ else:
112
+ logger.info("βœ… VALIDATION PASSED: All required fields present")
113
+
114
+ logger.info("πŸ“‹ Data Summary:")
115
+ logger.info(f" - Feedback ID: {feedback.feedback_id}")
116
+ logger.info(f" - Score: {feedback.score}")
117
+ logger.info(f" - Conversation ID: {feedback.conversation_id}")
118
+ logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
119
+ logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
120
+ logger.info(f" - Message Count: {feedback.message_count}")
121
+ logger.info(f" - Timestamp: {feedback.timestamp}")
122
+
123
+ cursor = self._connection.cursor()
124
+ logger.info("βœ… SNOWFLAKE CONNECTION: Cursor created")
125
+
126
+ # Set database and schema context
127
+ logger.info(f"πŸ”§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
128
+ try:
129
+ cursor.execute(f'USE DATABASE "{self.database}"')
130
+ cursor.execute(f'USE SCHEMA "{self.schema}"')
131
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
132
+ current_db, current_schema = cursor.fetchone()
133
+ logger.info(f"βœ… Current context verified: Database={current_db}, Schema={current_schema}")
134
+ except Exception as e:
135
+ logger.error(f"❌ Could not set context: {e}")
136
+ raise
137
+
138
+ # Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
139
+ logger.info("πŸ”§ DATA PREPARATION: Preparing VARIANT columns...")
140
+ feedback_dict = feedback.to_dict()
141
+
142
+ # Prepare transcript (ARRAY) - convert to JSON string
143
+ transcript_raw = feedback_dict.get('transcript', [])
144
+ if transcript_raw:
145
+ # Convert to JSON string (same approach as old retrieved_data)
146
+ transcript_for_db = json.dumps(transcript_raw)
147
+ logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
148
+ else:
149
+ transcript_for_db = None
150
+ logger.info(" - Transcript: None")
151
+
152
+ # Prepare retrievals (ARRAY) - convert to JSON string
153
+ retrievals_raw = feedback_dict.get('retrievals', [])
154
+ if retrievals_raw:
155
+ # Convert to JSON string (same approach as old retrieved_data)
156
+ retrievals_for_db = json.dumps(retrievals_raw)
157
+ logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
158
+ else:
159
+ retrievals_for_db = None
160
+ logger.info(" - Retrievals: None")
161
+
162
+ # Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
163
+ feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
164
+ if feedback_score_related_raw:
165
+ # Convert to JSON string (same approach as old retrieved_data)
166
+ feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
167
+ logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
168
+ else:
169
+ feedback_score_related_for_db = None
170
+ logger.info(" - Feedback score related docs: None")
171
+
172
+ # Prepare retrieved_data (preserved old column) - convert to JSON string
173
+ retrieved_data_raw = feedback_dict.get('retrieved_data')
174
+ if retrieved_data_raw:
175
+ # Convert to JSON string (same approach as old retrieved_data)
176
+ retrieved_data_for_db = json.dumps(retrieved_data_raw)
177
+ logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
178
+ else:
179
+ retrieved_data_for_db = None
180
+ logger.info(" - Retrieved data (preserved): None")
181
+
182
+ # Build SQL with new column structure
183
+ # Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
184
+ sql = f"""INSERT INTO {table_name} (
185
+ feedback_id,
186
+ open_ended_feedback,
187
+ score,
188
+ is_feedback_about_last_retrieval,
189
+ conversation_id,
190
+ timestamp,
191
+ message_count,
192
+ has_retrievals,
193
+ retrieval_count,
194
+ transcript,
195
+ retrievals,
196
+ feedback_score_related_retrieval_docs,
197
+ retrieved_data,
198
+ created_at
199
+ ) VALUES (
200
+ %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
201
+ %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
202
+ %(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
203
+ %(retrieved_data)s, %(created_at)s
204
+ )"""
205
+
206
+ logger.info("πŸ“ SQL PREPARATION: Building INSERT statement...")
207
+ logger.info(f" - Target table: {table_name}")
208
+ logger.info(f" - Database: {self.database}")
209
+ logger.info(f" - Schema: {self.schema}")
210
+
211
+ # Prepare parameters
212
+ # Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
213
+ params = {
214
+ 'feedback_id': feedback.feedback_id,
215
+ 'open_ended_feedback': feedback.open_ended_feedback,
216
+ 'score': feedback.score,
217
+ 'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
218
+ 'conversation_id': feedback.conversation_id,
219
+ 'timestamp': int(feedback.timestamp),
220
+ 'message_count': feedback.message_count,
221
+ 'has_retrievals': feedback.has_retrievals,
222
+ 'retrieval_count': feedback.retrieval_count,
223
+ 'transcript': transcript_for_db, # JSON string
224
+ 'retrievals': retrievals_for_db, # JSON string
225
+ 'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
226
+ 'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
227
+ 'created_at': feedback.created_at
228
+ }
229
+
230
+ # Execute insert
231
+ logger.info("πŸš€ SQL EXECUTION: Executing INSERT query...")
232
+ cursor.execute(sql, params)
233
+
234
+ logger.info("βœ… SQL EXECUTION: Query executed successfully")
235
+ logger.info(f" - Rows affected: 1")
236
+ logger.info(f" - Status: SUCCESS")
237
+
238
+ cursor.close()
239
+ logger.info("βœ… SNOWFLAKE INSERT: Feedback inserted successfully")
240
+ logger.info(f"πŸ“ Inserted feedback: {feedback.feedback_id}")
241
+ logger.info("=" * 80)
242
+ return True
243
+
244
+ except Exception as e:
245
+ # Check if it's a Snowflake error
246
+ if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
247
+ logger.error(f"❌ SQL EXECUTION ERROR: {e}")
248
+ logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
249
+ logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
250
+ else:
251
+ logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
252
+ logger.error(f" - Error: {e}")
253
+ logger.error("=" * 80)
254
+ return False
255
+
256
+ def __enter__(self):
257
+ """Context manager entry"""
258
+ self.connect()
259
+ return self
260
+
261
+ def __exit__(self, exc_type, exc_val, exc_tb):
262
+ """Context manager exit"""
263
+ self.disconnect()
264
+
265
+
266
+ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
267
+ """Create Snowflake connector from environment variables"""
268
+ user = os.getenv("SNOWFLAKE_USER")
269
+ password = os.getenv("SNOWFLAKE_PASSWORD")
270
+ account = os.getenv("SNOWFLAKE_ACCOUNT")
271
+ warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
272
+ database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
273
+ schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
274
+
275
+ if not all([user, password, account, warehouse]):
276
+ print("⚠️ Snowflake credentials not found in environment variables")
277
+ print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
278
+ return None
279
+
280
+ return SnowflakeFeedbackConnector(
281
+ user=user,
282
+ password=password,
283
+ account=account,
284
+ warehouse=warehouse,
285
+ database=database,
286
+ schema=schema
287
+ )
288
+
289
+
290
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
291
+ """Helper function to save feedback to Snowflake"""
292
+ logger.info("=" * 80)
293
+ logger.info("πŸ”΅ SNOWFLAKE SAVE: Starting save process")
294
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
295
+
296
+ # Get table name from parameter or env var
297
+ if table_name is None:
298
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
299
+
300
+ connector = get_snowflake_connector_from_env()
301
+
302
+ if not connector:
303
+ logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
304
+ logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
305
+ logger.info("=" * 80)
306
+ return False
307
+
308
+ try:
309
+ logger.info("πŸ“‘ SNOWFLAKE SAVE: Establishing connection...")
310
+ connector.connect()
311
+ logger.info("βœ… SNOWFLAKE SAVE: Connection established")
312
+
313
+ logger.info("πŸ“₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
314
+ success = connector.insert_feedback(feedback, table_name=table_name)
315
+
316
+ logger.info("πŸ”Œ SNOWFLAKE SAVE: Disconnecting...")
317
+ connector.disconnect()
318
+
319
+ if success:
320
+ logger.info("βœ… SNOWFLAKE SAVE: Successfully saved feedback")
321
+ else:
322
+ logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
323
+
324
+ logger.info("=" * 80)
325
+ return success
326
+ except Exception as e:
327
+ logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
328
+ logger.error(f" - Error: {e}")
329
+ logger.info("=" * 80)
330
+ return False
331
+
src/gemini/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Integration Module
3
+
4
+ This module provides integration with Google Gemini File Search API
5
+ for RAG functionality using Gemini's built-in file search capabilities.
6
+ """
7
+
8
+ from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
9
+
10
+ __all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
11
+
src/gemini/file_search.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Client
3
+
4
+ Handles interaction with Google Gemini File Search API for RAG.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from typing import List, Dict, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ try:
15
+ from google import genai
16
+ from google.genai import types
17
+ GEMINI_AVAILABLE = True
18
+ except ImportError:
19
+ GEMINI_AVAILABLE = False
20
+
21
+
22
+ @dataclass
23
+ class GeminiFileSearchResult:
24
+ """Result from Gemini File Search query"""
25
+ answer: str
26
+ sources: List[Dict[str, Any]] # List of document references
27
+ grounding_metadata: Optional[Dict[str, Any]] = None
28
+ query: str = ""
29
+
30
+
31
+ class GeminiFileSearchClient:
32
+ """Client for interacting with Gemini File Search API"""
33
+
34
+ def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
35
+ """
36
+ Initialize Gemini File Search client.
37
+
38
+ Args:
39
+ api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
40
+ store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
41
+ """
42
+ if not GEMINI_AVAILABLE:
43
+ raise ImportError("google-genai package not installed. Install with: pip install google-genai")
44
+
45
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
46
+ if not self.api_key:
47
+ raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
48
+
49
+ self.store_name = store_name or os.getenv("GEMINI_FILESTORE_NAME")
50
+ if not self.store_name:
51
+ raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
52
+
53
+ self.client = genai.Client(api_key=self.api_key)
54
+ self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
55
+
56
+ def search(
57
+ self,
58
+ query: str,
59
+ filters: Optional[Dict[str, Any]] = None,
60
+ model: Optional[str] = None
61
+ ) -> GeminiFileSearchResult:
62
+ """
63
+ Search using Gemini File Search.
64
+
65
+ Args:
66
+ query: User query
67
+ filters: Optional filters (year, source, district, etc.)
68
+ model: Model to use (defaults to gemini-2.5-flash)
69
+
70
+ Returns:
71
+ GeminiFileSearchResult with answer and sources
72
+ """
73
+ model = model or self.model
74
+
75
+ # Build filter context for the query if filters are provided
76
+ # Gemini File Search doesn't support explicit filters in the API,
77
+ # so we add them as context in the query
78
+ filter_context = ""
79
+ if filters:
80
+ filter_parts = []
81
+ if filters.get("year"):
82
+ years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
83
+ filter_parts.append(f"Year: {', '.join(years)}")
84
+ if filters.get("sources"):
85
+ sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
86
+ filter_parts.append(f"Source: {', '.join(sources)}")
87
+ if filters.get("district"):
88
+ districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
89
+ filter_parts.append(f"District: {', '.join(districts)}")
90
+ if filters.get("filenames"):
91
+ filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
92
+ filter_parts.append(f"Filename: {', '.join(filenames)}")
93
+
94
+ if filter_parts:
95
+ filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
96
+
97
+ # Combine query with filter context
98
+ # Add explicit instruction to only use information from retrieved documents
99
+ instruction = "\n\nIMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents. If the retrieved documents don't contain the requested information, clearly state that.\n\n"
100
+ full_query = query + filter_context + instruction
101
+
102
+ try:
103
+ # Generate content with file search
104
+ # Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
105
+ try:
106
+ # Try the documented format first
107
+ response = self.client.models.generate_content(
108
+ model=model,
109
+ contents=full_query,
110
+ config=types.GenerateContentConfig(
111
+ tools=[
112
+ types.Tool(
113
+ file_search=types.FileSearch(
114
+ file_search_store_names=[self.store_name]
115
+ )
116
+ )
117
+ ]
118
+ )
119
+ )
120
+ except (AttributeError, TypeError) as e:
121
+ # Fallback: try alternative format
122
+ logger.warning(f"Primary API format failed, trying alternative: {e}")
123
+ try:
124
+ response = self.client.models.generate_content(
125
+ model=model,
126
+ contents=full_query,
127
+ tools=[{
128
+ "file_search": {
129
+ "file_search_store_names": [self.store_name]
130
+ }
131
+ }]
132
+ )
133
+ except Exception as e2:
134
+ raise Exception(f"Failed to call Gemini API: {e2}")
135
+
136
+ # Extract answer
137
+ answer = ""
138
+ if hasattr(response, 'text'):
139
+ answer = response.text
140
+ elif hasattr(response, 'candidates') and response.candidates:
141
+ # Try to get text from first candidate
142
+ candidate = response.candidates[0]
143
+ if hasattr(candidate, 'content') and candidate.content:
144
+ if hasattr(candidate.content, 'parts'):
145
+ text_parts = []
146
+ for part in candidate.content.parts:
147
+ if hasattr(part, 'text'):
148
+ text_parts.append(part.text)
149
+ answer = " ".join(text_parts)
150
+ elif isinstance(candidate.content, str):
151
+ answer = candidate.content
152
+ else:
153
+ answer = str(response)
154
+
155
+ # Extract grounding metadata (document references)
156
+ sources = []
157
+ grounding_metadata = None
158
+
159
+ if hasattr(response, 'candidates') and response.candidates:
160
+ candidate = response.candidates[0]
161
+
162
+ # Get grounding metadata
163
+ if hasattr(candidate, 'grounding_metadata'):
164
+ grounding_metadata = candidate.grounding_metadata
165
+
166
+ # Extract source documents from grounding metadata
167
+ # Handle different response formats
168
+ grounding_chunks = None
169
+ if hasattr(grounding_metadata, 'grounding_chunks'):
170
+ grounding_chunks = grounding_metadata.grounding_chunks
171
+ elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
172
+ grounding_chunks = grounding_metadata['grounding_chunks']
173
+
174
+ if grounding_chunks:
175
+ for chunk in grounding_chunks:
176
+ # Handle both object and dict formats
177
+ try:
178
+ if isinstance(chunk, dict):
179
+ chunk_data = chunk
180
+ else:
181
+ # Object format - convert to dict-like access
182
+ chunk_data = {}
183
+ if hasattr(chunk, 'chunk'):
184
+ chunk_obj = chunk.chunk
185
+ chunk_data['chunk'] = {
186
+ 'text': getattr(chunk_obj, 'text', ''),
187
+ 'file_name': getattr(chunk_obj, 'file_name', '')
188
+ }
189
+ if hasattr(chunk, 'relevance_score'):
190
+ score_obj = chunk.relevance_score
191
+ chunk_data['relevance_score'] = {
192
+ 'score': getattr(score_obj, 'score', 0.0)
193
+ }
194
+
195
+ chunk_info = chunk_data.get('chunk', {})
196
+ text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
197
+ file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
198
+
199
+ score_data = chunk_data.get('relevance_score', {})
200
+ score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
201
+
202
+ if text or file_name: # Only add if we have content
203
+ source_info = {
204
+ "content": text,
205
+ "filename": file_name,
206
+ "score": score,
207
+ }
208
+ sources.append(source_info)
209
+ except Exception as e:
210
+ logger.warning(f"Error extracting chunk info: {e}")
211
+ continue
212
+
213
+ return GeminiFileSearchResult(
214
+ answer=answer,
215
+ sources=sources,
216
+ grounding_metadata=grounding_metadata,
217
+ query=query
218
+ )
219
+
220
+ except Exception as e:
221
+ # Return error result
222
+ return GeminiFileSearchResult(
223
+ answer=f"I apologize, but I encountered an error: {str(e)}",
224
+ sources=[],
225
+ query=query
226
+ )
227
+
228
+ def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
229
+ """
230
+ Format Gemini sources to match the format expected by the UI.
231
+
232
+ Returns list of document-like objects compatible with existing display code.
233
+ """
234
+ from langchain.docstore.document import Document
235
+
236
+ formatted_sources = []
237
+
238
+ for i, source in enumerate(result.sources):
239
+ # Create a Document object compatible with existing code
240
+ doc = Document(
241
+ page_content=source.get("content", ""),
242
+ metadata={
243
+ "filename": source.get("filename", "Unknown"),
244
+ "source": "Gemini File Search",
245
+ "score": source.get("score"),
246
+ "chunk_index": i,
247
+ # Add default fields that might be expected
248
+ "page": None,
249
+ "year": None,
250
+ "district": None,
251
+ }
252
+ )
253
+ formatted_sources.append(doc)
254
+
255
+ return formatted_sources
256
+
src/{loader.py β†’ llm/loader.py} RENAMED
File without changes
src/pipeline.py CHANGED
@@ -14,7 +14,7 @@ except ModuleNotFoundError as me:
14
 
15
  from .logging import log_error
16
 
17
- from .loader import chunks_to_documents
18
  from .vectorstore import VectorStoreManager
19
  from .reporting.service import ReportService
20
  from .retrieval.context import ContextRetriever
 
14
 
15
  from .logging import log_error
16
 
17
+ from .llm.loader import chunks_to_documents
18
  from .vectorstore import VectorStoreManager
19
  from .reporting.service import ReportService
20
  from .retrieval.context import ContextRetriever
src/reporting/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- """Report metadata and utilities."""
 
 
 
 
2
 
3
  from .metadata import get_report_metadata, get_available_sources
4
  from .service import ReportService
 
1
+ """Report metadata and utilities.
2
+
3
+ This module is kept for backward compatibility with pipeline.py.
4
+ For feedback-related functionality, use src.feedback instead.
5
+ """
6
 
7
  from .metadata import get_report_metadata, get_available_sources
8
  from .service import ReportService
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ui_components/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI Components Module
3
+
4
+ This module contains UI-related components including styles, visualizations,
5
+ and utility functions for the Streamlit application.
6
+ """
7
+
8
+ from .styles import get_custom_css
9
+ from .components import (
10
+ display_chunk_statistics_charts,
11
+ display_chunk_statistics_table
12
+ )
13
+ from .utils import extract_chunk_statistics
14
+
15
+ __all__ = [
16
+ "get_custom_css",
17
+ "display_chunk_statistics_charts",
18
+ "display_chunk_statistics_table",
19
+ "extract_chunk_statistics"
20
+ ]
21
+
src/ui_components/components.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI components for displaying statistics and visualizations
3
+ """
4
+
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ from typing import Dict, Any
9
+
10
+
11
+ def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
12
+ """Display statistics as interactive charts for 10+ results."""
13
+ if not stats or stats.get('total_chunks', 0) == 0:
14
+ return
15
+
16
+ # Wrap everything in one styled container - open it
17
+ st.markdown(f"""
18
+ <div class="retrieval-distribution-container">
19
+ <h3 style="margin-top: 0;">πŸ“Š {title}</h3>
20
+ <div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
21
+ <div class="metric-container">
22
+ <div class="metric-label">Total Chunks</div>
23
+ <div class="metric-value">{stats['total_chunks']}</div>
24
+ </div>
25
+ <div class="metric-container">
26
+ <div class="metric-label">Unique Sources</div>
27
+ <div class="metric-value">{stats['unique_sources']}</div>
28
+ </div>
29
+ <div class="metric-container">
30
+ <div class="metric-label">Unique Years</div>
31
+ <div class="metric-value">{stats['unique_years']}</div>
32
+ </div>
33
+ <div class="metric-container">
34
+ <div class="metric-label">Unique Files</div>
35
+ <div class="metric-value">{stats['unique_filenames']}</div>
36
+ </div>
37
+ </div>
38
+ """, unsafe_allow_html=True)
39
+
40
+ # Charts - three columns to include Districts
41
+ col1, col2, col3 = st.columns(3)
42
+
43
+ with col1:
44
+ # Source distribution chart
45
+ if stats['source_distribution']:
46
+ source_df = pd.DataFrame(
47
+ list(stats['source_distribution'].items()),
48
+ columns=['Source', 'Count']
49
+ )
50
+ fig_source = px.bar(
51
+ source_df,
52
+ x='Count',
53
+ y='Source',
54
+ orientation='h',
55
+ title='Distribution by Source',
56
+ color='Count',
57
+ color_continuous_scale='viridis'
58
+ )
59
+ fig_source.update_layout(height=400, showlegend=False)
60
+ st.plotly_chart(fig_source, use_container_width=True)
61
+
62
+ with col2:
63
+ # Year distribution chart
64
+ if stats['year_distribution']:
65
+ # Filter out 'Unknown' years for the chart
66
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
67
+ if year_dist_filtered:
68
+ year_df = pd.DataFrame(
69
+ list(year_dist_filtered.items()),
70
+ columns=['Year', 'Count']
71
+ )
72
+ # Sort by year as integer but keep as string for categorical display
73
+ year_df['Year_Int'] = year_df['Year'].astype(int)
74
+ year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
75
+
76
+ fig_year = px.bar(
77
+ year_df,
78
+ x='Year',
79
+ y='Count',
80
+ title='Distribution by Year',
81
+ color='Count',
82
+ color_continuous_scale='plasma'
83
+ )
84
+ # Ensure years are treated as categorical (discrete) not continuous
85
+ fig_year.update_xaxes(type='category')
86
+ fig_year.update_layout(height=400, showlegend=False)
87
+ st.plotly_chart(fig_year, use_container_width=True)
88
+ else:
89
+ st.info("No valid years found in the results")
90
+
91
+ with col3:
92
+ # District distribution chart
93
+ if stats.get('district_distribution'):
94
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
95
+ if district_dist_filtered:
96
+ district_df = pd.DataFrame(
97
+ list(district_dist_filtered.items()),
98
+ columns=['District', 'Count']
99
+ )
100
+ district_df = district_df.sort_values('Count', ascending=False)
101
+
102
+ fig_district = px.bar(
103
+ district_df,
104
+ x='Count',
105
+ y='District',
106
+ orientation='h',
107
+ title='Distribution by District',
108
+ color='Count',
109
+ color_continuous_scale='blues'
110
+ )
111
+ fig_district.update_layout(height=400, showlegend=False)
112
+ st.plotly_chart(fig_district, use_container_width=True)
113
+ else:
114
+ st.info("No valid districts found in the results")
115
+
116
+ # Close the container
117
+ st.markdown('</div>', unsafe_allow_html=True)
118
+
119
+
120
+ def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
121
+ """Display statistics as tables for smaller results with fixed alignment."""
122
+ if not stats or stats.get('total_chunks', 0) == 0:
123
+ return
124
+
125
+ # Wrap in styled container
126
+ st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
127
+
128
+ st.subheader(f"πŸ“Š {title}")
129
+
130
+ # Create a container with fixed height for alignment
131
+ stats_container = st.container()
132
+
133
+ with stats_container:
134
+ # Create 4 equal columns for consistent alignment
135
+ col1, col2, col3, col4 = st.columns(4)
136
+
137
+ with col1:
138
+ st.markdown("**🏘️ Districts**")
139
+ if stats.get('district_distribution'):
140
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
141
+ if district_dist_filtered:
142
+ district_data = {
143
+ "District": list(district_dist_filtered.keys()),
144
+ "Count": list(district_dist_filtered.values())
145
+ }
146
+ district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
147
+ st.dataframe(district_df, hide_index=True, use_container_width=True)
148
+ else:
149
+ st.write("No district data")
150
+ else:
151
+ st.write("No district data")
152
+
153
+ with col2:
154
+ st.markdown("**πŸ“‚ Sources**")
155
+ if stats['source_distribution']:
156
+ source_data = {
157
+ "Source": list(stats['source_distribution'].keys()),
158
+ "Count": list(stats['source_distribution'].values())
159
+ }
160
+ source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
161
+ st.dataframe(source_df, hide_index=True, use_container_width=True)
162
+ else:
163
+ st.write("No source data")
164
+
165
+ with col3:
166
+ st.markdown("**πŸ“… Years**")
167
+ if stats['year_distribution']:
168
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
169
+ if year_dist_filtered:
170
+ year_data = {
171
+ "Year": list(year_dist_filtered.keys()),
172
+ "Count": list(year_dist_filtered.values())
173
+ }
174
+ year_df = pd.DataFrame(year_data)
175
+ # Sort by year as integer but display as string
176
+ year_df['Year_Int'] = year_df['Year'].astype(int)
177
+ year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
178
+ st.dataframe(year_df, hide_index=True, use_container_width=True)
179
+ else:
180
+ st.write("No year data")
181
+ else:
182
+ st.write("No year data")
183
+
184
+ with col4:
185
+ st.markdown("**πŸ“„ Files**")
186
+ if stats['filename_distribution']:
187
+ filename_items = list(stats['filename_distribution'].items())
188
+ filename_items.sort(key=lambda x: x[1], reverse=True)
189
+
190
+ # Show top files with truncated names
191
+ file_data = {
192
+ "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
193
+ "Count": [c for f, c in filename_items[:5]]
194
+ }
195
+ file_df = pd.DataFrame(file_data)
196
+ st.dataframe(file_df, hide_index=True, use_container_width=True)
197
+ else:
198
+ st.write("No file data")
199
+
200
+ # Close container
201
+ st.markdown('</div>', unsafe_allow_html=True)
202
+
src/ui_components/styles.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom CSS styles for Streamlit application
3
+ """
4
+
5
+
6
+ def get_custom_css() -> str:
7
+ """Get custom CSS styles as a string"""
8
+ return """
9
+ <style>
10
+ .main-header {
11
+ font-size: 2.5rem;
12
+ font-weight: bold;
13
+ color: #1f77b4;
14
+ text-align: center;
15
+ margin-bottom: 1rem;
16
+ width: 100%;
17
+ display: block;
18
+ }
19
+
20
+ .subtitle {
21
+ font-size: 1.2rem;
22
+ color: #666;
23
+ text-align: center;
24
+ margin-bottom: 2rem;
25
+ width: 100%;
26
+ display: block;
27
+ }
28
+
29
+ .session-info {
30
+ background-color: #f0f2f6;
31
+ padding: 10px;
32
+ border-radius: 5px;
33
+ margin-bottom: 20px;
34
+ font-size: 0.9rem;
35
+ }
36
+
37
+ .user-message {
38
+ background-color: #007bff;
39
+ color: white;
40
+ padding: 12px 16px;
41
+ border-radius: 18px 18px 4px 18px;
42
+ margin: 8px 0;
43
+ margin-left: 20%;
44
+ word-wrap: break-word;
45
+ }
46
+
47
+ .bot-message {
48
+ background-color: #f1f3f4;
49
+ color: #333;
50
+ padding: 12px 16px;
51
+ border-radius: 18px 18px 18px 4px;
52
+ margin: 8px 0;
53
+ margin-right: 20%;
54
+ word-wrap: break-word;
55
+ border: 1px solid #e0e0e0;
56
+ }
57
+
58
+ .filter-section {
59
+ margin-bottom: 20px;
60
+ padding: 15px;
61
+ background-color: #f8f9fa;
62
+ border-radius: 8px;
63
+ border: 1px solid #e9ecef;
64
+ }
65
+
66
+ .filter-title {
67
+ font-weight: bold;
68
+ margin-bottom: 10px;
69
+ color: #495057;
70
+ }
71
+
72
+ .feedback-section {
73
+ background-color: #f8f9fa;
74
+ padding: 20px;
75
+ border-radius: 10px;
76
+ margin-top: 30px;
77
+ border: 2px solid #dee2e6;
78
+ }
79
+
80
+ .retrieval-history {
81
+ background-color: #ffffff;
82
+ padding: 15px;
83
+ border-radius: 5px;
84
+ margin: 10px 0;
85
+ border-left: 4px solid #007bff;
86
+ }
87
+
88
+ .retrieval-distribution-container {
89
+ background-color: #ffffff;
90
+ padding: 25px;
91
+ border-radius: 10px;
92
+ margin: 20px 0;
93
+ border: 2px solid #e0e0e0;
94
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
95
+ }
96
+
97
+ .metric-label {
98
+ font-size: 0.9rem;
99
+ color: #555;
100
+ margin-bottom: 5px;
101
+ text-align: center;
102
+ }
103
+
104
+ .metric-value {
105
+ font-size: 1.8rem;
106
+ font-weight: bold;
107
+ color: #000000;
108
+ text-align: center;
109
+ }
110
+
111
+ .metric-container {
112
+ text-align: center;
113
+ padding: 10px;
114
+ }
115
+ </style>
116
+ """
117
+
src/ui_components/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI utility functions for data processing and statistics
3
+ """
4
+
5
+ from typing import Dict, Any, List
6
+ from collections import Counter
7
+
8
+
9
+ def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
10
+ """Extract statistics from retrieved chunks."""
11
+ if not sources:
12
+ return {}
13
+
14
+ sources_list = []
15
+ years = []
16
+ filenames = []
17
+ districts = []
18
+
19
+ for doc in sources:
20
+ metadata = getattr(doc, 'metadata', {})
21
+
22
+ # Extract source
23
+ source = metadata.get('source', 'Unknown')
24
+ sources_list.append(source)
25
+
26
+ # Extract year
27
+ year = metadata.get('year', 'Unknown')
28
+ if year and year != 'Unknown':
29
+ try:
30
+ # Convert to int first, then back to string to ensure it's a proper year
31
+ year_int = int(float(year)) # Handle both int and float strings
32
+ if 1900 <= year_int <= 2030: # Reasonable year range
33
+ years.append(str(year_int))
34
+ else:
35
+ years.append('Unknown')
36
+ except (ValueError, TypeError):
37
+ years.append('Unknown')
38
+ else:
39
+ years.append('Unknown')
40
+
41
+ # Extract filename
42
+ filename = metadata.get('filename', 'Unknown')
43
+ filenames.append(filename)
44
+
45
+ # Extract district
46
+ district = metadata.get('district', 'Unknown')
47
+ if district and district != 'Unknown':
48
+ districts.append(district)
49
+ else:
50
+ districts.append('Unknown')
51
+
52
+ # Count occurrences
53
+ source_counts = Counter(sources_list)
54
+ year_counts = Counter(years)
55
+ filename_counts = Counter(filenames)
56
+ district_counts = Counter(districts)
57
+
58
+ return {
59
+ 'total_chunks': len(sources),
60
+ 'unique_sources': len(source_counts),
61
+ 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
62
+ 'unique_filenames': len(filename_counts),
63
+ 'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
64
+ 'source_distribution': dict(source_counts),
65
+ 'year_distribution': dict(year_counts),
66
+ 'filename_distribution': dict(filename_counts),
67
+ 'district_distribution': dict(district_counts),
68
+ 'sources': sources_list,
69
+ 'years': years,
70
+ 'filenames': filenames,
71
+ 'districts': districts
72
+ }
73
+
utils.py β†’ src/utils.py RENAMED
File without changes
src/vectorstore.py CHANGED
@@ -28,11 +28,19 @@ class MatryoshkaEmbeddings(Embeddings):
28
 
29
  if truncate_dim and "matryoshka" in model_name.lower():
30
  # Use SentenceTransformer directly for Matryoshka models
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
 
 
 
 
33
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
34
  else:
35
  # Use standard HuggingFaceEmbeddings
 
 
 
 
36
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
37
 
38
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -76,12 +84,15 @@ class VectorStoreManager:
76
 
77
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
78
  """Create embeddings model from configuration."""
79
- device = "cuda" if torch.cuda.is_available() else "cpu"
80
-
81
  model_name = self.config["retriever"]["model"]
82
  normalize = self.config["retriever"]["normalize"]
83
 
84
- model_kwargs = {"device": device}
 
 
 
 
 
85
  encode_kwargs = {
86
  "normalize_embeddings": normalize,
87
  "batch_size": 100,
 
28
 
29
  if truncate_dim and "matryoshka" in model_name.lower():
30
  # Use SentenceTransformer directly for Matryoshka models
31
+ # Explicitly load on CPU first to avoid meta tensor issues
32
+ self.model = SentenceTransformer(
33
+ model_name,
34
+ truncate_dim=truncate_dim,
35
+ device="cpu" # Load on CPU first, prevents meta tensor error
36
+ )
37
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
38
  else:
39
  # Use standard HuggingFaceEmbeddings
40
+ # Pass device="cpu" to prevent meta tensor issues
41
+ if "model_kwargs" not in kwargs:
42
+ kwargs["model_kwargs"] = {}
43
+ kwargs["model_kwargs"]["device"] = "cpu"
44
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
45
 
46
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
84
 
85
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
86
  """Create embeddings model from configuration."""
 
 
87
  model_name = self.config["retriever"]["model"]
88
  normalize = self.config["retriever"]["normalize"]
89
 
90
+ # Fix for meta tensor issue: explicitly load on CPU first
91
+ # This prevents HuggingFaceEmbeddings from trying to move meta tensors
92
+ # The model will be loaded on CPU and can be moved later if needed
93
+ model_kwargs = {
94
+ "device": "cpu" # Load on CPU first to avoid meta tensor issues
95
+ }
96
  encode_kwargs = {
97
  "normalize_embeddings": normalize,
98
  "batch_size": 100,
upload_to_gemini_filestore.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Upload Documents to Google Gemini File Search Store
4
+
5
+ This script uploads PDF documents to a Gemini File Search store for RAG.
6
+ It processes documents from the reports directory and uploads them with metadata.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import time
13
+ from pathlib import Path
14
+ from typing import List, Dict, Any, Optional
15
+ from dotenv import load_dotenv
16
+
17
+ try:
18
+ from google import genai
19
+ from google.genai import types
20
+ GEMINI_AVAILABLE = True
21
+ except ImportError:
22
+ GEMINI_AVAILABLE = False
23
+ print("❌ google-genai package not installed. Install with: pip install google-genai")
24
+
25
+ # Load .env file
26
+ load_dotenv()
27
+
28
+
29
+ def extract_metadata_from_path(file_path: Path) -> Dict[str, Any]:
30
+ """Extract metadata from file path structure."""
31
+ # Example: /path/to/reports/Annual Consolidated OAG audit reports 2018/Annual Consolidated OAG audit reports 2018.pdf
32
+ parts = file_path.parts
33
+ filename = file_path.stem # Without extension
34
+
35
+ metadata = {
36
+ "filename": file_path.name,
37
+ "filepath": str(file_path),
38
+ }
39
+
40
+ # Extract year
41
+ year_match = None
42
+ for part in parts:
43
+ if any(year in part for year in ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025']):
44
+ for year in ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025']:
45
+ if year in part:
46
+ year_match = year
47
+ break
48
+ if year_match:
49
+ break
50
+
51
+ if year_match:
52
+ metadata["year"] = year_match
53
+
54
+ # Extract source/district
55
+ filename_lower = filename.lower()
56
+ if "consolidated" in filename_lower or "oag" in filename_lower:
57
+ metadata["source"] = "Consolidated"
58
+ elif "gulu" in filename_lower:
59
+ metadata["source"] = "Gulu DLG"
60
+ metadata["district"] = "Gulu"
61
+ elif "kalangala" in filename_lower:
62
+ metadata["source"] = "Kalangala DLG"
63
+ metadata["district"] = "Kalangala"
64
+ elif "kcca" in filename_lower:
65
+ metadata["source"] = "KCCA"
66
+ metadata["district"] = "Kampala"
67
+ elif "maaif" in filename_lower:
68
+ metadata["source"] = "MAAIF"
69
+ elif "mwts" in filename_lower:
70
+ metadata["source"] = "MWTS"
71
+
72
+ return metadata
73
+
74
+
75
+ def get_or_create_filestore(client: genai.Client, store_name: Optional[str] = None) -> str:
76
+ """Get existing file search store or create a new one."""
77
+ if store_name:
78
+ # Try to get existing store
79
+ try:
80
+ stores = client.file_search_stores.list()
81
+ for store in stores:
82
+ if store.name == store_name or store.display_name == store_name:
83
+ print(f"βœ… Using existing store: {store.display_name} ({store.name})")
84
+ return store.name
85
+ except Exception as e:
86
+ print(f"⚠️ Could not list stores: {e}")
87
+
88
+ # Create new store
89
+ display_name = store_name or "Audit Reports"
90
+ print(f"πŸ“ Creating new file search store: '{display_name}'...")
91
+
92
+ try:
93
+ file_search_store = client.file_search_stores.create(
94
+ config={'display_name': display_name}
95
+ )
96
+ print(f"βœ… Created store: {file_search_store.display_name} ({file_search_store.name})")
97
+ return file_search_store.name
98
+ except Exception as e:
99
+ print(f"❌ Failed to create store: {e}")
100
+ raise
101
+
102
+
103
+ def format_metadata_for_gemini(metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
104
+ """Format metadata dictionary for Gemini API customMetadata format.
105
+
106
+ Based on Gemini API, customMetadata should use:
107
+ - string_value for string fields
108
+ - numeric_value for numeric fields
109
+ """
110
+ custom_metadata = []
111
+
112
+ # Add year if available (as numeric_value)
113
+ if metadata.get('year'):
114
+ try:
115
+ year_int = int(metadata['year'])
116
+ custom_metadata.append({
117
+ 'key': 'year',
118
+ 'numeric_value': year_int
119
+ })
120
+ except (ValueError, TypeError):
121
+ # Fallback to string if not numeric
122
+ custom_metadata.append({
123
+ 'key': 'year',
124
+ 'string_value': str(metadata['year'])
125
+ })
126
+
127
+ # Add source if available (as string_value)
128
+ if metadata.get('source'):
129
+ custom_metadata.append({
130
+ 'key': 'source',
131
+ 'string_value': str(metadata['source'])
132
+ })
133
+
134
+ # Add district if available (as string_value)
135
+ if metadata.get('district'):
136
+ custom_metadata.append({
137
+ 'key': 'district',
138
+ 'string_value': str(metadata['district'])
139
+ })
140
+
141
+ # Add filename for reference (as string_value)
142
+ if metadata.get('filename'):
143
+ custom_metadata.append({
144
+ 'key': 'filename',
145
+ 'string_value': str(metadata['filename'])
146
+ })
147
+
148
+ return custom_metadata
149
+
150
+
151
+ def check_file_exists(client: genai.Client, store_name: str, filename: str) -> bool:
152
+ """Check if a file with the same name already exists in the store."""
153
+ try:
154
+ # List files in the store
155
+ store = client.file_search_stores.get(name=store_name)
156
+ # Note: The API might not have a direct list method, so we'll catch errors
157
+ return False # Assume not exists for now
158
+ except Exception:
159
+ return False # If we can't check, assume it doesn't exist
160
+
161
+
162
+ def upload_file_to_store(
163
+ client: genai.Client,
164
+ file_path: Path,
165
+ store_name: str,
166
+ metadata: Dict[str, Any],
167
+ skip_existing: bool = True
168
+ ) -> Optional[bool]:
169
+ """Upload a single file to the file search store with metadata."""
170
+ try:
171
+ print(f" πŸ“€ Uploading: {file_path.name}...")
172
+
173
+ # Format metadata for Gemini API
174
+ custom_metadata = format_metadata_for_gemini(metadata)
175
+
176
+ # Display metadata being uploaded
177
+ if custom_metadata:
178
+ metadata_parts = []
179
+ for m in custom_metadata:
180
+ if 'numeric_value' in m:
181
+ metadata_parts.append(f"{m['key']}={m['numeric_value']}")
182
+ elif 'string_value' in m:
183
+ metadata_parts.append(f"{m['key']}={m['string_value']}")
184
+ if metadata_parts:
185
+ print(f" πŸ“‹ Metadata: {', '.join(metadata_parts)}")
186
+
187
+ # Check if file already exists (if skip_existing is True)
188
+ if skip_existing:
189
+ # Note: We'll handle duplicates via error messages
190
+ pass
191
+
192
+ # Upload and import file with metadata
193
+ # Note: Gemini API may not support customMetadata in upload_to_file_search_store
194
+ # We'll try with metadata first, then fallback without it if it fails
195
+ upload_params = {
196
+ 'file': str(file_path),
197
+ 'file_search_store_name': store_name,
198
+ }
199
+
200
+ # Build config
201
+ config = {
202
+ 'display_name': metadata.get('filename', file_path.name),
203
+ }
204
+
205
+ # Upload file (metadata not supported in upload config per API)
206
+ # Note: Gemini File Search API doesn't support customMetadata in upload_to_file_search_store
207
+ # Metadata would need to be added via a separate API call after upload, if supported
208
+ # For now, we upload without metadata - the filename in display_name contains the info
209
+ upload_params['config'] = config
210
+ operation = client.file_search_stores.upload_to_file_search_store(**upload_params)
211
+
212
+ # Wait for import to complete
213
+ max_wait = 300 # 5 minutes max per file
214
+ start_time = time.time()
215
+
216
+ while not operation.done:
217
+ if time.time() - start_time > max_wait:
218
+ print(f" ⚠️ Timeout waiting for upload to complete")
219
+ return False
220
+
221
+ time.sleep(2)
222
+ try:
223
+ operation = client.operations.get(operation)
224
+ except Exception as op_error:
225
+ # Check if it's a "terminated" error (file might already exist)
226
+ error_str = str(op_error).lower()
227
+ if 'terminated' in error_str or 'already' in error_str:
228
+ print(f" ⚠️ File may already exist or upload was interrupted")
229
+ print(f" πŸ’‘ Skipping this file")
230
+ return None # Return None to indicate "skipped"
231
+ raise
232
+
233
+ # Check for errors in the operation result
234
+ if hasattr(operation, 'error') and operation.error:
235
+ error_msg = str(operation.error)
236
+ if 'terminated' in error_msg.lower() or 'already' in error_msg.lower():
237
+ print(f" ⚠️ File may already exist in the store")
238
+ print(f" πŸ’‘ Skipping this file")
239
+ return None # Return None to indicate "skipped" vs False for "failed"
240
+ print(f" ❌ Upload failed: {operation.error}")
241
+ return False
242
+
243
+ print(f" βœ… Uploaded successfully")
244
+ return True
245
+
246
+ except Exception as e:
247
+ error_str = str(e).lower()
248
+ # Handle specific error cases
249
+ if 'terminated' in error_str or 'already' in error_str or '400' in error_str:
250
+ print(f" ⚠️ Upload error: File may already exist or upload was interrupted")
251
+ print(f" πŸ’‘ Error details: {e}")
252
+ print(f" πŸ’‘ Skipping this file")
253
+ return None # Return None to indicate "skipped"
254
+ print(f" ❌ Error uploading {file_path.name}: {e}")
255
+ import traceback
256
+ traceback.print_exc()
257
+ return False
258
+
259
+
260
+ def find_report_files(reports_dir: Path) -> List[Path]:
261
+ """Find all PDF report files in the reports directory."""
262
+ pdf_files = []
263
+
264
+ if not reports_dir.exists():
265
+ print(f"❌ Reports directory not found: {reports_dir}")
266
+ return pdf_files
267
+
268
+ # Find all PDF files
269
+ for pdf_file in reports_dir.rglob("*.pdf"):
270
+ pdf_files.append(pdf_file)
271
+
272
+ return sorted(pdf_files)
273
+
274
+
275
+ def main():
276
+ """Main function to upload documents to Gemini File Search store."""
277
+ print("=" * 60)
278
+ print("Gemini File Search Store Upload Tool")
279
+ print("=" * 60)
280
+
281
+ if not GEMINI_AVAILABLE:
282
+ print("\n❌ Please install google-genai package:")
283
+ print(" pip install google-genai")
284
+ return 1
285
+
286
+ # Get API key
287
+ api_key = os.getenv("GEMINI_API_KEY")
288
+ if not api_key:
289
+ print("\n❌ GEMINI_API_KEY not found in environment variables")
290
+ print(" Please add GEMINI_API_KEY to your .env file")
291
+ return 1
292
+
293
+ # Get store name (optional)
294
+ store_name = os.getenv("GEMINI_FILESTORE_NAME")
295
+
296
+ # Get reports directory - try multiple possible locations
297
+ reports_dir_str = os.getenv("REPORTS_DIR")
298
+ if not reports_dir_str:
299
+ # Try common locations
300
+ possible_paths = [
301
+ "/Users/ayeroyan/workspace/chatbot-rag/reports",
302
+ Path(__file__).parent / "reports",
303
+ Path.cwd() / "reports",
304
+ ]
305
+ for path in possible_paths:
306
+ if Path(path).exists():
307
+ reports_dir_str = str(path)
308
+ break
309
+
310
+ if not reports_dir_str:
311
+ reports_dir_str = "/Users/ayeroyan/workspace/chatbot-rag/reports" # Default fallback
312
+
313
+ reports_dir = Path(reports_dir_str)
314
+
315
+ # Initialize Gemini client
316
+ print(f"\nπŸ”Œ Connecting to Gemini API...")
317
+ try:
318
+ client = genai.Client(api_key=api_key)
319
+ print(f" βœ… Connected")
320
+ except Exception as e:
321
+ print(f" ❌ Failed to connect: {e}")
322
+ return 1
323
+
324
+ # Get or create file search store
325
+ print(f"\nπŸ“¦ Setting up file search store...")
326
+ try:
327
+ store_name = get_or_create_filestore(client, store_name)
328
+ except Exception as e:
329
+ print(f" ❌ Failed to setup store: {e}")
330
+ return 1
331
+
332
+ # Find all PDF files
333
+ print(f"\nπŸ” Scanning for PDF files in: {reports_dir}")
334
+ pdf_files = find_report_files(reports_dir)
335
+
336
+ if not pdf_files:
337
+ print(f" ❌ No PDF files found in {reports_dir}")
338
+ return 1
339
+
340
+ print(f" βœ… Found {len(pdf_files)} PDF files")
341
+
342
+ # Upload files
343
+ print(f"\nπŸ“€ Uploading files to store...")
344
+ print(f" Store: {store_name}")
345
+ print(f" Files: {len(pdf_files)}")
346
+
347
+ uploaded = 0
348
+ failed = 0
349
+ skipped = 0
350
+
351
+ for i, pdf_file in enumerate(pdf_files, 1):
352
+ print(f"\n[{i}/{len(pdf_files)}] Processing: {pdf_file.name}")
353
+
354
+ # Extract metadata
355
+ metadata = extract_metadata_from_path(pdf_file)
356
+
357
+ # Display extracted metadata
358
+ metadata_info = []
359
+ if metadata.get('year'):
360
+ metadata_info.append(f"Year: {metadata['year']}")
361
+ if metadata.get('source'):
362
+ metadata_info.append(f"Source: {metadata['source']}")
363
+ if metadata.get('district'):
364
+ metadata_info.append(f"District: {metadata['district']}")
365
+
366
+ if metadata_info:
367
+ print(f" πŸ“Š Extracted metadata: {', '.join(metadata_info)}")
368
+
369
+ # Upload file with metadata
370
+ result = upload_file_to_store(client, pdf_file, store_name, metadata, skip_existing=True)
371
+
372
+ if result is True:
373
+ uploaded += 1
374
+ elif result is None: # Skipped (already exists)
375
+ skipped += 1
376
+ else: # Failed
377
+ failed += 1
378
+
379
+ # Small delay between uploads to avoid rate limits
380
+ if i < len(pdf_files):
381
+ time.sleep(1)
382
+
383
+ # Summary
384
+ print(f"\n" + "=" * 60)
385
+ print(f"Upload Summary")
386
+ print(f"=" * 60)
387
+ print(f" βœ… Uploaded: {uploaded}")
388
+ if skipped > 0:
389
+ print(f" ⏭️ Skipped (already exists): {skipped}")
390
+ print(f" ❌ Failed: {failed}")
391
+ print(f" πŸ“¦ Store: {store_name}")
392
+
393
+ if uploaded > 0:
394
+ print(f"\nβœ… Successfully uploaded {uploaded} files to Gemini File Search store!")
395
+ print(f" You can now use this store in the beta version of the chatbot.")
396
+
397
+ return 0 if failed == 0 else 1
398
+
399
+
400
+ if __name__ == "__main__":
401
+ sys.exit(main())
402
+
verify_qdrant_migration.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Qdrant Migration Verification Script
4
+
5
+ This script compares the source and destination Qdrant collections to verify
6
+ that the migration was successful. It:
7
+ 1. Compares collection configurations
8
+ 2. Fetches sample points from source
9
+ 3. Retrieves same points from destination using IDs
10
+ 4. Compares vectors, metadata, and all attributes
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ from typing import List, Dict, Any, Optional
16
+ from pathlib import Path
17
+ from qdrant_client import QdrantClient
18
+ import json
19
+
20
+ # Try to import config loader and dotenv for automatic source detection
21
+ try:
22
+ from src.config.loader import load_config
23
+ CONFIG_AVAILABLE = True
24
+ except ImportError:
25
+ CONFIG_AVAILABLE = False
26
+
27
+ try:
28
+ from dotenv import load_dotenv
29
+ DOTENV_AVAILABLE = True
30
+ except ImportError:
31
+ DOTENV_AVAILABLE = False
32
+
33
+ # Load .env file automatically if available
34
+ if DOTENV_AVAILABLE:
35
+ project_root = Path(__file__).parent
36
+ env_file = project_root / ".env"
37
+ if env_file.exists():
38
+ load_dotenv(env_file, override=True)
39
+ else:
40
+ load_dotenv(override=True)
41
+
42
+
43
+ def get_collection_info(client: QdrantClient, collection_name: str) -> Dict[str, Any]:
44
+ """Get collection information including vector size and point count."""
45
+ try:
46
+ collection_info = client.get_collection(collection_name)
47
+
48
+ # Handle different Qdrant client versions and response formats
49
+ if hasattr(collection_info, 'config'):
50
+ config = collection_info.config
51
+ if hasattr(config, 'params') and hasattr(config.params, 'vectors'):
52
+ vectors_config = config.params.vectors
53
+ if isinstance(vectors_config, dict):
54
+ vector_size = vectors_config.get('size')
55
+ distance = vectors_config.get('distance')
56
+ else:
57
+ vector_size = getattr(vectors_config, 'size', None)
58
+ distance = getattr(vectors_config, 'distance', None)
59
+ else:
60
+ vector_size = getattr(config, 'vector_size', None)
61
+ distance = getattr(config, 'distance', None)
62
+ else:
63
+ vector_size = getattr(collection_info, 'vector_size', None)
64
+ distance = getattr(collection_info, 'distance', None)
65
+
66
+ points_count = getattr(collection_info, 'points_count', 0)
67
+ indexed_vectors_count = getattr(collection_info, 'indexed_vectors_count', 0)
68
+
69
+ if vector_size is None:
70
+ try:
71
+ result, _ = client.scroll(collection_name=collection_name, limit=1, with_vectors=True)
72
+ if result and hasattr(result[0], 'vector') and result[0].vector:
73
+ vector_size = len(result[0].vector)
74
+ except Exception:
75
+ pass
76
+
77
+ return {
78
+ "vector_size": vector_size,
79
+ "distance": distance or "Cosine",
80
+ "points_count": points_count,
81
+ "indexed_vectors_count": indexed_vectors_count,
82
+ }
83
+ except Exception as e:
84
+ print(f"❌ Error getting collection info: {e}")
85
+ return None
86
+
87
+
88
+ def fetch_points_by_ids(client: QdrantClient, collection_name: str, point_ids: List) -> Dict:
89
+ """Fetch points by their IDs from a collection."""
90
+ try:
91
+ points = client.retrieve(
92
+ collection_name=collection_name,
93
+ ids=point_ids,
94
+ with_payload=True,
95
+ with_vectors=True
96
+ )
97
+ return {point.id: point for point in points}
98
+ except Exception as e:
99
+ print(f"❌ Error fetching points by IDs: {e}")
100
+ return {}
101
+
102
+
103
+ def compare_points(source_point, dest_point, point_id) -> Dict[str, Any]:
104
+ """Compare two points and return differences."""
105
+ differences = []
106
+ matches = []
107
+
108
+ # Compare IDs
109
+ if source_point.id == dest_point.id:
110
+ matches.append("ID")
111
+ else:
112
+ differences.append(f"ID: source={source_point.id}, dest={dest_point.id}")
113
+
114
+ # Compare vectors
115
+ source_vec = getattr(source_point, 'vector', None)
116
+ dest_vec = getattr(dest_point, 'vector', None)
117
+
118
+ if source_vec is None and dest_vec is None:
119
+ matches.append("Vector (both None)")
120
+ elif source_vec is None or dest_vec is None:
121
+ differences.append(f"Vector: source={'None' if source_vec is None else f'len={len(source_vec)}'}, dest={'None' if dest_vec is None else f'len={len(dest_vec)}'}")
122
+ elif len(source_vec) != len(dest_vec):
123
+ differences.append(f"Vector length: source={len(source_vec)}, dest={len(dest_vec)}")
124
+ else:
125
+ # Compare vector values (with tolerance for floating point)
126
+ import numpy as np
127
+ try:
128
+ vec_diff = np.abs(np.array(source_vec) - np.array(dest_vec))
129
+ max_diff = float(np.max(vec_diff))
130
+ if max_diff < 1e-6:
131
+ matches.append(f"Vector (max diff: {max_diff:.2e})")
132
+ else:
133
+ differences.append(f"Vector values differ (max diff: {max_diff:.2e})")
134
+ except Exception as e:
135
+ differences.append(f"Vector comparison error: {e}")
136
+
137
+ # Compare payloads
138
+ source_payload = getattr(source_point, 'payload', {}) or {}
139
+ dest_payload = getattr(dest_point, 'payload', {}) or {}
140
+
141
+ # Convert to dicts if needed
142
+ if hasattr(source_payload, '__dict__'):
143
+ source_payload = source_payload.__dict__
144
+ if hasattr(dest_payload, '__dict__'):
145
+ dest_payload = dest_payload.__dict__
146
+
147
+ source_keys = set(source_payload.keys())
148
+ dest_keys = set(dest_payload.keys())
149
+
150
+ if source_keys != dest_keys:
151
+ missing_in_dest = source_keys - dest_keys
152
+ extra_in_dest = dest_keys - source_keys
153
+ if missing_in_dest:
154
+ differences.append(f"Payload keys missing in dest: {missing_in_dest}")
155
+ if extra_in_dest:
156
+ differences.append(f"Payload keys extra in dest: {extra_in_dest}")
157
+
158
+ # Compare payload values
159
+ common_keys = source_keys & dest_keys
160
+ for key in common_keys:
161
+ source_val = source_payload[key]
162
+ dest_val = dest_payload[key]
163
+
164
+ if source_val == dest_val:
165
+ matches.append(f"Payload.{key}")
166
+ else:
167
+ # Handle nested structures
168
+ if isinstance(source_val, dict) and isinstance(dest_val, dict):
169
+ if source_val != dest_val:
170
+ differences.append(f"Payload.{key}: dicts differ")
171
+ elif isinstance(source_val, list) and isinstance(dest_val, list):
172
+ if source_val != dest_val:
173
+ differences.append(f"Payload.{key}: lists differ (len: {len(source_val)} vs {len(dest_val)})")
174
+ else:
175
+ differences.append(f"Payload.{key}: '{source_val}' != '{dest_val}'")
176
+
177
+ return {
178
+ "point_id": point_id,
179
+ "matches": matches,
180
+ "differences": differences,
181
+ "match_count": len(matches),
182
+ "diff_count": len(differences)
183
+ }
184
+
185
+
186
+ def main():
187
+ print("="*70)
188
+ print("Qdrant Migration Verification Script")
189
+ print("="*70)
190
+
191
+ # Auto-detect source from config and .env file
192
+ source_url = os.getenv('QDRANT_URL')
193
+ source_key = os.getenv('QDRANT_API_KEY')
194
+ source_collection = os.getenv('QDRANT_COLLECTION', 'docling')
195
+
196
+ if CONFIG_AVAILABLE:
197
+ try:
198
+ config = load_config()
199
+ qdrant_config = config.get('qdrant', {})
200
+ if not source_url:
201
+ source_url = qdrant_config.get('url')
202
+ if not source_key:
203
+ source_key = qdrant_config.get('api_key')
204
+ if not source_collection:
205
+ source_collection = qdrant_config.get('collection_name', 'docling')
206
+ except Exception as e:
207
+ print(f"⚠️ Could not load config: {e}")
208
+
209
+ # Get destination from env
210
+ dest_url = os.getenv('DEST_QDRANT_URL')
211
+ dest_key = os.getenv('DEST_QDRANT_API_KEY')
212
+ dest_collection = os.getenv('DEST_COLLECTION') # Optional, will auto-detect
213
+
214
+ # Validate
215
+ if not source_url or not source_key:
216
+ print("❌ Source Qdrant credentials missing!")
217
+ print(" Set QDRANT_URL and QDRANT_API_KEY in .env or environment")
218
+ return 1
219
+
220
+ if not dest_url or not dest_key:
221
+ print("❌ Destination Qdrant credentials missing!")
222
+ print(" Set DEST_QDRANT_URL and DEST_QDRANT_API_KEY in .env or environment")
223
+ return 1
224
+
225
+ print(f"\nπŸ“‹ Configuration:")
226
+ print(f" Source: {source_url}")
227
+ print(f" Source Collection: {source_collection}")
228
+ print(f" Destination: {dest_url}")
229
+ if dest_collection:
230
+ print(f" Destination Collection: {dest_collection} (specified)")
231
+ else:
232
+ print(f" Destination Collection: (auto-detect)")
233
+
234
+ # Connect to Qdrant instances
235
+ print(f"\nπŸ”Œ Connecting to Qdrant instances...")
236
+ try:
237
+ source_client = QdrantClient(url=source_url, api_key=source_key, timeout=120)
238
+ print(f" βœ… Connected to source")
239
+ except Exception as e:
240
+ print(f" ❌ Failed to connect to source: {e}")
241
+ return 1
242
+
243
+ try:
244
+ dest_client = QdrantClient(url=dest_url, api_key=dest_key, timeout=120)
245
+ print(f" βœ… Connected to destination")
246
+ except Exception as e:
247
+ print(f" ❌ Failed to connect to destination: {e}")
248
+ return 1
249
+
250
+ # Auto-detect destination collection if not specified
251
+ if not dest_collection:
252
+ try:
253
+ collections = dest_client.get_collections().collections
254
+ collection_names = [c.name for c in collections]
255
+ if len(collection_names) == 1:
256
+ dest_collection = collection_names[0]
257
+ print(f"\nπŸ“‹ Auto-detected destination collection: '{dest_collection}'")
258
+ elif len(collection_names) > 1:
259
+ print(f"\n⚠️ Found {len(collection_names)} collections in destination:")
260
+ for name in collection_names:
261
+ print(f" - {name}")
262
+ print(f"\n Using first collection: '{collection_names[0]}'")
263
+ dest_collection = collection_names[0]
264
+ else:
265
+ print("❌ No collections found in destination!")
266
+ return 1
267
+ except Exception as e:
268
+ print(f"❌ Could not list destination collections: {e}")
269
+ return 1
270
+
271
+ # Get collection info
272
+ print(f"\nπŸ“Š Collection Information Comparison")
273
+ print("="*70)
274
+
275
+ source_info = get_collection_info(source_client, source_collection)
276
+ dest_info = get_collection_info(dest_client, dest_collection)
277
+
278
+ if not source_info:
279
+ print("❌ Could not get source collection info")
280
+ return 1
281
+
282
+ if not dest_info:
283
+ print("❌ Could not get destination collection info")
284
+ return 1
285
+
286
+ print(f"\nSource Collection ('{source_collection}'):")
287
+ print(f" Vector size: {source_info['vector_size']}")
288
+ print(f" Distance: {source_info['distance']}")
289
+ print(f" Points: {source_info['points_count']:,}")
290
+ print(f" Indexed: {source_info['indexed_vectors_count']:,}")
291
+
292
+ print(f"\nDestination Collection ('{dest_collection}'):")
293
+ print(f" Vector size: {dest_info['vector_size']}")
294
+ print(f" Distance: {dest_info['distance']}")
295
+ print(f" Points: {dest_info['points_count']:,}")
296
+ print(f" Indexed: {dest_info['indexed_vectors_count']:,}")
297
+
298
+ # Compare configs
299
+ print(f"\nπŸ” Configuration Comparison:")
300
+ config_matches = []
301
+ config_diffs = []
302
+
303
+ if source_info['vector_size'] == dest_info['vector_size']:
304
+ config_matches.append(f"Vector size: {source_info['vector_size']}")
305
+ else:
306
+ config_diffs.append(f"Vector size: source={source_info['vector_size']}, dest={dest_info['vector_size']}")
307
+
308
+ if str(source_info['distance']) == str(dest_info['distance']):
309
+ config_matches.append(f"Distance: {source_info['distance']}")
310
+ else:
311
+ config_diffs.append(f"Distance: source={source_info['distance']}, dest={dest_info['distance']}")
312
+
313
+ if source_info['points_count'] == dest_info['points_count']:
314
+ config_matches.append(f"Points count: {source_info['points_count']:,}")
315
+ else:
316
+ config_diffs.append(f"Points count: source={source_info['points_count']:,}, dest={dest_info['points_count']:,}")
317
+
318
+ if config_matches:
319
+ print(f" βœ… Matches: {len(config_matches)}")
320
+ for match in config_matches:
321
+ print(f" - {match}")
322
+
323
+ if config_diffs:
324
+ print(f" ❌ Differences: {len(config_diffs)}")
325
+ for diff in config_diffs:
326
+ print(f" - {diff}")
327
+
328
+ # Fetch sample points from source
329
+ print(f"\nπŸ“₯ Fetching sample points from source...")
330
+ sample_size = 2000 # Fetch 20 sample points
331
+
332
+ try:
333
+ source_points_result, _ = source_client.scroll(
334
+ collection_name=source_collection,
335
+ limit=sample_size,
336
+ with_payload=True,
337
+ with_vectors=True
338
+ )
339
+
340
+ if not source_points_result:
341
+ print("❌ No points found in source collection!")
342
+ return 1
343
+
344
+ print(f" βœ… Fetched {len(source_points_result)} points from source")
345
+
346
+ # Extract point IDs
347
+ source_point_ids = [point.id for point in source_points_result]
348
+ print(f" Point IDs: {source_point_ids[:5]}{'...' if len(source_point_ids) > 5 else ''}")
349
+
350
+ except Exception as e:
351
+ print(f"❌ Error fetching source points: {e}")
352
+ import traceback
353
+ traceback.print_exc()
354
+ return 1
355
+
356
+ # Fetch same points from destination
357
+ print(f"\nπŸ“₯ Fetching same points from destination by ID...")
358
+ try:
359
+ dest_points_dict = fetch_points_by_ids(dest_client, dest_collection, source_point_ids)
360
+ print(f" βœ… Fetched {len(dest_points_dict)} points from destination")
361
+
362
+ missing_ids = set(source_point_ids) - set(dest_points_dict.keys())
363
+ if missing_ids:
364
+ print(f" ⚠️ Missing {len(missing_ids)} points in destination: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
365
+
366
+ except Exception as e:
367
+ print(f"❌ Error fetching destination points: {e}")
368
+ import traceback
369
+ traceback.print_exc()
370
+ return 1
371
+
372
+ # Compare points
373
+ print(f"\nπŸ” Point-by-Point Comparison")
374
+ print("="*70)
375
+
376
+ comparison_results = []
377
+ for source_point in source_points_result:
378
+ point_id = source_point.id
379
+ dest_point = dest_points_dict.get(point_id)
380
+
381
+ if dest_point is None:
382
+ comparison_results.append({
383
+ "point_id": point_id,
384
+ "status": "MISSING",
385
+ "matches": [],
386
+ "differences": [f"Point not found in destination"]
387
+ })
388
+ else:
389
+ comparison = compare_points(source_point, dest_point, point_id)
390
+ comparison["status"] = "MATCH" if comparison["diff_count"] == 0 else "DIFF"
391
+ comparison_results.append(comparison)
392
+
393
+ # Summary
394
+ matches = [r for r in comparison_results if r["status"] == "MATCH"]
395
+ diffs = [r for r in comparison_results if r["status"] == "DIFF"]
396
+ missing = [r for r in comparison_results if r["status"] == "MISSING"]
397
+
398
+ print(f"\nπŸ“Š Comparison Summary:")
399
+ print(f" Total points compared: {len(comparison_results)}")
400
+ print(f" βœ… Perfect matches: {len(matches)}")
401
+ print(f" ⚠️ Differences found: {len(diffs)}")
402
+ print(f" ❌ Missing in destination: {len(missing)}")
403
+
404
+ # Show details for points with differences
405
+ if diffs:
406
+ print(f"\n⚠️ Points with differences:")
407
+ for diff_result in diffs[:10]: # Show first 10
408
+ print(f"\n Point ID: {diff_result['point_id']}")
409
+ if diff_result['matches']:
410
+ print(f" βœ… Matches ({len(diff_result['matches'])}): {', '.join(diff_result['matches'][:5])}")
411
+ if diff_result['differences']:
412
+ print(f" ❌ Differences ({len(diff_result['differences'])}):")
413
+ for d in diff_result['differences'][:5]:
414
+ print(f" - {d}")
415
+
416
+ if missing:
417
+ print(f"\n❌ Missing points in destination:")
418
+ for missing_result in missing[:10]:
419
+ print(f" - Point ID: {missing_result['point_id']}")
420
+
421
+ # Final verdict
422
+ print(f"\n" + "="*70)
423
+ if len(missing) == 0 and len(diffs) == 0:
424
+ print("βœ… VERIFICATION PASSED: All points match perfectly!")
425
+ return 0
426
+ elif len(missing) == 0:
427
+ print(f"⚠️ VERIFICATION PARTIAL: All points present but {len(diffs)} have differences")
428
+ return 1
429
+ else:
430
+ print(f"❌ VERIFICATION FAILED: {len(missing)} points missing, {len(diffs)} have differences")
431
+ return 1
432
+
433
+
434
+ if __name__ == "__main__":
435
+ sys.exit(main())
436
+
437
+
438
+