Spaces:
Sleeping
Sleeping
Ara Yeroyan
commited on
Commit
Β·
72eb0bf
1
Parent(s):
b4984e2
refactor + add gemini
Browse files- app.py +177 -595
- src/agents/__init__.py +10 -0
- src/agents/gemini_chatbot.py +372 -0
- multi_agent_chatbot.py β src/agents/multi_agent_chatbot.py +292 -40
- smart_chatbot.py β src/agents/smart_chatbot.py +0 -0
- src/feedback/__init__.py +152 -0
- src/feedback/feedback_schema.py +161 -0
- src/feedback/snowflake_connector.py +331 -0
- src/gemini/__init__.py +11 -0
- src/gemini/file_search.py +256 -0
- src/{loader.py β llm/loader.py} +0 -0
- src/pipeline.py +1 -1
- src/reporting/__init__.py +5 -1
- src/streamlit_app.py +0 -40
- src/ui_components/__init__.py +21 -0
- src/ui_components/components.py +202 -0
- src/ui_components/styles.py +117 -0
- src/ui_components/utils.py +73 -0
- utils.py β src/utils.py +0 -0
- src/vectorstore.py +16 -5
- upload_to_gemini_filestore.py +402 -0
- verify_qdrant_migration.py +438 -0
app.py
CHANGED
|
@@ -10,6 +10,7 @@ import uuid
|
|
| 10 |
import logging
|
| 11 |
import traceback
|
| 12 |
from pathlib import Path
|
|
|
|
| 13 |
from collections import Counter
|
| 14 |
from typing import List, Dict, Any, Optional
|
| 15 |
|
|
@@ -19,10 +20,11 @@ import streamlit as st
|
|
| 19 |
import plotly.express as px
|
| 20 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 21 |
|
| 22 |
-
|
| 23 |
-
from
|
| 24 |
-
from src.
|
| 25 |
-
from src.
|
|
|
|
| 26 |
from src.config.paths import (
|
| 27 |
IS_DEPLOYED,
|
| 28 |
PROJECT_DIR,
|
|
@@ -90,116 +92,9 @@ st.set_page_config(
|
|
| 90 |
page_title="Intelligent Audit Report Chatbot"
|
| 91 |
)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
st.markdown(
|
| 95 |
-
|
| 96 |
-
.main-header {
|
| 97 |
-
font-size: 2.5rem;
|
| 98 |
-
font-weight: bold;
|
| 99 |
-
color: #1f77b4;
|
| 100 |
-
text-align: center;
|
| 101 |
-
margin-bottom: 1rem;
|
| 102 |
-
width: 100%;
|
| 103 |
-
display: block;
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
.subtitle {
|
| 107 |
-
font-size: 1.2rem;
|
| 108 |
-
color: #666;
|
| 109 |
-
text-align: center;
|
| 110 |
-
margin-bottom: 2rem;
|
| 111 |
-
width: 100%;
|
| 112 |
-
display: block;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
.session-info {
|
| 116 |
-
background-color: #f0f2f6;
|
| 117 |
-
padding: 10px;
|
| 118 |
-
border-radius: 5px;
|
| 119 |
-
margin-bottom: 20px;
|
| 120 |
-
font-size: 0.9rem;
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
.user-message {
|
| 124 |
-
background-color: #007bff;
|
| 125 |
-
color: white;
|
| 126 |
-
padding: 12px 16px;
|
| 127 |
-
border-radius: 18px 18px 4px 18px;
|
| 128 |
-
margin: 8px 0;
|
| 129 |
-
margin-left: 20%;
|
| 130 |
-
word-wrap: break-word;
|
| 131 |
-
}
|
| 132 |
-
|
| 133 |
-
.bot-message {
|
| 134 |
-
background-color: #f1f3f4;
|
| 135 |
-
color: #333;
|
| 136 |
-
padding: 12px 16px;
|
| 137 |
-
border-radius: 18px 18px 18px 4px;
|
| 138 |
-
margin: 8px 0;
|
| 139 |
-
margin-right: 20%;
|
| 140 |
-
word-wrap: break-word;
|
| 141 |
-
border: 1px solid #e0e0e0;
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
.filter-section {
|
| 145 |
-
margin-bottom: 20px;
|
| 146 |
-
padding: 15px;
|
| 147 |
-
background-color: #f8f9fa;
|
| 148 |
-
border-radius: 8px;
|
| 149 |
-
border: 1px solid #e9ecef;
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
.filter-title {
|
| 153 |
-
font-weight: bold;
|
| 154 |
-
margin-bottom: 10px;
|
| 155 |
-
color: #495057;
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
.feedback-section {
|
| 159 |
-
background-color: #f8f9fa;
|
| 160 |
-
padding: 20px;
|
| 161 |
-
border-radius: 10px;
|
| 162 |
-
margin-top: 30px;
|
| 163 |
-
border: 2px solid #dee2e6;
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
.retrieval-history {
|
| 167 |
-
background-color: #ffffff;
|
| 168 |
-
padding: 15px;
|
| 169 |
-
border-radius: 5px;
|
| 170 |
-
margin: 10px 0;
|
| 171 |
-
border-left: 4px solid #007bff;
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
.retrieval-distribution-container {
|
| 175 |
-
background-color: #ffffff;
|
| 176 |
-
padding: 25px;
|
| 177 |
-
border-radius: 10px;
|
| 178 |
-
margin: 20px 0;
|
| 179 |
-
border: 2px solid #e0e0e0;
|
| 180 |
-
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
.metric-label {
|
| 184 |
-
font-size: 0.9rem;
|
| 185 |
-
color: #555;
|
| 186 |
-
margin-bottom: 5px;
|
| 187 |
-
text-align: center;
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
.metric-value {
|
| 191 |
-
font-size: 1.8rem;
|
| 192 |
-
font-weight: bold;
|
| 193 |
-
color: #000000;
|
| 194 |
-
text-align: center;
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
.metric-container {
|
| 198 |
-
text-align: center;
|
| 199 |
-
padding: 10px;
|
| 200 |
-
}
|
| 201 |
-
</style>
|
| 202 |
-
""", unsafe_allow_html=True)
|
| 203 |
|
| 204 |
def get_system_type():
|
| 205 |
"""Get the current system type"""
|
|
@@ -209,14 +104,17 @@ def get_system_type():
|
|
| 209 |
else:
|
| 210 |
return "Multi-Agent System"
|
| 211 |
|
| 212 |
-
def get_chatbot():
|
| 213 |
-
"""Initialize and return the chatbot based on
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
if system == 'smart':
|
| 217 |
-
return get_smart_chatbot()
|
| 218 |
else:
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
def serialize_messages(messages):
|
| 222 |
"""Serialize LangChain messages to dictionaries"""
|
|
@@ -262,372 +160,9 @@ def serialize_documents(sources):
|
|
| 262 |
return serialized
|
| 263 |
|
| 264 |
|
| 265 |
-
|
| 266 |
-
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
|
| 267 |
-
transcript = []
|
| 268 |
-
for msg in messages:
|
| 269 |
-
if isinstance(msg, HumanMessage):
|
| 270 |
-
transcript.append({
|
| 271 |
-
"role": "user",
|
| 272 |
-
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 273 |
-
})
|
| 274 |
-
elif isinstance(msg, AIMessage):
|
| 275 |
-
transcript.append({
|
| 276 |
-
"role": "assistant",
|
| 277 |
-
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 278 |
-
})
|
| 279 |
-
return transcript
|
| 280 |
|
| 281 |
|
| 282 |
-
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
|
| 283 |
-
"""Build retrievals structure from retrieval history"""
|
| 284 |
-
retrievals = []
|
| 285 |
-
|
| 286 |
-
for entry in rag_retrieval_history:
|
| 287 |
-
# Get the user message that triggered this retrieval
|
| 288 |
-
# The entry has conversation_up_to which includes messages up to that point
|
| 289 |
-
conversation_up_to = entry.get("conversation_up_to", [])
|
| 290 |
-
|
| 291 |
-
# Find the last user message in conversation_up_to (this is the trigger)
|
| 292 |
-
user_message_trigger = ""
|
| 293 |
-
for msg_dict in reversed(conversation_up_to):
|
| 294 |
-
if msg_dict.get("type") == "HumanMessage":
|
| 295 |
-
user_message_trigger = msg_dict.get("content", "")
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
-
# Fallback: if not found in conversation_up_to, get from actual messages
|
| 299 |
-
# This handles edge cases where conversation_up_to might be incomplete
|
| 300 |
-
if not user_message_trigger:
|
| 301 |
-
# Find which retrieval this is (0-indexed)
|
| 302 |
-
retrieval_idx = rag_retrieval_history.index(entry)
|
| 303 |
-
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
|
| 304 |
-
# because each retrieval is preceded by: user message, bot response, user message, ...
|
| 305 |
-
# But we need to account for the fact that the first retrieval happens after the first user message
|
| 306 |
-
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
|
| 307 |
-
if retrieval_idx < len(user_msgs):
|
| 308 |
-
user_message_trigger = str(user_msgs[retrieval_idx].content)
|
| 309 |
-
elif user_msgs:
|
| 310 |
-
# Fallback to last user message
|
| 311 |
-
user_message_trigger = str(user_msgs[-1].content)
|
| 312 |
-
|
| 313 |
-
# Get retrieved documents and truncate content to 100 chars
|
| 314 |
-
docs_retrieved = entry.get("docs_retrieved", [])
|
| 315 |
-
retrieved_docs = []
|
| 316 |
-
for doc in docs_retrieved:
|
| 317 |
-
doc_copy = doc.copy()
|
| 318 |
-
# Truncate content to 100 characters (keep all other fields)
|
| 319 |
-
if "content" in doc_copy:
|
| 320 |
-
doc_copy["content"] = doc_copy["content"][:100]
|
| 321 |
-
retrieved_docs.append(doc_copy)
|
| 322 |
-
|
| 323 |
-
retrievals.append({
|
| 324 |
-
"retrieved_docs": retrieved_docs,
|
| 325 |
-
"user_message_trigger": user_message_trigger
|
| 326 |
-
})
|
| 327 |
-
|
| 328 |
-
return retrievals
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def build_feedback_score_related_retrieval_docs(
|
| 332 |
-
is_feedback_about_last_retrieval: bool,
|
| 333 |
-
messages: List[Any],
|
| 334 |
-
rag_retrieval_history: List[Dict[str, Any]]
|
| 335 |
-
) -> Optional[Dict[str, Any]]:
|
| 336 |
-
"""Build feedback_score_related_retrieval_docs structure"""
|
| 337 |
-
if not rag_retrieval_history:
|
| 338 |
-
return None
|
| 339 |
-
|
| 340 |
-
# Get the relevant retrieval entry
|
| 341 |
-
if is_feedback_about_last_retrieval:
|
| 342 |
-
relevant_entry = rag_retrieval_history[-1]
|
| 343 |
-
else:
|
| 344 |
-
# If feedback is about all retrievals, use the last one as default
|
| 345 |
-
relevant_entry = rag_retrieval_history[-1]
|
| 346 |
-
|
| 347 |
-
# Get conversation up to that point
|
| 348 |
-
conversation_up_to = relevant_entry.get("conversation_up_to", [])
|
| 349 |
-
|
| 350 |
-
# Convert to transcript format (role/content)
|
| 351 |
-
conversation_up_to_point = []
|
| 352 |
-
for msg_dict in conversation_up_to:
|
| 353 |
-
if msg_dict.get("type") == "HumanMessage":
|
| 354 |
-
conversation_up_to_point.append({
|
| 355 |
-
"role": "user",
|
| 356 |
-
"content": msg_dict.get("content", "")
|
| 357 |
-
})
|
| 358 |
-
elif msg_dict.get("type") == "AIMessage":
|
| 359 |
-
conversation_up_to_point.append({
|
| 360 |
-
"role": "assistant",
|
| 361 |
-
"content": msg_dict.get("content", "")
|
| 362 |
-
})
|
| 363 |
-
|
| 364 |
-
# Get retrieved docs with full content (not truncated)
|
| 365 |
-
retrieved_docs = relevant_entry.get("docs_retrieved", [])
|
| 366 |
-
|
| 367 |
-
return {
|
| 368 |
-
"conversation_up_to_point": conversation_up_to_point,
|
| 369 |
-
"retrieved_docs": retrieved_docs
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
|
| 373 |
-
"""Extract statistics from retrieved chunks."""
|
| 374 |
-
if not sources:
|
| 375 |
-
return {}
|
| 376 |
-
|
| 377 |
-
sources_list = []
|
| 378 |
-
years = []
|
| 379 |
-
filenames = []
|
| 380 |
-
districts = []
|
| 381 |
-
|
| 382 |
-
for doc in sources:
|
| 383 |
-
metadata = getattr(doc, 'metadata', {})
|
| 384 |
-
|
| 385 |
-
# Extract source
|
| 386 |
-
source = metadata.get('source', 'Unknown')
|
| 387 |
-
sources_list.append(source)
|
| 388 |
-
|
| 389 |
-
# Extract year
|
| 390 |
-
year = metadata.get('year', 'Unknown')
|
| 391 |
-
if year and year != 'Unknown':
|
| 392 |
-
try:
|
| 393 |
-
# Convert to int first, then back to string to ensure it's a proper year
|
| 394 |
-
year_int = int(float(year)) # Handle both int and float strings
|
| 395 |
-
if 1900 <= year_int <= 2030: # Reasonable year range
|
| 396 |
-
years.append(str(year_int))
|
| 397 |
-
else:
|
| 398 |
-
years.append('Unknown')
|
| 399 |
-
except (ValueError, TypeError):
|
| 400 |
-
years.append('Unknown')
|
| 401 |
-
else:
|
| 402 |
-
years.append('Unknown')
|
| 403 |
-
|
| 404 |
-
# Extract filename
|
| 405 |
-
filename = metadata.get('filename', 'Unknown')
|
| 406 |
-
filenames.append(filename)
|
| 407 |
-
|
| 408 |
-
# Extract district
|
| 409 |
-
district = metadata.get('district', 'Unknown')
|
| 410 |
-
if district and district != 'Unknown':
|
| 411 |
-
districts.append(district)
|
| 412 |
-
else:
|
| 413 |
-
districts.append('Unknown')
|
| 414 |
-
|
| 415 |
-
# Count occurrences
|
| 416 |
-
source_counts = Counter(sources_list)
|
| 417 |
-
year_counts = Counter(years)
|
| 418 |
-
filename_counts = Counter(filenames)
|
| 419 |
-
district_counts = Counter(districts)
|
| 420 |
-
|
| 421 |
-
return {
|
| 422 |
-
'total_chunks': len(sources),
|
| 423 |
-
'unique_sources': len(source_counts),
|
| 424 |
-
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
|
| 425 |
-
'unique_filenames': len(filename_counts),
|
| 426 |
-
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
|
| 427 |
-
'source_distribution': dict(source_counts),
|
| 428 |
-
'year_distribution': dict(year_counts),
|
| 429 |
-
'filename_distribution': dict(filename_counts),
|
| 430 |
-
'district_distribution': dict(district_counts),
|
| 431 |
-
'sources': sources_list,
|
| 432 |
-
'years': years,
|
| 433 |
-
'filenames': filenames,
|
| 434 |
-
'districts': districts
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
|
| 438 |
-
"""Display statistics as interactive charts for 10+ results."""
|
| 439 |
-
if not stats or stats.get('total_chunks', 0) == 0:
|
| 440 |
-
return
|
| 441 |
-
|
| 442 |
-
# Wrap everything in one styled container - open it
|
| 443 |
-
st.markdown(f"""
|
| 444 |
-
<div class="retrieval-distribution-container">
|
| 445 |
-
<h3 style="margin-top: 0;">π {title}</h3>
|
| 446 |
-
<div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
|
| 447 |
-
<div class="metric-container">
|
| 448 |
-
<div class="metric-label">Total Chunks</div>
|
| 449 |
-
<div class="metric-value">{stats['total_chunks']}</div>
|
| 450 |
-
</div>
|
| 451 |
-
<div class="metric-container">
|
| 452 |
-
<div class="metric-label">Unique Sources</div>
|
| 453 |
-
<div class="metric-value">{stats['unique_sources']}</div>
|
| 454 |
-
</div>
|
| 455 |
-
<div class="metric-container">
|
| 456 |
-
<div class="metric-label">Unique Years</div>
|
| 457 |
-
<div class="metric-value">{stats['unique_years']}</div>
|
| 458 |
-
</div>
|
| 459 |
-
<div class="metric-container">
|
| 460 |
-
<div class="metric-label">Unique Files</div>
|
| 461 |
-
<div class="metric-value">{stats['unique_filenames']}</div>
|
| 462 |
-
</div>
|
| 463 |
-
</div>
|
| 464 |
-
""", unsafe_allow_html=True)
|
| 465 |
-
|
| 466 |
-
# Charts - three columns to include Districts
|
| 467 |
-
col1, col2, col3 = st.columns(3)
|
| 468 |
-
|
| 469 |
-
with col1:
|
| 470 |
-
# Source distribution chart
|
| 471 |
-
if stats['source_distribution']:
|
| 472 |
-
source_df = pd.DataFrame(
|
| 473 |
-
list(stats['source_distribution'].items()),
|
| 474 |
-
columns=['Source', 'Count']
|
| 475 |
-
)
|
| 476 |
-
fig_source = px.bar(
|
| 477 |
-
source_df,
|
| 478 |
-
x='Count',
|
| 479 |
-
y='Source',
|
| 480 |
-
orientation='h',
|
| 481 |
-
title='Distribution by Source',
|
| 482 |
-
color='Count',
|
| 483 |
-
color_continuous_scale='viridis'
|
| 484 |
-
)
|
| 485 |
-
fig_source.update_layout(height=400, showlegend=False)
|
| 486 |
-
st.plotly_chart(fig_source, use_container_width=True)
|
| 487 |
-
|
| 488 |
-
with col2:
|
| 489 |
-
# Year distribution chart
|
| 490 |
-
if stats['year_distribution']:
|
| 491 |
-
# Filter out 'Unknown' years for the chart
|
| 492 |
-
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 493 |
-
if year_dist_filtered:
|
| 494 |
-
year_df = pd.DataFrame(
|
| 495 |
-
list(year_dist_filtered.items()),
|
| 496 |
-
columns=['Year', 'Count']
|
| 497 |
-
)
|
| 498 |
-
# Sort by year as integer but keep as string for categorical display
|
| 499 |
-
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 500 |
-
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
|
| 501 |
-
|
| 502 |
-
fig_year = px.bar(
|
| 503 |
-
year_df,
|
| 504 |
-
x='Year',
|
| 505 |
-
y='Count',
|
| 506 |
-
title='Distribution by Year',
|
| 507 |
-
color='Count',
|
| 508 |
-
color_continuous_scale='plasma'
|
| 509 |
-
)
|
| 510 |
-
# Ensure years are treated as categorical (discrete) not continuous
|
| 511 |
-
fig_year.update_xaxes(type='category')
|
| 512 |
-
fig_year.update_layout(height=400, showlegend=False)
|
| 513 |
-
st.plotly_chart(fig_year, use_container_width=True)
|
| 514 |
-
else:
|
| 515 |
-
st.info("No valid years found in the results")
|
| 516 |
-
|
| 517 |
-
with col3:
|
| 518 |
-
# District distribution chart
|
| 519 |
-
if stats.get('district_distribution'):
|
| 520 |
-
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 521 |
-
if district_dist_filtered:
|
| 522 |
-
district_df = pd.DataFrame(
|
| 523 |
-
list(district_dist_filtered.items()),
|
| 524 |
-
columns=['District', 'Count']
|
| 525 |
-
)
|
| 526 |
-
district_df = district_df.sort_values('Count', ascending=False)
|
| 527 |
-
|
| 528 |
-
fig_district = px.bar(
|
| 529 |
-
district_df,
|
| 530 |
-
x='Count',
|
| 531 |
-
y='District',
|
| 532 |
-
orientation='h',
|
| 533 |
-
title='Distribution by District',
|
| 534 |
-
color='Count',
|
| 535 |
-
color_continuous_scale='blues'
|
| 536 |
-
)
|
| 537 |
-
fig_district.update_layout(height=400, showlegend=False)
|
| 538 |
-
st.plotly_chart(fig_district, use_container_width=True)
|
| 539 |
-
else:
|
| 540 |
-
st.info("No valid districts found in the results")
|
| 541 |
-
|
| 542 |
-
# Close the container
|
| 543 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 544 |
-
|
| 545 |
-
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
|
| 546 |
-
"""Display statistics as tables for smaller results with fixed alignment."""
|
| 547 |
-
if not stats or stats.get('total_chunks', 0) == 0:
|
| 548 |
-
return
|
| 549 |
-
|
| 550 |
-
# Wrap in styled container
|
| 551 |
-
|
| 552 |
-
st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 553 |
-
# st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
st.subheader(f"π {title}")
|
| 557 |
-
|
| 558 |
-
# Create a container with fixed height for alignment
|
| 559 |
-
stats_container = st.container()
|
| 560 |
-
|
| 561 |
-
with stats_container:
|
| 562 |
-
# Create 4 equal columns for consistent alignment
|
| 563 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 564 |
-
|
| 565 |
-
with col1:
|
| 566 |
-
st.markdown("**ποΈ Districts**")
|
| 567 |
-
if stats.get('district_distribution'):
|
| 568 |
-
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 569 |
-
if district_dist_filtered:
|
| 570 |
-
district_data = {
|
| 571 |
-
"District": list(district_dist_filtered.keys()),
|
| 572 |
-
"Count": list(district_dist_filtered.values())
|
| 573 |
-
}
|
| 574 |
-
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
|
| 575 |
-
st.dataframe(district_df, hide_index=True, use_container_width=True)
|
| 576 |
-
else:
|
| 577 |
-
st.write("No district data")
|
| 578 |
-
else:
|
| 579 |
-
st.write("No district data")
|
| 580 |
-
|
| 581 |
-
with col2:
|
| 582 |
-
st.markdown("**π Sources**")
|
| 583 |
-
if stats['source_distribution']:
|
| 584 |
-
source_data = {
|
| 585 |
-
"Source": list(stats['source_distribution'].keys()),
|
| 586 |
-
"Count": list(stats['source_distribution'].values())
|
| 587 |
-
}
|
| 588 |
-
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
|
| 589 |
-
st.dataframe(source_df, hide_index=True, use_container_width=True)
|
| 590 |
-
else:
|
| 591 |
-
st.write("No source data")
|
| 592 |
-
|
| 593 |
-
with col3:
|
| 594 |
-
st.markdown("**π
Years**")
|
| 595 |
-
if stats['year_distribution']:
|
| 596 |
-
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 597 |
-
if year_dist_filtered:
|
| 598 |
-
year_data = {
|
| 599 |
-
"Year": list(year_dist_filtered.keys()),
|
| 600 |
-
"Count": list(year_dist_filtered.values())
|
| 601 |
-
}
|
| 602 |
-
year_df = pd.DataFrame(year_data)
|
| 603 |
-
# Sort by year as integer but display as string
|
| 604 |
-
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 605 |
-
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
|
| 606 |
-
st.dataframe(year_df, hide_index=True, use_container_width=True)
|
| 607 |
-
else:
|
| 608 |
-
st.write("No year data")
|
| 609 |
-
else:
|
| 610 |
-
st.write("No year data")
|
| 611 |
-
|
| 612 |
-
with col4:
|
| 613 |
-
st.markdown("**π Files**")
|
| 614 |
-
if stats['filename_distribution']:
|
| 615 |
-
filename_items = list(stats['filename_distribution'].items())
|
| 616 |
-
filename_items.sort(key=lambda x: x[1], reverse=True)
|
| 617 |
-
|
| 618 |
-
# Show top files with truncated names
|
| 619 |
-
file_data = {
|
| 620 |
-
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
|
| 621 |
-
"Count": [c for f, c in filename_items[:5]]
|
| 622 |
-
}
|
| 623 |
-
file_df = pd.DataFrame(file_data)
|
| 624 |
-
st.dataframe(file_df, hide_index=True, use_container_width=True)
|
| 625 |
-
else:
|
| 626 |
-
st.write("No file data")
|
| 627 |
-
|
| 628 |
-
# Close container
|
| 629 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 630 |
-
|
| 631 |
@st.cache_data
|
| 632 |
def load_filter_options():
|
| 633 |
try:
|
|
@@ -652,11 +187,30 @@ def main():
|
|
| 652 |
# Track RAG retrieval history for feedback
|
| 653 |
if 'rag_retrieval_history' not in st.session_state:
|
| 654 |
st.session_state.rag_retrieval_history = []
|
| 655 |
-
#
|
| 656 |
-
if '
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 662 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
@@ -668,10 +222,13 @@ def main():
|
|
| 668 |
st.session_state.reset_conversation = False
|
| 669 |
st.rerun()
|
| 670 |
|
| 671 |
-
|
| 672 |
-
st.markdown('<h1 class="main-header">π€ Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
|
| 673 |
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
# Session info
|
| 676 |
duration = int(time.time() - st.session_state.session_start_time)
|
| 677 |
duration_str = f"{duration // 60}m {duration % 60}s"
|
|
@@ -829,7 +386,7 @@ def main():
|
|
| 829 |
)
|
| 830 |
|
| 831 |
with col2:
|
| 832 |
-
send_button = st.button("Send", key="send_button",
|
| 833 |
|
| 834 |
# Clear chat button
|
| 835 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
|
@@ -890,10 +447,20 @@ def main():
|
|
| 890 |
else:
|
| 891 |
formatted_query = "No RAG query available"
|
| 892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
retrieval_entry = {
|
| 894 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 895 |
"rag_query_expansion": formatted_query,
|
| 896 |
-
"docs_retrieved": serialize_documents(sources)
|
|
|
|
|
|
|
| 897 |
}
|
| 898 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 899 |
else:
|
|
@@ -954,9 +521,18 @@ def main():
|
|
| 954 |
for i, doc in enumerate(sources): # Show all documents
|
| 955 |
# Get relevance score and ID if available
|
| 956 |
metadata = getattr(doc, 'metadata', {})
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
|
| 961 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 962 |
# Display document metadata with emojis
|
|
@@ -1034,7 +610,7 @@ def main():
|
|
| 1034 |
|
| 1035 |
submitted = st.form_submit_button(
|
| 1036 |
"π€ Submit Feedback",
|
| 1037 |
-
|
| 1038 |
disabled=submit_disabled
|
| 1039 |
)
|
| 1040 |
|
|
@@ -1046,16 +622,18 @@ def main():
|
|
| 1046 |
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 1047 |
|
| 1048 |
# Extract transcript from messages
|
| 1049 |
-
transcript = extract_transcript(st.session_state.messages)
|
| 1050 |
|
| 1051 |
# Build retrievals structure
|
| 1052 |
-
retrievals = build_retrievals_structure(
|
|
|
|
| 1053 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 1054 |
st.session_state.messages
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
# Build feedback_score_related_retrieval_docs
|
| 1058 |
-
|
|
|
|
| 1059 |
is_feedback_about_last_retrieval,
|
| 1060 |
st.session_state.messages,
|
| 1061 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
|
@@ -1085,7 +663,7 @@ def main():
|
|
| 1085 |
# Create UserFeedback dataclass instance
|
| 1086 |
feedback_obj = None # Initialize outside try block
|
| 1087 |
try:
|
| 1088 |
-
feedback_obj = create_feedback_from_dict(feedback_dict)
|
| 1089 |
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 1090 |
st.write(f"β
**Feedback Object Created**")
|
| 1091 |
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
|
@@ -1141,7 +719,7 @@ def main():
|
|
| 1141 |
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 1142 |
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 1143 |
|
| 1144 |
-
snowflake_success = save_to_snowflake(feedback_obj)
|
| 1145 |
if snowflake_success:
|
| 1146 |
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 1147 |
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
|
@@ -1214,20 +792,111 @@ def main():
|
|
| 1214 |
st.markdown("---")
|
| 1215 |
st.markdown("#### π Retrieval History")
|
| 1216 |
|
| 1217 |
-
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=
|
| 1218 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 1219 |
-
st.markdown(f"**Retrieval #{idx}**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1220 |
|
| 1221 |
# Display the actual RAG query
|
| 1222 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
|
|
|
| 1223 |
st.code(rag_query_expansion, language="text")
|
| 1224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
# Display summary stats
|
|
|
|
| 1226 |
st.json({
|
| 1227 |
-
"conversation_length": len(
|
| 1228 |
-
"documents_retrieved": len(
|
| 1229 |
})
|
| 1230 |
-
|
|
|
|
|
|
|
| 1231 |
|
| 1232 |
# Example Questions Section
|
| 1233 |
st.markdown("---")
|
|
@@ -1307,93 +976,6 @@ def main():
|
|
| 1307 |
st.caption("π‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
|
| 1308 |
|
| 1309 |
|
| 1310 |
-
# Store selected question for next render (handled in input section above)
|
| 1311 |
-
# This ensures the question populates the input field correctly
|
| 1312 |
-
|
| 1313 |
-
# Example Questions Section
|
| 1314 |
-
st.markdown("---")
|
| 1315 |
-
st.markdown(
|
| 1316 |
-
"<h3 class='example-questions-header'>π‘ Example Questions</h3>",
|
| 1317 |
-
unsafe_allow_html=True
|
| 1318 |
-
)
|
| 1319 |
-
st.markdown(
|
| 1320 |
-
"<p class='example-questions-description'>Click on any question below to use it, or modify the editable examples:</p>",
|
| 1321 |
-
unsafe_allow_html=True
|
| 1322 |
-
)
|
| 1323 |
-
|
| 1324 |
-
# Initialize example question state
|
| 1325 |
-
if 'custom_question_1' not in st.session_state:
|
| 1326 |
-
st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?"
|
| 1327 |
-
if 'custom_question_2' not in st.session_state:
|
| 1328 |
-
st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?"
|
| 1329 |
-
|
| 1330 |
-
# Question 1: Filename insights (fixed, clickable)
|
| 1331 |
-
st.markdown("#### π Question 1: List insights from a specific file")
|
| 1332 |
-
col1, col2 = st.columns([3, 1])
|
| 1333 |
-
with col1:
|
| 1334 |
-
example_q1 = "List couple of insights from the filename."
|
| 1335 |
-
st.markdown(f"**Example:** `{example_q1}`")
|
| 1336 |
-
st.info("π‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
|
| 1337 |
-
with col2:
|
| 1338 |
-
if st.button("π Use This Question", key="use_example_1", use_container_width=True):
|
| 1339 |
-
st.session_state.pending_question = example_q1
|
| 1340 |
-
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1341 |
-
st.rerun()
|
| 1342 |
-
|
| 1343 |
-
st.markdown("---")
|
| 1344 |
-
|
| 1345 |
-
# Questions 2 & 3: Editable examples
|
| 1346 |
-
st.markdown("#### βοΈ Customizable Questions (Edit and use)")
|
| 1347 |
-
|
| 1348 |
-
# Question 2
|
| 1349 |
-
# st.markdown("**Question 2:**")
|
| 1350 |
-
custom_q1 = st.text_area(
|
| 1351 |
-
"Edit question 2:",
|
| 1352 |
-
value=st.session_state.custom_question_1,
|
| 1353 |
-
height=80,
|
| 1354 |
-
key="edit_question_2",
|
| 1355 |
-
help="Modify this question to fit your needs, then click 'Use This Question'"
|
| 1356 |
-
)
|
| 1357 |
-
col1, col2 = st.columns([1, 4])
|
| 1358 |
-
with col1:
|
| 1359 |
-
if st.button("π Use Question 2", key="use_custom_1", use_container_width=True):
|
| 1360 |
-
if custom_q1.strip():
|
| 1361 |
-
st.session_state.pending_question = custom_q1.strip()
|
| 1362 |
-
st.session_state.custom_question_1 = custom_q1.strip()
|
| 1363 |
-
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1364 |
-
st.rerun()
|
| 1365 |
-
else:
|
| 1366 |
-
st.warning("Please enter a question first!")
|
| 1367 |
-
with col2:
|
| 1368 |
-
st.caption("π‘ Tip: Add specific details like dates, names, or amounts to get more precise answers")
|
| 1369 |
-
|
| 1370 |
-
st.info("π‘ **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
|
| 1371 |
-
|
| 1372 |
-
st.markdown("---")
|
| 1373 |
-
|
| 1374 |
-
# Question 3
|
| 1375 |
-
# st.markdown("**Question 3:**")
|
| 1376 |
-
custom_q2 = st.text_area(
|
| 1377 |
-
"Edit question 3:",
|
| 1378 |
-
value=st.session_state.custom_question_2,
|
| 1379 |
-
height=80,
|
| 1380 |
-
key="edit_question_3",
|
| 1381 |
-
help="Modify this question to fit your needs, then click 'Use This Question'"
|
| 1382 |
-
)
|
| 1383 |
-
col1, col2 = st.columns([1, 4])
|
| 1384 |
-
with col1:
|
| 1385 |
-
if st.button("π Use Question 3", key="use_custom_2", use_container_width=True):
|
| 1386 |
-
if custom_q2.strip():
|
| 1387 |
-
st.session_state.pending_question = custom_q2.strip()
|
| 1388 |
-
st.session_state.custom_question_2 = custom_q2.strip()
|
| 1389 |
-
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1390 |
-
st.rerun()
|
| 1391 |
-
else:
|
| 1392 |
-
st.warning("Please enter a question first!")
|
| 1393 |
-
with col2:
|
| 1394 |
-
st.caption("π‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
# Store selected question for next render (handled in input section above)
|
| 1398 |
# This ensures the question populates the input field correctly
|
| 1399 |
|
|
|
|
| 10 |
import logging
|
| 11 |
import traceback
|
| 12 |
from pathlib import Path
|
| 13 |
+
|
| 14 |
from collections import Counter
|
| 15 |
from typing import List, Dict, Any, Optional
|
| 16 |
|
|
|
|
| 20 |
import plotly.express as px
|
| 21 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 22 |
|
| 23 |
+
|
| 24 |
+
from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
|
| 25 |
+
from src.feedback import FeedbackManager
|
| 26 |
+
from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
|
| 27 |
+
|
| 28 |
from src.config.paths import (
|
| 29 |
IS_DEPLOYED,
|
| 30 |
PROJECT_DIR,
|
|
|
|
| 92 |
page_title="Intelligent Audit Report Chatbot"
|
| 93 |
)
|
| 94 |
|
| 95 |
+
|
| 96 |
+
st.markdown(get_custom_css(), unsafe_allow_html=True)
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def get_system_type():
|
| 100 |
"""Get the current system type"""
|
|
|
|
| 104 |
else:
|
| 105 |
return "Multi-Agent System"
|
| 106 |
|
| 107 |
+
def get_chatbot(version: str = "v1"):
|
| 108 |
+
"""Initialize and return the chatbot based on version"""
|
| 109 |
+
if version == "beta":
|
| 110 |
+
return get_gemini_chatbot()
|
|
|
|
|
|
|
| 111 |
else:
|
| 112 |
+
# Check environment variable for system type (v1)
|
| 113 |
+
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 114 |
+
if system == 'smart':
|
| 115 |
+
return get_smart_chatbot()
|
| 116 |
+
else:
|
| 117 |
+
return get_multi_agent_chatbot()
|
| 118 |
|
| 119 |
def serialize_messages(messages):
|
| 120 |
"""Serialize LangChain messages to dictionaries"""
|
|
|
|
| 160 |
return serialized
|
| 161 |
|
| 162 |
|
| 163 |
+
feedback_manager = FeedbackManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
@st.cache_data
|
| 167 |
def load_filter_options():
|
| 168 |
try:
|
|
|
|
| 187 |
# Track RAG retrieval history for feedback
|
| 188 |
if 'rag_retrieval_history' not in st.session_state:
|
| 189 |
st.session_state.rag_retrieval_history = []
|
| 190 |
+
# Version selection (v1 or beta)
|
| 191 |
+
if 'chatbot_version' not in st.session_state:
|
| 192 |
+
st.session_state.chatbot_version = "v1"
|
| 193 |
+
|
| 194 |
+
# Initialize chatbot based on version (reinitialize if version changes)
|
| 195 |
+
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 196 |
+
if chatbot_version_key not in st.session_state or st.session_state.get('_last_version') != st.session_state.chatbot_version:
|
| 197 |
+
try:
|
| 198 |
+
with st.spinner("π Loading AI models and connecting to database..."):
|
| 199 |
+
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 200 |
+
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 201 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 202 |
+
st.success("β
AI system ready!")
|
| 203 |
+
except Exception as e:
|
| 204 |
+
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 205 |
+
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
|
| 206 |
+
# Reset to v1 to prevent infinite loop
|
| 207 |
+
st.session_state.chatbot_version = "v1"
|
| 208 |
+
st.session_state['_last_version'] = "v1"
|
| 209 |
+
if 'chatbot' in st.session_state:
|
| 210 |
+
del st.session_state['chatbot']
|
| 211 |
+
st.stop() # Stop execution to prevent infinite loop
|
| 212 |
+
else:
|
| 213 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 214 |
|
| 215 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 216 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
|
|
| 222 |
st.session_state.reset_conversation = False
|
| 223 |
st.rerun()
|
| 224 |
|
| 225 |
+
|
|
|
|
| 226 |
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 227 |
|
| 228 |
+
# Show version info
|
| 229 |
+
if st.session_state.chatbot_version == "beta":
|
| 230 |
+
st.info("π¬ **Beta Mode**: Using Google Gemini File Search API")
|
| 231 |
+
|
| 232 |
# Session info
|
| 233 |
duration = int(time.time() - st.session_state.session_start_time)
|
| 234 |
duration_str = f"{duration // 60}m {duration % 60}s"
|
|
|
|
| 386 |
)
|
| 387 |
|
| 388 |
with col2:
|
| 389 |
+
send_button = st.button("Send", key="send_button", width='stretch')
|
| 390 |
|
| 391 |
# Clear chat button
|
| 392 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
|
|
|
| 447 |
else:
|
| 448 |
formatted_query = "No RAG query available"
|
| 449 |
|
| 450 |
+
# Extract filters from active filters
|
| 451 |
+
filters_used = {
|
| 452 |
+
"sources": st.session_state.active_filters.get('sources', []),
|
| 453 |
+
"years": st.session_state.active_filters.get('years', []),
|
| 454 |
+
"districts": st.session_state.active_filters.get('districts', []),
|
| 455 |
+
"filenames": st.session_state.active_filters.get('filenames', [])
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
retrieval_entry = {
|
| 459 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 460 |
"rag_query_expansion": formatted_query,
|
| 461 |
+
"docs_retrieved": serialize_documents(sources),
|
| 462 |
+
"filters_applied": filters_used,
|
| 463 |
+
"timestamp": time.time()
|
| 464 |
}
|
| 465 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 466 |
else:
|
|
|
|
| 521 |
for i, doc in enumerate(sources): # Show all documents
|
| 522 |
# Get relevance score and ID if available
|
| 523 |
metadata = getattr(doc, 'metadata', {})
|
| 524 |
+
# Handle both standard RAG scores and Gemini scores
|
| 525 |
+
score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
|
| 526 |
+
chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
|
| 527 |
+
if score is not None:
|
| 528 |
+
try:
|
| 529 |
+
score_text = f" (Score: {float(score):.3f})"
|
| 530 |
+
except (ValueError, TypeError):
|
| 531 |
+
score_text = ""
|
| 532 |
+
else:
|
| 533 |
+
score_text = ""
|
| 534 |
+
if chunk_id and chunk_id != 'Unknown':
|
| 535 |
+
score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
|
| 536 |
|
| 537 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 538 |
# Display document metadata with emojis
|
|
|
|
| 610 |
|
| 611 |
submitted = st.form_submit_button(
|
| 612 |
"π€ Submit Feedback",
|
| 613 |
+
width='stretch',
|
| 614 |
disabled=submit_disabled
|
| 615 |
)
|
| 616 |
|
|
|
|
| 622 |
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 623 |
|
| 624 |
# Extract transcript from messages
|
| 625 |
+
transcript = feedback_manager.extract_transcript(st.session_state.messages)
|
| 626 |
|
| 627 |
# Build retrievals structure
|
| 628 |
+
retrievals = feedback_manager.build_retrievals_structure(
|
| 629 |
+
|
| 630 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 631 |
st.session_state.messages
|
| 632 |
)
|
| 633 |
|
| 634 |
# Build feedback_score_related_retrieval_docs
|
| 635 |
+
|
| 636 |
+
feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
|
| 637 |
is_feedback_about_last_retrieval,
|
| 638 |
st.session_state.messages,
|
| 639 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
|
|
|
| 663 |
# Create UserFeedback dataclass instance
|
| 664 |
feedback_obj = None # Initialize outside try block
|
| 665 |
try:
|
| 666 |
+
feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
|
| 667 |
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 668 |
st.write(f"β
**Feedback Object Created**")
|
| 669 |
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
|
|
|
| 719 |
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 720 |
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 721 |
|
| 722 |
+
snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
|
| 723 |
if snowflake_success:
|
| 724 |
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 725 |
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
|
|
|
| 792 |
st.markdown("---")
|
| 793 |
st.markdown("#### π Retrieval History")
|
| 794 |
|
| 795 |
+
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
|
| 796 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 797 |
+
st.markdown(f"### **Retrieval #{idx}**")
|
| 798 |
+
|
| 799 |
+
# Display timestamp if available
|
| 800 |
+
if entry.get("timestamp"):
|
| 801 |
+
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
|
| 802 |
+
st.caption(f"π {timestamp_str}")
|
| 803 |
|
| 804 |
# Display the actual RAG query
|
| 805 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
| 806 |
+
st.markdown("**π RAG Query:**")
|
| 807 |
st.code(rag_query_expansion, language="text")
|
| 808 |
|
| 809 |
+
# Display filters used
|
| 810 |
+
filters_applied = entry.get("filters_applied", {})
|
| 811 |
+
if filters_applied and any(filters_applied.values()):
|
| 812 |
+
st.markdown("**π― Filters Applied:**")
|
| 813 |
+
filter_display = {}
|
| 814 |
+
if filters_applied.get("sources"):
|
| 815 |
+
filter_display["Sources"] = filters_applied["sources"]
|
| 816 |
+
if filters_applied.get("years"):
|
| 817 |
+
filter_display["Years"] = filters_applied["years"]
|
| 818 |
+
if filters_applied.get("districts"):
|
| 819 |
+
filter_display["Districts"] = filters_applied["districts"]
|
| 820 |
+
if filters_applied.get("filenames"):
|
| 821 |
+
filter_display["Filenames"] = filters_applied["filenames"]
|
| 822 |
+
|
| 823 |
+
if filter_display:
|
| 824 |
+
st.json(filter_display)
|
| 825 |
+
else:
|
| 826 |
+
st.info("No filters applied")
|
| 827 |
+
else:
|
| 828 |
+
st.info("No filters applied")
|
| 829 |
+
|
| 830 |
+
# Display conversation history up to retrieval point
|
| 831 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 832 |
+
if conversation_up_to:
|
| 833 |
+
st.markdown("**π¬ Conversation History (up to retrieval point):**")
|
| 834 |
+
with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
|
| 835 |
+
for msg_idx, msg in enumerate(conversation_up_to, 1):
|
| 836 |
+
role = msg.get("type", "unknown")
|
| 837 |
+
content = msg.get("content", "")
|
| 838 |
+
|
| 839 |
+
if role == "HumanMessage" or role == "human":
|
| 840 |
+
st.markdown(f"**π€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 841 |
+
elif role == "AIMessage" or role == "ai":
|
| 842 |
+
st.markdown(f"**π€ Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 843 |
+
else:
|
| 844 |
+
st.info("No conversation history available")
|
| 845 |
+
|
| 846 |
+
# Display documents retrieved
|
| 847 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 848 |
+
if docs_retrieved:
|
| 849 |
+
st.markdown(f"**π Documents Retrieved ({len(docs_retrieved)}):**")
|
| 850 |
+
with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
|
| 851 |
+
for doc_idx, doc in enumerate(docs_retrieved, 1):
|
| 852 |
+
st.markdown(f"**Document {doc_idx}:**")
|
| 853 |
+
|
| 854 |
+
# Display metadata
|
| 855 |
+
metadata = doc.get("metadata", {})
|
| 856 |
+
if metadata:
|
| 857 |
+
col1, col2, col3 = st.columns(3)
|
| 858 |
+
with col1:
|
| 859 |
+
st.write(f"π **File:** {metadata.get('filename', 'Unknown')}")
|
| 860 |
+
with col2:
|
| 861 |
+
st.write(f"ποΈ **Source:** {metadata.get('source', 'Unknown')}")
|
| 862 |
+
with col3:
|
| 863 |
+
st.write(f"π
**Year:** {metadata.get('year', 'Unknown')}")
|
| 864 |
+
|
| 865 |
+
# Additional metadata
|
| 866 |
+
if metadata.get('district'):
|
| 867 |
+
st.write(f"π **District:** {metadata.get('district')}")
|
| 868 |
+
if metadata.get('page'):
|
| 869 |
+
st.write(f"π **Page:** {metadata.get('page')}")
|
| 870 |
+
if metadata.get('score') is not None:
|
| 871 |
+
st.write(f"β **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"β **Score:** {metadata.get('score')}")
|
| 872 |
+
|
| 873 |
+
# Display content preview (first 200 chars)
|
| 874 |
+
content = doc.get("content", doc.get("page_content", ""))
|
| 875 |
+
if content:
|
| 876 |
+
st.markdown("**Content Preview:**")
|
| 877 |
+
st.text_area(
|
| 878 |
+
"Content Preview",
|
| 879 |
+
value=content[:200] + ("..." if len(content) > 200 else ""),
|
| 880 |
+
height=100,
|
| 881 |
+
disabled=True,
|
| 882 |
+
label_visibility="collapsed",
|
| 883 |
+
key=f"retrieval_{idx}_doc_{doc_idx}_preview"
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
if doc_idx < len(docs_retrieved):
|
| 887 |
+
st.markdown("---")
|
| 888 |
+
else:
|
| 889 |
+
st.info("No documents retrieved")
|
| 890 |
+
|
| 891 |
# Display summary stats
|
| 892 |
+
st.markdown("**π Summary:**")
|
| 893 |
st.json({
|
| 894 |
+
"conversation_length": len(conversation_up_to),
|
| 895 |
+
"documents_retrieved": len(docs_retrieved)
|
| 896 |
})
|
| 897 |
+
|
| 898 |
+
if idx < len(st.session_state.rag_retrieval_history):
|
| 899 |
+
st.markdown("---")
|
| 900 |
|
| 901 |
# Example Questions Section
|
| 902 |
st.markdown("---")
|
|
|
|
| 976 |
st.caption("π‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
|
| 977 |
|
| 978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
# Store selected question for next render (handled in input section above)
|
| 980 |
# This ensures the question populates the input field correctly
|
| 981 |
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent modules for chatbot implementations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .smart_chatbot import get_chatbot as get_smart_chatbot
|
| 6 |
+
from .multi_agent_chatbot import get_multi_agent_chatbot
|
| 7 |
+
from .gemini_chatbot import get_gemini_chatbot
|
| 8 |
+
|
| 9 |
+
__all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
|
| 10 |
+
|
src/agents/gemini_chatbot.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Chatbot (Beta Version)
|
| 3 |
+
|
| 4 |
+
This chatbot uses Google Gemini File Search API for RAG.
|
| 5 |
+
It provides a simpler architecture: Main Agent + Gemini Agent
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import traceback
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Any, Optional, TypedDict
|
| 15 |
+
|
| 16 |
+
from langgraph.graph import StateGraph, END
|
| 17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 19 |
+
|
| 20 |
+
from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 21 |
+
from src.config.paths import CONVERSATIONS_DIR
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GeminiState(TypedDict):
|
| 28 |
+
"""State for Gemini chatbot conversation flow"""
|
| 29 |
+
conversation_id: str
|
| 30 |
+
messages: List[Any]
|
| 31 |
+
current_query: str
|
| 32 |
+
query_context: Optional[Dict[str, Any]]
|
| 33 |
+
gemini_result: Optional[GeminiFileSearchResult]
|
| 34 |
+
final_response: Optional[str]
|
| 35 |
+
agent_logs: List[str]
|
| 36 |
+
conversation_context: Dict[str, Any]
|
| 37 |
+
session_start_time: float
|
| 38 |
+
last_ai_message_time: float
|
| 39 |
+
filters: Optional[Dict[str, Any]]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GeminiRAGChatbot:
|
| 43 |
+
"""Gemini File Search RAG chatbot (Beta version)"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
"""Initialize the Gemini chatbot"""
|
| 47 |
+
logger.info("π€ INITIALIZING: Gemini File Search Chatbot (Beta)")
|
| 48 |
+
|
| 49 |
+
# Initialize Gemini File Search client
|
| 50 |
+
try:
|
| 51 |
+
self.gemini_client = GeminiFileSearchClient()
|
| 52 |
+
logger.info("β
Gemini File Search client initialized")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"β Failed to initialize Gemini client: {e}")
|
| 55 |
+
raise RuntimeError(f"Gemini client initialization failed: {e}")
|
| 56 |
+
|
| 57 |
+
# Build the LangGraph with LangSmith tracing if enabled
|
| 58 |
+
self.graph = self._build_graph()
|
| 59 |
+
|
| 60 |
+
# Enable LangSmith tracing if configured
|
| 61 |
+
langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
| 62 |
+
if langsmith_enabled:
|
| 63 |
+
logger.info("π LangSmith tracing enabled")
|
| 64 |
+
langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
|
| 65 |
+
logger.info(f"π LangSmith project: {langsmith_project}")
|
| 66 |
+
|
| 67 |
+
# Conversations directory
|
| 68 |
+
self.conversations_dir = CONVERSATIONS_DIR
|
| 69 |
+
try:
|
| 70 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 71 |
+
except (PermissionError, OSError) as e:
|
| 72 |
+
logger.warning(f"Could not create conversations directory: {e}")
|
| 73 |
+
self.conversations_dir = Path("conversations")
|
| 74 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
logger.info("β
Gemini File Search Chatbot initialized")
|
| 77 |
+
|
| 78 |
+
def _build_graph(self) -> StateGraph:
|
| 79 |
+
"""Build the LangGraph for Gemini chatbot"""
|
| 80 |
+
graph = StateGraph(GeminiState)
|
| 81 |
+
|
| 82 |
+
# Add nodes
|
| 83 |
+
graph.add_node("main_agent", self._main_agent)
|
| 84 |
+
graph.add_node("gemini_agent", self._gemini_agent)
|
| 85 |
+
|
| 86 |
+
# Define the flow
|
| 87 |
+
graph.set_entry_point("main_agent")
|
| 88 |
+
graph.add_edge("main_agent", "gemini_agent")
|
| 89 |
+
graph.add_edge("gemini_agent", END)
|
| 90 |
+
|
| 91 |
+
return graph.compile()
|
| 92 |
+
|
| 93 |
+
def _main_agent(self, state: GeminiState) -> GeminiState:
|
| 94 |
+
"""Main Agent: Extracts filters and prepares query"""
|
| 95 |
+
logger.info("π― MAIN AGENT: Processing query")
|
| 96 |
+
|
| 97 |
+
query = state["current_query"]
|
| 98 |
+
messages = state["messages"]
|
| 99 |
+
|
| 100 |
+
# Extract UI filters if present in query
|
| 101 |
+
ui_filters = self._extract_ui_filters(query)
|
| 102 |
+
|
| 103 |
+
# Extract context from conversation
|
| 104 |
+
context = self._extract_context_from_conversation(messages, ui_filters)
|
| 105 |
+
|
| 106 |
+
# Store context and filters
|
| 107 |
+
state["query_context"] = context
|
| 108 |
+
state["filters"] = context.get("filters", {})
|
| 109 |
+
|
| 110 |
+
logger.info(f"π― MAIN AGENT: Filters extracted: {state['filters']}")
|
| 111 |
+
|
| 112 |
+
return state
|
| 113 |
+
|
| 114 |
+
def _gemini_agent(self, state: GeminiState) -> GeminiState:
|
| 115 |
+
"""Gemini Agent: Performs file search and generates response"""
|
| 116 |
+
logger.info("π GEMINI AGENT: Starting file search")
|
| 117 |
+
|
| 118 |
+
query = state["current_query"]
|
| 119 |
+
filters = state.get("filters", {})
|
| 120 |
+
|
| 121 |
+
# Perform Gemini file search
|
| 122 |
+
try:
|
| 123 |
+
result = self.gemini_client.search(query=query, filters=filters)
|
| 124 |
+
logger.info(f"β
GEMINI AGENT: Search completed, {len(result.sources)} sources found")
|
| 125 |
+
|
| 126 |
+
# Enhance response with document references
|
| 127 |
+
enhanced_response = self._enhance_response_with_references(
|
| 128 |
+
result.answer,
|
| 129 |
+
result.sources,
|
| 130 |
+
query
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
state["gemini_result"] = result
|
| 134 |
+
state["final_response"] = enhanced_response
|
| 135 |
+
state["last_ai_message_time"] = time.time()
|
| 136 |
+
|
| 137 |
+
state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"β GEMINI AGENT ERROR: {e}")
|
| 141 |
+
traceback.print_exc()
|
| 142 |
+
state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
|
| 143 |
+
state["last_ai_message_time"] = time.time()
|
| 144 |
+
|
| 145 |
+
return state
|
| 146 |
+
|
| 147 |
+
def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
|
| 148 |
+
"""Enhance Gemini response to include document references"""
|
| 149 |
+
if not sources or not answer:
|
| 150 |
+
return answer
|
| 151 |
+
|
| 152 |
+
# Use LLM to intelligently add document references
|
| 153 |
+
try:
|
| 154 |
+
from src.llm.adapters import get_llm_client
|
| 155 |
+
llm = get_llm_client()
|
| 156 |
+
|
| 157 |
+
# Prepare document summaries for the LLM
|
| 158 |
+
doc_summaries = []
|
| 159 |
+
for idx, doc in enumerate(sources, 1):
|
| 160 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 161 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 162 |
+
|
| 163 |
+
filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 164 |
+
year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 165 |
+
source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 166 |
+
|
| 167 |
+
doc_summaries.append(f"[Doc {idx}] {filename} ({year}, {source}): {content[:300]}...")
|
| 168 |
+
|
| 169 |
+
prompt = f"""You are enhancing a response from a document search system. The original response is:
|
| 170 |
+
|
| 171 |
+
{answer}
|
| 172 |
+
|
| 173 |
+
The following documents were retrieved and used to generate this response:
|
| 174 |
+
|
| 175 |
+
{chr(10).join(doc_summaries)}
|
| 176 |
+
|
| 177 |
+
CRITICAL RULES:
|
| 178 |
+
1. The response should ONLY contain information from the retrieved documents listed above
|
| 179 |
+
2. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
|
| 180 |
+
3. Add document references [Doc i] at the end of sentences that use information from specific documents
|
| 181 |
+
4. Only reference documents that are actually used in the response
|
| 182 |
+
5. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
|
| 183 |
+
6. Keep the response natural and conversational
|
| 184 |
+
7. Don't change the core content that matches the documents, just add references where appropriate
|
| 185 |
+
8. If multiple documents support the same claim, use [Doc i, Doc j] format
|
| 186 |
+
9. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
|
| 187 |
+
|
| 188 |
+
Return ONLY the enhanced response with references added and any corrections made. Do not include any explanation or meta-commentary."""
|
| 189 |
+
|
| 190 |
+
enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
|
| 191 |
+
|
| 192 |
+
# Fallback: if LLM fails, just return original
|
| 193 |
+
if not enhanced or len(enhanced) < len(answer) * 0.5:
|
| 194 |
+
logger.warning("LLM enhancement failed, using original response")
|
| 195 |
+
return answer
|
| 196 |
+
|
| 197 |
+
return enhanced
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.warning(f"Failed to enhance response with references: {e}")
|
| 201 |
+
# Fallback: add basic references at the end
|
| 202 |
+
if sources:
|
| 203 |
+
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
|
| 204 |
+
return f"{answer}\n\n*Based on documents: {ref_list}*"
|
| 205 |
+
return answer
|
| 206 |
+
|
| 207 |
+
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
|
| 208 |
+
"""Extract UI filters from query if present"""
|
| 209 |
+
filters = {}
|
| 210 |
+
|
| 211 |
+
if "FILTER CONTEXT:" in query:
|
| 212 |
+
filter_section = query.split("FILTER CONTEXT:")[1]
|
| 213 |
+
if "USER QUERY:" in filter_section:
|
| 214 |
+
filter_section = filter_section.split("USER QUERY:")[0]
|
| 215 |
+
filter_section = filter_section.strip()
|
| 216 |
+
|
| 217 |
+
if "Sources:" in filter_section:
|
| 218 |
+
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
|
| 219 |
+
if sources_line:
|
| 220 |
+
sources_str = sources_line[0].split("Sources:")[1].strip()
|
| 221 |
+
if sources_str and sources_str != "None":
|
| 222 |
+
filters["sources"] = [s.strip() for s in sources_str.split(",")]
|
| 223 |
+
|
| 224 |
+
if "Years:" in filter_section:
|
| 225 |
+
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
|
| 226 |
+
if years_line:
|
| 227 |
+
years_str = years_line[0].split("Years:")[1].strip()
|
| 228 |
+
if years_str and years_str != "None":
|
| 229 |
+
filters["year"] = [y.strip() for y in years_str.split(",")]
|
| 230 |
+
|
| 231 |
+
if "Districts:" in filter_section:
|
| 232 |
+
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
|
| 233 |
+
if districts_line:
|
| 234 |
+
districts_str = districts_line[0].split("Districts:")[1].strip()
|
| 235 |
+
if districts_str and districts_str != "None":
|
| 236 |
+
filters["district"] = [d.strip() for d in districts_str.split(",")]
|
| 237 |
+
|
| 238 |
+
if "Filenames:" in filter_section:
|
| 239 |
+
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
|
| 240 |
+
if filenames_line:
|
| 241 |
+
filenames_str = filenames_line[0].split("Filenames:")[1].strip()
|
| 242 |
+
if filenames_str and filenames_str != "None":
|
| 243 |
+
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
|
| 244 |
+
|
| 245 |
+
return filters
|
| 246 |
+
|
| 247 |
+
def _extract_context_from_conversation(
|
| 248 |
+
self,
|
| 249 |
+
messages: List[Any],
|
| 250 |
+
ui_filters: Dict[str, List[str]]
|
| 251 |
+
) -> Dict[str, Any]:
|
| 252 |
+
"""Extract context from conversation history"""
|
| 253 |
+
# Use UI filters if available
|
| 254 |
+
filters = ui_filters.copy() if ui_filters else {}
|
| 255 |
+
|
| 256 |
+
# For Gemini, we pass filters directly to the search function
|
| 257 |
+
# The filters will be used to add context to the query
|
| 258 |
+
|
| 259 |
+
return {
|
| 260 |
+
"filters": filters,
|
| 261 |
+
"has_filters": bool(filters)
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
|
| 265 |
+
"""Main chat interface"""
|
| 266 |
+
logger.info(f"π¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
|
| 267 |
+
|
| 268 |
+
# Load conversation
|
| 269 |
+
conversation_file = self.conversations_dir / f"{conversation_id}.json"
|
| 270 |
+
conversation = self._load_conversation(conversation_file)
|
| 271 |
+
|
| 272 |
+
# Add user message
|
| 273 |
+
conversation["messages"].append(HumanMessage(content=user_input))
|
| 274 |
+
|
| 275 |
+
# Prepare state
|
| 276 |
+
state = GeminiState(
|
| 277 |
+
conversation_id=conversation_id,
|
| 278 |
+
messages=conversation["messages"],
|
| 279 |
+
current_query=user_input,
|
| 280 |
+
query_context=None,
|
| 281 |
+
gemini_result=None,
|
| 282 |
+
final_response=None,
|
| 283 |
+
agent_logs=[],
|
| 284 |
+
conversation_context=conversation.get("context", {}),
|
| 285 |
+
session_start_time=conversation["session_start_time"],
|
| 286 |
+
last_ai_message_time=conversation["last_ai_message_time"],
|
| 287 |
+
filters=None
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Run graph
|
| 291 |
+
final_state = self.graph.invoke(state)
|
| 292 |
+
|
| 293 |
+
# Add AI response to conversation
|
| 294 |
+
if final_state["final_response"]:
|
| 295 |
+
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
|
| 296 |
+
|
| 297 |
+
# Update conversation
|
| 298 |
+
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
|
| 299 |
+
conversation["context"] = final_state["conversation_context"]
|
| 300 |
+
|
| 301 |
+
# Save conversation
|
| 302 |
+
self._save_conversation(conversation_file, conversation)
|
| 303 |
+
|
| 304 |
+
# Format sources for display
|
| 305 |
+
sources = []
|
| 306 |
+
if final_state.get("gemini_result"):
|
| 307 |
+
sources = self.gemini_client.format_sources_for_display(final_state["gemini_result"])
|
| 308 |
+
|
| 309 |
+
return {
|
| 310 |
+
'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
|
| 311 |
+
'rag_result': {
|
| 312 |
+
'sources': sources,
|
| 313 |
+
'answer': final_state["final_response"]
|
| 314 |
+
},
|
| 315 |
+
'agent_logs': final_state["agent_logs"],
|
| 316 |
+
'actual_rag_query': final_state["current_query"]
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
|
| 320 |
+
"""Load conversation from file"""
|
| 321 |
+
if conversation_file.exists():
|
| 322 |
+
try:
|
| 323 |
+
with open(conversation_file) as f:
|
| 324 |
+
data = json.load(f)
|
| 325 |
+
messages = []
|
| 326 |
+
for msg_data in data.get("messages", []):
|
| 327 |
+
if msg_data["type"] == "human":
|
| 328 |
+
messages.append(HumanMessage(content=msg_data["content"]))
|
| 329 |
+
elif msg_data["type"] == "ai":
|
| 330 |
+
messages.append(AIMessage(content=msg_data["content"]))
|
| 331 |
+
data["messages"] = messages
|
| 332 |
+
return data
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.warning(f"Could not load conversation: {e}")
|
| 335 |
+
|
| 336 |
+
return {
|
| 337 |
+
"messages": [],
|
| 338 |
+
"session_start_time": time.time(),
|
| 339 |
+
"last_ai_message_time": time.time(),
|
| 340 |
+
"context": {}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
|
| 344 |
+
"""Save conversation to file"""
|
| 345 |
+
try:
|
| 346 |
+
conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 347 |
+
|
| 348 |
+
messages_data = []
|
| 349 |
+
for msg in conversation["messages"]:
|
| 350 |
+
if isinstance(msg, HumanMessage):
|
| 351 |
+
messages_data.append({"type": "human", "content": msg.content})
|
| 352 |
+
elif isinstance(msg, AIMessage):
|
| 353 |
+
messages_data.append({"type": "ai", "content": msg.content})
|
| 354 |
+
|
| 355 |
+
conversation_data = {
|
| 356 |
+
"messages": messages_data,
|
| 357 |
+
"session_start_time": conversation["session_start_time"],
|
| 358 |
+
"last_ai_message_time": conversation["last_ai_message_time"],
|
| 359 |
+
"context": conversation.get("context", {})
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
with open(conversation_file, 'w') as f:
|
| 363 |
+
json.dump(conversation_data, f, indent=2)
|
| 364 |
+
|
| 365 |
+
except Exception as e:
|
| 366 |
+
logger.error(f"Could not save conversation: {e}")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def get_gemini_chatbot():
|
| 370 |
+
"""Get Gemini chatbot instance"""
|
| 371 |
+
return GeminiRAGChatbot()
|
| 372 |
+
|
multi_agent_chatbot.py β src/agents/multi_agent_chatbot.py
RENAMED
|
@@ -208,6 +208,59 @@ class MultiAgentRAGChatbot:
|
|
| 208 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 209 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
def _build_graph(self) -> StateGraph:
|
| 212 |
"""Build the multi-agent LangGraph"""
|
| 213 |
graph = StateGraph(MultiAgentState)
|
|
@@ -512,6 +565,10 @@ class MultiAgentRAGChatbot:
|
|
| 512 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 513 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 514 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
- If no exact matches found, set extracted values to null
|
| 516 |
|
| 517 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 656 |
# Validate each district in the array
|
| 657 |
valid_districts = []
|
| 658 |
for district in extracted_district:
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
# Try removing "District" suffix
|
| 663 |
-
district_name = district.replace(" District", "").replace(" district", "")
|
| 664 |
-
if district_name in self.district_whitelist:
|
| 665 |
-
valid_districts.append(district_name)
|
| 666 |
|
| 667 |
if valid_districts:
|
| 668 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 671 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 672 |
extracted_district = None
|
| 673 |
else:
|
| 674 |
-
# Single district validation
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
extracted_district = None
|
| 684 |
|
| 685 |
# Validate source (handle both single values and arrays)
|
| 686 |
if extracted_source:
|
|
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
|
|
| 918 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 919 |
|
| 920 |
# Merge with extracted context for missing filters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
if not filters.get("year") and context.extracted_year:
|
| 922 |
# Handle both single values and arrays
|
| 923 |
if isinstance(context.extracted_year, list):
|
|
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
|
|
| 926 |
filters["year"] = [context.extracted_year]
|
| 927 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 928 |
|
| 929 |
-
if not filters.get("district") and context.extracted_district:
|
| 930 |
-
# Handle both single values and arrays
|
| 931 |
-
if isinstance(context.extracted_district, list):
|
| 932 |
-
# Normalize district names to title case (match Qdrant metadata format)
|
| 933 |
-
normalized = [d.title() for d in context.extracted_district]
|
| 934 |
-
filters["district"] = normalized
|
| 935 |
-
else:
|
| 936 |
-
filters["district"] = [context.extracted_district.title()]
|
| 937 |
-
logger.info(f"π§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
|
| 938 |
-
|
| 939 |
if not filters.get("sources") and context.extracted_source:
|
| 940 |
# Handle both single values and arrays
|
| 941 |
if isinstance(context.extracted_source, list):
|
|
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
|
|
| 963 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 964 |
|
| 965 |
if context.extracted_district:
|
| 966 |
-
#
|
| 967 |
if isinstance(context.extracted_district, list):
|
| 968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
else:
|
| 970 |
-
|
| 971 |
-
|
|
|
|
|
|
|
| 972 |
|
| 973 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 974 |
return filters
|
|
@@ -978,49 +1046,233 @@ Rewrite the best retrieval query:""")
|
|
| 978 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 979 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 980 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
# Create response prompt
|
| 983 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 984 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 985 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 986 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 987 |
RULES:
|
| 988 |
1. Answer the user's question directly and clearly
|
| 989 |
-
2. Use the retrieved documents as evidence
|
| 990 |
3. Be conversational, not technical
|
| 991 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 992 |
5. If relevant documents were found, reference them naturally
|
| 993 |
-
6. If no relevant documents,
|
| 994 |
-
7. If the passages have useful facts or numbers, use them in your answer
|
| 995 |
-
8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 996 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 997 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 998 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 999 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1000 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1001 |
-
14.
|
| 1002 |
-
|
|
|
|
| 1003 |
|
| 1004 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1005 |
-
HumanMessage(content=f"""
|
|
|
|
|
|
|
|
|
|
| 1006 |
|
| 1007 |
Retrieved Documents: {len(documents)} documents found
|
| 1008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
RAG Answer: {rag_answer}
|
| 1010 |
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
])
|
| 1013 |
|
| 1014 |
try:
|
| 1015 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1016 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1017 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
except Exception as e:
|
| 1020 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1021 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1022 |
return rag_answer # Fallback to RAG answer
|
| 1023 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1025 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1026 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
|
|
|
| 208 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 209 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 210 |
|
| 211 |
+
def _normalize_district_name(self, district: str) -> Optional[str]:
|
| 212 |
+
"""Normalize district name with fuzzy matching for common misspellings."""
|
| 213 |
+
if not district:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
district = district.strip()
|
| 217 |
+
|
| 218 |
+
# Direct match
|
| 219 |
+
if district in self.district_whitelist:
|
| 220 |
+
return district
|
| 221 |
+
|
| 222 |
+
# Remove "District" suffix
|
| 223 |
+
district_name = district.replace(" District", "").replace(" district", "").strip()
|
| 224 |
+
if district_name in self.district_whitelist:
|
| 225 |
+
return district_name
|
| 226 |
+
|
| 227 |
+
# Common misspellings mapping
|
| 228 |
+
misspelling_map = {
|
| 229 |
+
"kalagala": "Kalangala",
|
| 230 |
+
"Kalagala": "Kalangala",
|
| 231 |
+
"KALAGALA": "Kalangala",
|
| 232 |
+
"kalangala": "Kalangala",
|
| 233 |
+
"gulu": "Gulu",
|
| 234 |
+
"GULU": "Gulu",
|
| 235 |
+
"kampala": "Kampala",
|
| 236 |
+
"KAMPALA": "Kampala",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# Check misspelling map (case-insensitive)
|
| 240 |
+
district_lower = district_name.lower()
|
| 241 |
+
if district_lower in misspelling_map:
|
| 242 |
+
corrected = misspelling_map[district_lower]
|
| 243 |
+
if corrected in self.district_whitelist:
|
| 244 |
+
return corrected
|
| 245 |
+
|
| 246 |
+
# Fuzzy matching for similar names (simple Levenshtein-like check)
|
| 247 |
+
# Check if the district name is very similar to any whitelist entry
|
| 248 |
+
for whitelist_district in self.district_whitelist:
|
| 249 |
+
# Case-insensitive comparison
|
| 250 |
+
if district_name.lower() == whitelist_district.lower():
|
| 251 |
+
return whitelist_district
|
| 252 |
+
|
| 253 |
+
# Check if one is a substring of the other (for partial matches)
|
| 254 |
+
if len(district_name) >= 4 and len(whitelist_district) >= 4:
|
| 255 |
+
if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
|
| 256 |
+
# Only return if it's a strong match (at least 80% of characters match)
|
| 257 |
+
min_len = min(len(district_name), len(whitelist_district))
|
| 258 |
+
max_len = max(len(district_name), len(whitelist_district))
|
| 259 |
+
if min_len / max_len >= 0.8:
|
| 260 |
+
return whitelist_district
|
| 261 |
+
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
def _build_graph(self) -> StateGraph:
|
| 265 |
"""Build the multi-agent LangGraph"""
|
| 266 |
graph = StateGraph(MultiAgentState)
|
|
|
|
| 565 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 566 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 567 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
| 568 |
+
- **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
|
| 569 |
+
* "Kalagala" (missing 'n') should be extracted as "Kalangala"
|
| 570 |
+
* "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
|
| 571 |
+
* Similar case-insensitive variations should be normalized to the correct district name
|
| 572 |
- If no exact matches found, set extracted values to null
|
| 573 |
|
| 574 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
|
|
| 713 |
# Validate each district in the array
|
| 714 |
valid_districts = []
|
| 715 |
for district in extracted_district:
|
| 716 |
+
normalized = self._normalize_district_name(district)
|
| 717 |
+
if normalized:
|
| 718 |
+
valid_districts.append(normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
if valid_districts:
|
| 721 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
|
|
| 724 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 725 |
extracted_district = None
|
| 726 |
else:
|
| 727 |
+
# Single district validation with fuzzy matching
|
| 728 |
+
normalized = self._normalize_district_name(extracted_district)
|
| 729 |
+
if normalized:
|
| 730 |
+
if normalized != extracted_district:
|
| 731 |
+
logger.info(f"π QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
|
| 732 |
+
extracted_district = normalized
|
| 733 |
+
else:
|
| 734 |
+
logger.warning(f"β οΈ Invalid district extracted: '{extracted_district}' not in whitelist")
|
| 735 |
+
extracted_district = None
|
|
|
|
| 736 |
|
| 737 |
# Validate source (handle both single values and arrays)
|
| 738 |
if extracted_source:
|
|
|
|
| 970 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 971 |
|
| 972 |
# Merge with extracted context for missing filters
|
| 973 |
+
if not filters.get("district") and context.extracted_district:
|
| 974 |
+
# Normalize district names using the normalization function
|
| 975 |
+
if isinstance(context.extracted_district, list):
|
| 976 |
+
normalized_districts = []
|
| 977 |
+
for d in context.extracted_district:
|
| 978 |
+
normalized = self._normalize_district_name(d)
|
| 979 |
+
if normalized:
|
| 980 |
+
normalized_districts.append(normalized)
|
| 981 |
+
if normalized_districts:
|
| 982 |
+
filters["district"] = normalized_districts
|
| 983 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 984 |
+
else:
|
| 985 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 986 |
+
if normalized:
|
| 987 |
+
filters["district"] = [normalized]
|
| 988 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 989 |
+
|
| 990 |
if not filters.get("year") and context.extracted_year:
|
| 991 |
# Handle both single values and arrays
|
| 992 |
if isinstance(context.extracted_year, list):
|
|
|
|
| 995 |
filters["year"] = [context.extracted_year]
|
| 996 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
if not filters.get("sources") and context.extracted_source:
|
| 999 |
# Handle both single values and arrays
|
| 1000 |
if isinstance(context.extracted_source, list):
|
|
|
|
| 1022 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 1023 |
|
| 1024 |
if context.extracted_district:
|
| 1025 |
+
# Normalize district names using the normalization function
|
| 1026 |
if isinstance(context.extracted_district, list):
|
| 1027 |
+
normalized_districts = []
|
| 1028 |
+
for d in context.extracted_district:
|
| 1029 |
+
normalized = self._normalize_district_name(d)
|
| 1030 |
+
if normalized:
|
| 1031 |
+
normalized_districts.append(normalized)
|
| 1032 |
+
if normalized_districts:
|
| 1033 |
+
filters["district"] = normalized_districts
|
| 1034 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 1035 |
else:
|
| 1036 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 1037 |
+
if normalized:
|
| 1038 |
+
filters["district"] = [normalized]
|
| 1039 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 1040 |
|
| 1041 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 1042 |
return filters
|
|
|
|
| 1046 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 1047 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 1048 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
| 1049 |
+
logger.info(f"π¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
|
| 1050 |
+
|
| 1051 |
+
# Build conversation history context
|
| 1052 |
+
conversation_context = self._build_conversation_context(messages)
|
| 1053 |
+
|
| 1054 |
+
# Build detailed document information
|
| 1055 |
+
document_details = self._build_document_details(documents)
|
| 1056 |
+
|
| 1057 |
+
# Extract correct district/source/year names from documents (to correct misspellings)
|
| 1058 |
+
correct_names = self._extract_correct_names_from_documents(documents)
|
| 1059 |
|
| 1060 |
# Create response prompt
|
| 1061 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 1062 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 1063 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 1064 |
|
| 1065 |
+
CRITICAL RULES - NO HALLUCINATION:
|
| 1066 |
+
1. **ONLY use information from the retrieved documents provided below**
|
| 1067 |
+
2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
|
| 1068 |
+
3. **If a document doesn't contain the information, DO NOT make it up**
|
| 1069 |
+
4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
|
| 1070 |
+
5. **Check the document years/districts before making any claims about them**
|
| 1071 |
+
6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
|
| 1072 |
+
|
| 1073 |
RULES:
|
| 1074 |
1. Answer the user's question directly and clearly
|
| 1075 |
+
2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
|
| 1076 |
3. Be conversational, not technical
|
| 1077 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 1078 |
5. If relevant documents were found, reference them naturally
|
| 1079 |
+
6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
|
| 1080 |
+
7. If the passages have useful facts or numbers, use them in your answer WITH references
|
| 1081 |
+
8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 1082 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 1083 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 1084 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 1085 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1086 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1087 |
+
14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
|
| 1088 |
+
15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
|
| 1089 |
+
16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
|
| 1090 |
|
| 1091 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1092 |
+
HumanMessage(content=f"""Conversation History:
|
| 1093 |
+
{conversation_context}
|
| 1094 |
+
|
| 1095 |
+
Current User Question: {query}
|
| 1096 |
|
| 1097 |
Retrieved Documents: {len(documents)} documents found
|
| 1098 |
|
| 1099 |
+
CORRECT NAMES TO USE (from document metadata - use these exact spellings):
|
| 1100 |
+
{correct_names}
|
| 1101 |
+
|
| 1102 |
+
Full Document Details:
|
| 1103 |
+
{document_details}
|
| 1104 |
+
|
| 1105 |
RAG Answer: {rag_answer}
|
| 1106 |
|
| 1107 |
+
CRITICAL:
|
| 1108 |
+
- Responses should be grounded to what is available in the retrieved documents
|
| 1109 |
+
- If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
|
| 1110 |
+
- Every factual claim MUST have [Doc i] reference
|
| 1111 |
+
- If information is not in documents, explicitly state it's not available
|
| 1112 |
+
- **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
|
| 1113 |
+
|
| 1114 |
+
Generate a conversational response with proper document references:""")
|
| 1115 |
])
|
| 1116 |
|
| 1117 |
try:
|
| 1118 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1119 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1120 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1121 |
+
|
| 1122 |
+
# Post-process response to ensure no hallucination
|
| 1123 |
+
final_response = self._validate_and_enhance_response(
|
| 1124 |
+
response.content.strip(),
|
| 1125 |
+
documents,
|
| 1126 |
+
query
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
return final_response
|
| 1130 |
except Exception as e:
|
| 1131 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1132 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1133 |
return rag_answer # Fallback to RAG answer
|
| 1134 |
|
| 1135 |
+
def _build_conversation_context(self, messages: List[Any]) -> str:
|
| 1136 |
+
"""Build conversation history context for response generation."""
|
| 1137 |
+
if not messages:
|
| 1138 |
+
return "No previous conversation."
|
| 1139 |
+
|
| 1140 |
+
context_lines = []
|
| 1141 |
+
# Show last 6 messages for context (to capture the current exchange)
|
| 1142 |
+
for msg in messages[-6:]:
|
| 1143 |
+
if isinstance(msg, HumanMessage):
|
| 1144 |
+
context_lines.append(f"User: {msg.content}")
|
| 1145 |
+
elif isinstance(msg, AIMessage):
|
| 1146 |
+
context_lines.append(f"Assistant: {msg.content}")
|
| 1147 |
+
|
| 1148 |
+
return "\n".join(context_lines) if context_lines else "No previous conversation."
|
| 1149 |
+
|
| 1150 |
+
def _build_document_details(self, documents: List[Any]) -> str:
|
| 1151 |
+
"""Build detailed document information for response generation."""
|
| 1152 |
+
if not documents:
|
| 1153 |
+
return "No documents retrieved."
|
| 1154 |
+
|
| 1155 |
+
details = []
|
| 1156 |
+
for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
|
| 1157 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1158 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 1159 |
+
|
| 1160 |
+
if isinstance(metadata, dict):
|
| 1161 |
+
filename = metadata.get('filename', 'Unknown')
|
| 1162 |
+
year = metadata.get('year', 'Unknown')
|
| 1163 |
+
district = metadata.get('district', 'Unknown')
|
| 1164 |
+
source = metadata.get('source', 'Unknown')
|
| 1165 |
+
page = metadata.get('page', metadata.get('page_label', 'Unknown'))
|
| 1166 |
+
|
| 1167 |
+
doc_info = f"[Doc {i}]"
|
| 1168 |
+
doc_info += f"\n Filename: {filename}"
|
| 1169 |
+
doc_info += f"\n Year: {year}"
|
| 1170 |
+
doc_info += f"\n District: {district}"
|
| 1171 |
+
doc_info += f"\n Source: {source}"
|
| 1172 |
+
if page != 'Unknown':
|
| 1173 |
+
doc_info += f"\n Page: {page}"
|
| 1174 |
+
doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
|
| 1175 |
+
details.append(doc_info)
|
| 1176 |
+
|
| 1177 |
+
return "\n\n".join(details) if details else "No document details available."
|
| 1178 |
+
|
| 1179 |
+
def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
|
| 1180 |
+
"""Extract correct district/source names from documents to correct misspellings."""
|
| 1181 |
+
districts = set()
|
| 1182 |
+
sources = set()
|
| 1183 |
+
years = set()
|
| 1184 |
+
|
| 1185 |
+
for doc in documents:
|
| 1186 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1187 |
+
if isinstance(metadata, dict):
|
| 1188 |
+
if metadata.get('district'):
|
| 1189 |
+
districts.add(str(metadata['district']))
|
| 1190 |
+
if metadata.get('source'):
|
| 1191 |
+
sources.add(str(metadata['source']))
|
| 1192 |
+
if metadata.get('year'):
|
| 1193 |
+
years.add(str(metadata['year']))
|
| 1194 |
+
|
| 1195 |
+
result = []
|
| 1196 |
+
if districts:
|
| 1197 |
+
result.append(f"Districts: {', '.join(sorted(districts))}")
|
| 1198 |
+
if sources:
|
| 1199 |
+
result.append(f"Sources: {', '.join(sorted(sources))}")
|
| 1200 |
+
if years:
|
| 1201 |
+
result.append(f"Years: {', '.join(sorted(years))}")
|
| 1202 |
+
|
| 1203 |
+
if result:
|
| 1204 |
+
return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
|
| 1205 |
+
return "No metadata available."
|
| 1206 |
+
|
| 1207 |
+
def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
|
| 1208 |
+
"""Validate response and ensure all claims are referenced."""
|
| 1209 |
+
# Extract years and districts from documents
|
| 1210 |
+
doc_years = set()
|
| 1211 |
+
doc_districts = set()
|
| 1212 |
+
doc_sources = set()
|
| 1213 |
+
|
| 1214 |
+
for doc in documents:
|
| 1215 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1216 |
+
if isinstance(metadata, dict):
|
| 1217 |
+
if metadata.get('year'):
|
| 1218 |
+
doc_years.add(str(metadata['year']))
|
| 1219 |
+
if metadata.get('district'):
|
| 1220 |
+
doc_districts.add(str(metadata['district']))
|
| 1221 |
+
if metadata.get('source'):
|
| 1222 |
+
doc_sources.add(str(metadata['source']))
|
| 1223 |
+
|
| 1224 |
+
# Correct misspellings in response using correct names from documents
|
| 1225 |
+
response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
|
| 1226 |
+
|
| 1227 |
+
# Check if response mentions years not in documents
|
| 1228 |
+
year_pattern = r'\b(20\d{2})\b'
|
| 1229 |
+
mentioned_years = set(re.findall(year_pattern, response))
|
| 1230 |
+
|
| 1231 |
+
# Check if user query mentions a year
|
| 1232 |
+
query_years = set(re.findall(year_pattern, query))
|
| 1233 |
+
|
| 1234 |
+
# If user asks about a year not in documents, add a warning
|
| 1235 |
+
missing_years = query_years - doc_years
|
| 1236 |
+
if missing_years and doc_years:
|
| 1237 |
+
warning = f"\n\nβ οΈ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
|
| 1238 |
+
if warning not in response:
|
| 1239 |
+
response = response + warning
|
| 1240 |
+
|
| 1241 |
+
# Check if response has document references
|
| 1242 |
+
doc_ref_pattern = r'\[Doc\s+\d+\]'
|
| 1243 |
+
has_refs = bool(re.search(doc_ref_pattern, response))
|
| 1244 |
+
|
| 1245 |
+
# If response has factual claims but no references, add a note
|
| 1246 |
+
if not has_refs and len(documents) > 0:
|
| 1247 |
+
# Check if response has numbers or specific claims (simple heuristic)
|
| 1248 |
+
has_numbers = bool(re.search(r'\d+', response))
|
| 1249 |
+
if has_numbers and len(response) > 50:
|
| 1250 |
+
logger.warning("β οΈ Response contains factual claims but no document references")
|
| 1251 |
+
# Don't modify response, but log the issue
|
| 1252 |
+
|
| 1253 |
+
return response
|
| 1254 |
+
|
| 1255 |
+
def _correct_misspellings_in_response(self, response: str, correct_districts: set, correct_sources: set) -> str:
|
| 1256 |
+
"""Correct common misspellings in response using correct names from documents."""
|
| 1257 |
+
# Common misspelling mappings (e.g., "Kalagala" -> "Kalangala")
|
| 1258 |
+
# We'll use fuzzy matching if needed, but first try direct corrections
|
| 1259 |
+
|
| 1260 |
+
corrected = response
|
| 1261 |
+
|
| 1262 |
+
# Correct district names
|
| 1263 |
+
for correct_district in correct_districts:
|
| 1264 |
+
# Try common misspellings
|
| 1265 |
+
if correct_district.lower() == "kalangala":
|
| 1266 |
+
# Replace "Kalagala" (missing 'n') with "Kalangala"
|
| 1267 |
+
corrected = re.sub(r'\bKalagala\b', 'Kalangala', corrected, flags=re.IGNORECASE)
|
| 1268 |
+
# Add more common misspellings as needed
|
| 1269 |
+
# For now, we rely on the LLM to use correct names from the prompt
|
| 1270 |
+
|
| 1271 |
+
# Correct source names if needed
|
| 1272 |
+
# Add source corrections as needed in the future
|
| 1273 |
+
|
| 1274 |
+
return corrected
|
| 1275 |
+
|
| 1276 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1277 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1278 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
smart_chatbot.py β src/agents/smart_chatbot.py
RENAMED
|
File without changes
|
src/feedback/__init__.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Management Module
|
| 3 |
+
|
| 4 |
+
This module provides a unified interface for handling user feedback,
|
| 5 |
+
including data preparation, validation, and Snowflake storage.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Any, List, Optional
|
| 9 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 10 |
+
|
| 11 |
+
from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
|
| 12 |
+
from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FeedbackManager:
|
| 16 |
+
"""
|
| 17 |
+
Unified manager for feedback operations.
|
| 18 |
+
|
| 19 |
+
This class provides a single interface for all feedback-related functionality,
|
| 20 |
+
including data preparation, validation, and storage.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize the FeedbackManager"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
|
| 29 |
+
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
|
| 30 |
+
transcript = []
|
| 31 |
+
for msg in messages:
|
| 32 |
+
if isinstance(msg, HumanMessage):
|
| 33 |
+
transcript.append({
|
| 34 |
+
"role": "user",
|
| 35 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 36 |
+
})
|
| 37 |
+
elif isinstance(msg, AIMessage):
|
| 38 |
+
transcript.append({
|
| 39 |
+
"role": "assistant",
|
| 40 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 41 |
+
})
|
| 42 |
+
return transcript
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
|
| 46 |
+
"""Build retrievals structure from retrieval history"""
|
| 47 |
+
retrievals = []
|
| 48 |
+
|
| 49 |
+
for entry in rag_retrieval_history:
|
| 50 |
+
# Get the user message that triggered this retrieval
|
| 51 |
+
# The entry has conversation_up_to which includes messages up to that point
|
| 52 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 53 |
+
|
| 54 |
+
# Find the last user message in conversation_up_to (this is the trigger)
|
| 55 |
+
user_message_trigger = ""
|
| 56 |
+
for msg_dict in reversed(conversation_up_to):
|
| 57 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 58 |
+
user_message_trigger = msg_dict.get("content", "")
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
# Fallback: if not found in conversation_up_to, get from actual messages
|
| 62 |
+
# This handles edge cases where conversation_up_to might be incomplete
|
| 63 |
+
if not user_message_trigger:
|
| 64 |
+
# Find which retrieval this is (0-indexed)
|
| 65 |
+
retrieval_idx = rag_retrieval_history.index(entry)
|
| 66 |
+
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
|
| 67 |
+
# because each retrieval is preceded by: user message, bot response, user message, ...
|
| 68 |
+
# But we need to account for the fact that the first retrieval happens after the first user message
|
| 69 |
+
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
|
| 70 |
+
if retrieval_idx < len(user_msgs):
|
| 71 |
+
user_message_trigger = str(user_msgs[retrieval_idx].content)
|
| 72 |
+
elif user_msgs:
|
| 73 |
+
# Fallback to last user message
|
| 74 |
+
user_message_trigger = str(user_msgs[-1].content)
|
| 75 |
+
|
| 76 |
+
# Get retrieved documents and truncate content to 100 chars
|
| 77 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 78 |
+
retrieved_docs = []
|
| 79 |
+
for doc in docs_retrieved:
|
| 80 |
+
doc_copy = doc.copy()
|
| 81 |
+
# Truncate content to 100 characters (keep all other fields)
|
| 82 |
+
if "content" in doc_copy:
|
| 83 |
+
doc_copy["content"] = doc_copy["content"][:100]
|
| 84 |
+
retrieved_docs.append(doc_copy)
|
| 85 |
+
|
| 86 |
+
retrievals.append({
|
| 87 |
+
"retrieved_docs": retrieved_docs,
|
| 88 |
+
"user_message_trigger": user_message_trigger
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
return retrievals
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def build_feedback_score_related_retrieval_docs(
|
| 95 |
+
is_feedback_about_last_retrieval: bool,
|
| 96 |
+
messages: List[Any],
|
| 97 |
+
rag_retrieval_history: List[Dict[str, Any]]
|
| 98 |
+
) -> Optional[Dict[str, Any]]:
|
| 99 |
+
"""Build feedback_score_related_retrieval_docs structure"""
|
| 100 |
+
if not rag_retrieval_history:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Get the relevant retrieval entry
|
| 104 |
+
if is_feedback_about_last_retrieval:
|
| 105 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 106 |
+
else:
|
| 107 |
+
# If feedback is about all retrievals, use the last one as default
|
| 108 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 109 |
+
|
| 110 |
+
# Get conversation up to that point
|
| 111 |
+
conversation_up_to = relevant_entry.get("conversation_up_to", [])
|
| 112 |
+
|
| 113 |
+
# Convert to transcript format (role/content)
|
| 114 |
+
conversation_up_to_point = []
|
| 115 |
+
for msg_dict in conversation_up_to:
|
| 116 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 117 |
+
conversation_up_to_point.append({
|
| 118 |
+
"role": "user",
|
| 119 |
+
"content": msg_dict.get("content", "")
|
| 120 |
+
})
|
| 121 |
+
elif msg_dict.get("type") == "AIMessage":
|
| 122 |
+
conversation_up_to_point.append({
|
| 123 |
+
"role": "assistant",
|
| 124 |
+
"content": msg_dict.get("content", "")
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# Get retrieved docs with full content (not truncated)
|
| 128 |
+
retrieved_docs = relevant_entry.get("docs_retrieved", [])
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"conversation_up_to_point": conversation_up_to_point,
|
| 132 |
+
"retrieved_docs": retrieved_docs
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 137 |
+
"""Create UserFeedback instance from dictionary"""
|
| 138 |
+
return create_feedback_from_dict(data)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 142 |
+
"""Save feedback to Snowflake"""
|
| 143 |
+
return save_to_snowflake(feedback, table_name)
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 147 |
+
"""Generate Snowflake schema SQL"""
|
| 148 |
+
return generate_snowflake_schema_sql(table_name)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
|
| 152 |
+
|
src/feedback/feedback_schema.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Schema for RAG Chatbot
|
| 3 |
+
|
| 4 |
+
This module defines dataclasses for feedback data structures
|
| 5 |
+
and provides Snowflake schema generation.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from dataclasses import dataclass, asdict, field
|
| 10 |
+
from typing import List, Optional, Dict, Any, Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class RetrievedDocument:
|
| 16 |
+
"""Single retrieved document metadata"""
|
| 17 |
+
doc_id: str
|
| 18 |
+
filename: str
|
| 19 |
+
page: int
|
| 20 |
+
score: float
|
| 21 |
+
content: str
|
| 22 |
+
metadata: Dict[str, Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class RetrievalEntry:
|
| 27 |
+
"""Single retrieval operation metadata"""
|
| 28 |
+
rag_query: str
|
| 29 |
+
documents_retrieved: List[RetrievedDocument]
|
| 30 |
+
conversation_length: int
|
| 31 |
+
filters_applied: Optional[Dict[str, Any]] = None
|
| 32 |
+
timestamp: Optional[float] = None
|
| 33 |
+
_raw_data: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class UserFeedback:
|
| 38 |
+
"""User feedback submission data"""
|
| 39 |
+
feedback_id: str
|
| 40 |
+
open_ended_feedback: Optional[str]
|
| 41 |
+
score: int
|
| 42 |
+
is_feedback_about_last_retrieval: bool
|
| 43 |
+
conversation_id: str
|
| 44 |
+
timestamp: float
|
| 45 |
+
message_count: int
|
| 46 |
+
has_retrievals: bool
|
| 47 |
+
retrieval_count: int
|
| 48 |
+
transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
|
| 49 |
+
retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
|
| 50 |
+
feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
|
| 51 |
+
retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
|
| 52 |
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 53 |
+
|
| 54 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 55 |
+
"""Convert to dictionary with nested data structures"""
|
| 56 |
+
result = asdict(self)
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
def to_snowflake_schema(self) -> Dict[str, Any]:
|
| 60 |
+
"""Generate Snowflake schema for this dataclass"""
|
| 61 |
+
schema = {
|
| 62 |
+
"feedback_id": "VARCHAR(255)",
|
| 63 |
+
"open_ended_feedback": "VARCHAR(16777216)", # Large text
|
| 64 |
+
"score": "INTEGER",
|
| 65 |
+
"is_feedback_about_last_retrieval": "BOOLEAN",
|
| 66 |
+
"conversation_id": "VARCHAR(255)",
|
| 67 |
+
"timestamp": "NUMBER(20, 0)",
|
| 68 |
+
"message_count": "INTEGER",
|
| 69 |
+
"has_retrievals": "BOOLEAN",
|
| 70 |
+
"retrieval_count": "INTEGER",
|
| 71 |
+
"transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
|
| 72 |
+
"retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
|
| 73 |
+
"feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
|
| 74 |
+
"retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
|
| 75 |
+
"created_at": "TIMESTAMP_NTZ",
|
| 76 |
+
# transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
| 77 |
+
# retrievals structure: [
|
| 78 |
+
# {
|
| 79 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
|
| 80 |
+
# "user_message_trigger": "final user message that triggered this retrieval"
|
| 81 |
+
# },
|
| 82 |
+
# ...
|
| 83 |
+
# ]
|
| 84 |
+
# feedback_score_related_retrieval_docs structure: {
|
| 85 |
+
# "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
|
| 86 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
|
| 87 |
+
# }
|
| 88 |
+
}
|
| 89 |
+
return schema
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
|
| 93 |
+
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 94 |
+
schema = cls.to_snowflake_schema(None)
|
| 95 |
+
|
| 96 |
+
columns = []
|
| 97 |
+
for col_name, col_type in schema.items():
|
| 98 |
+
nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
|
| 99 |
+
columns.append(f" {col_name} {col_type} {nullable}")
|
| 100 |
+
|
| 101 |
+
# Build SQL string properly
|
| 102 |
+
columns_str = ",\n".join(columns)
|
| 103 |
+
|
| 104 |
+
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 105 |
+
{columns_str},
|
| 106 |
+
PRIMARY KEY (feedback_id)
|
| 107 |
+
)
|
| 108 |
+
CLUSTER BY (timestamp, conversation_id, score);
|
| 109 |
+
-- Note: Snowflake doesn't support traditional indexes on regular tables.
|
| 110 |
+
-- Instead, we use CLUSTER BY to optimize queries on these columns.
|
| 111 |
+
-- Snowflake automatically maintains clustering for efficient querying.
|
| 112 |
+
-- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
|
| 113 |
+
-- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
|
| 114 |
+
"""
|
| 115 |
+
return sql
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Snowflake variant schema for retrieved_data array
|
| 119 |
+
RETRIEVAL_ENTRY_SCHEMA = {
|
| 120 |
+
"rag_query": "VARCHAR",
|
| 121 |
+
"documents_retrieved": "ARRAY", # Array of document objects
|
| 122 |
+
"conversation_length": "INTEGER",
|
| 123 |
+
"filters_applied": "OBJECT",
|
| 124 |
+
"timestamp": "NUMBER"
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
DOCUMENT_SCHEMA = {
|
| 128 |
+
"doc_id": "VARCHAR",
|
| 129 |
+
"filename": "VARCHAR",
|
| 130 |
+
"page": "INTEGER",
|
| 131 |
+
"score": "DOUBLE",
|
| 132 |
+
"content": "VARCHAR(16777216)",
|
| 133 |
+
"metadata": "OBJECT"
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 138 |
+
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 139 |
+
if table_name is None:
|
| 140 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 141 |
+
return UserFeedback.get_snowflake_create_table_sql(table_name)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 145 |
+
"""Create UserFeedback instance from dictionary"""
|
| 146 |
+
return UserFeedback(
|
| 147 |
+
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 148 |
+
open_ended_feedback=data.get("open_ended_feedback"),
|
| 149 |
+
score=data["score"],
|
| 150 |
+
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 151 |
+
conversation_id=data["conversation_id"],
|
| 152 |
+
timestamp=data["timestamp"],
|
| 153 |
+
message_count=data["message_count"],
|
| 154 |
+
has_retrievals=data["has_retrievals"],
|
| 155 |
+
retrieval_count=data["retrieval_count"],
|
| 156 |
+
transcript=data.get("transcript", []),
|
| 157 |
+
retrievals=data.get("retrievals", []),
|
| 158 |
+
feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
|
| 159 |
+
retrieved_data=data.get("retrieved_data")
|
| 160 |
+
)
|
| 161 |
+
|
src/feedback/snowflake_connector.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Snowflake Connector for Feedback System
|
| 3 |
+
|
| 4 |
+
This module handles inserting user feedback into Snowflake.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from .feedback_schema import UserFeedback
|
| 12 |
+
|
| 13 |
+
# Try to import snowflake connector
|
| 14 |
+
try:
|
| 15 |
+
import snowflake.connector
|
| 16 |
+
SNOWFLAKE_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SNOWFLAKE_AVAILABLE = False
|
| 19 |
+
logging.warning("β οΈ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SnowflakeFeedbackConnector:
|
| 27 |
+
"""Connector for inserting feedback into Snowflake"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
user: str,
|
| 32 |
+
password: str,
|
| 33 |
+
account: str,
|
| 34 |
+
warehouse: str,
|
| 35 |
+
database: str = "SNOWFLAKE_LEARNING",
|
| 36 |
+
schema: str = "PUBLIC"
|
| 37 |
+
):
|
| 38 |
+
self.user = user
|
| 39 |
+
self.password = password
|
| 40 |
+
self.account = account
|
| 41 |
+
self.warehouse = warehouse
|
| 42 |
+
self.database = database
|
| 43 |
+
self.schema = schema
|
| 44 |
+
self._connection = None
|
| 45 |
+
|
| 46 |
+
def connect(self):
|
| 47 |
+
"""Establish Snowflake connection"""
|
| 48 |
+
if not SNOWFLAKE_AVAILABLE:
|
| 49 |
+
raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
|
| 50 |
+
|
| 51 |
+
logger.info("=" * 80)
|
| 52 |
+
logger.info("π SNOWFLAKE CONNECTION: Attempting to connect...")
|
| 53 |
+
logger.info(f" - Account: {self.account}")
|
| 54 |
+
logger.info(f" - Warehouse: {self.warehouse}")
|
| 55 |
+
logger.info(f" - Database: {self.database}")
|
| 56 |
+
logger.info(f" - Schema: {self.schema}")
|
| 57 |
+
logger.info(f" - User: {self.user}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
self._connection = snowflake.connector.connect(
|
| 61 |
+
user=self.user,
|
| 62 |
+
password=self.password,
|
| 63 |
+
account=self.account,
|
| 64 |
+
warehouse=self.warehouse
|
| 65 |
+
# Don't set database/schema in connection - we'll do it per query
|
| 66 |
+
)
|
| 67 |
+
logger.info("β
SNOWFLAKE CONNECTION: Successfully connected")
|
| 68 |
+
logger.info("=" * 80)
|
| 69 |
+
print(f"β
Connected to Snowflake: {self.database}.{self.schema}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"β SNOWFLAKE CONNECTION FAILED: {e}")
|
| 72 |
+
logger.error("=" * 80)
|
| 73 |
+
print(f"β Failed to connect to Snowflake: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def disconnect(self):
|
| 77 |
+
"""Close Snowflake connection"""
|
| 78 |
+
if self._connection:
|
| 79 |
+
self._connection.close()
|
| 80 |
+
print("β
Disconnected from Snowflake")
|
| 81 |
+
|
| 82 |
+
def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 83 |
+
"""Insert a single feedback record into Snowflake"""
|
| 84 |
+
logger.info("=" * 80)
|
| 85 |
+
logger.info("π SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 87 |
+
|
| 88 |
+
# Get table name from parameter, env var, or default
|
| 89 |
+
if table_name is None:
|
| 90 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 91 |
+
|
| 92 |
+
if not self._connection:
|
| 93 |
+
logger.error("β Not connected to Snowflake. Call connect() first.")
|
| 94 |
+
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
logger.info("π VALIDATION: Validating feedback data structure...")
|
| 98 |
+
|
| 99 |
+
# Validate feedback object
|
| 100 |
+
validation_errors = []
|
| 101 |
+
if not feedback.feedback_id:
|
| 102 |
+
validation_errors.append("Missing feedback_id")
|
| 103 |
+
if feedback.score is None:
|
| 104 |
+
validation_errors.append("Missing score")
|
| 105 |
+
if feedback.timestamp is None:
|
| 106 |
+
validation_errors.append("Missing timestamp")
|
| 107 |
+
|
| 108 |
+
if validation_errors:
|
| 109 |
+
logger.error(f"β VALIDATION FAILED: {validation_errors}")
|
| 110 |
+
return False
|
| 111 |
+
else:
|
| 112 |
+
logger.info("β
VALIDATION PASSED: All required fields present")
|
| 113 |
+
|
| 114 |
+
logger.info("π Data Summary:")
|
| 115 |
+
logger.info(f" - Feedback ID: {feedback.feedback_id}")
|
| 116 |
+
logger.info(f" - Score: {feedback.score}")
|
| 117 |
+
logger.info(f" - Conversation ID: {feedback.conversation_id}")
|
| 118 |
+
logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
|
| 119 |
+
logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
|
| 120 |
+
logger.info(f" - Message Count: {feedback.message_count}")
|
| 121 |
+
logger.info(f" - Timestamp: {feedback.timestamp}")
|
| 122 |
+
|
| 123 |
+
cursor = self._connection.cursor()
|
| 124 |
+
logger.info("β
SNOWFLAKE CONNECTION: Cursor created")
|
| 125 |
+
|
| 126 |
+
# Set database and schema context
|
| 127 |
+
logger.info(f"π§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
|
| 128 |
+
try:
|
| 129 |
+
cursor.execute(f'USE DATABASE "{self.database}"')
|
| 130 |
+
cursor.execute(f'USE SCHEMA "{self.schema}"')
|
| 131 |
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
| 132 |
+
current_db, current_schema = cursor.fetchone()
|
| 133 |
+
logger.info(f"β
Current context verified: Database={current_db}, Schema={current_schema}")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"β Could not set context: {e}")
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
# Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 139 |
+
logger.info("π§ DATA PREPARATION: Preparing VARIANT columns...")
|
| 140 |
+
feedback_dict = feedback.to_dict()
|
| 141 |
+
|
| 142 |
+
# Prepare transcript (ARRAY) - convert to JSON string
|
| 143 |
+
transcript_raw = feedback_dict.get('transcript', [])
|
| 144 |
+
if transcript_raw:
|
| 145 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 146 |
+
transcript_for_db = json.dumps(transcript_raw)
|
| 147 |
+
logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
|
| 148 |
+
else:
|
| 149 |
+
transcript_for_db = None
|
| 150 |
+
logger.info(" - Transcript: None")
|
| 151 |
+
|
| 152 |
+
# Prepare retrievals (ARRAY) - convert to JSON string
|
| 153 |
+
retrievals_raw = feedback_dict.get('retrievals', [])
|
| 154 |
+
if retrievals_raw:
|
| 155 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 156 |
+
retrievals_for_db = json.dumps(retrievals_raw)
|
| 157 |
+
logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
|
| 158 |
+
else:
|
| 159 |
+
retrievals_for_db = None
|
| 160 |
+
logger.info(" - Retrievals: None")
|
| 161 |
+
|
| 162 |
+
# Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
|
| 163 |
+
feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
|
| 164 |
+
if feedback_score_related_raw:
|
| 165 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 166 |
+
feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
|
| 167 |
+
logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
|
| 168 |
+
else:
|
| 169 |
+
feedback_score_related_for_db = None
|
| 170 |
+
logger.info(" - Feedback score related docs: None")
|
| 171 |
+
|
| 172 |
+
# Prepare retrieved_data (preserved old column) - convert to JSON string
|
| 173 |
+
retrieved_data_raw = feedback_dict.get('retrieved_data')
|
| 174 |
+
if retrieved_data_raw:
|
| 175 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 176 |
+
retrieved_data_for_db = json.dumps(retrieved_data_raw)
|
| 177 |
+
logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
|
| 178 |
+
else:
|
| 179 |
+
retrieved_data_for_db = None
|
| 180 |
+
logger.info(" - Retrieved data (preserved): None")
|
| 181 |
+
|
| 182 |
+
# Build SQL with new column structure
|
| 183 |
+
# Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
|
| 184 |
+
sql = f"""INSERT INTO {table_name} (
|
| 185 |
+
feedback_id,
|
| 186 |
+
open_ended_feedback,
|
| 187 |
+
score,
|
| 188 |
+
is_feedback_about_last_retrieval,
|
| 189 |
+
conversation_id,
|
| 190 |
+
timestamp,
|
| 191 |
+
message_count,
|
| 192 |
+
has_retrievals,
|
| 193 |
+
retrieval_count,
|
| 194 |
+
transcript,
|
| 195 |
+
retrievals,
|
| 196 |
+
feedback_score_related_retrieval_docs,
|
| 197 |
+
retrieved_data,
|
| 198 |
+
created_at
|
| 199 |
+
) VALUES (
|
| 200 |
+
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 201 |
+
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 202 |
+
%(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
|
| 203 |
+
%(retrieved_data)s, %(created_at)s
|
| 204 |
+
)"""
|
| 205 |
+
|
| 206 |
+
logger.info("π SQL PREPARATION: Building INSERT statement...")
|
| 207 |
+
logger.info(f" - Target table: {table_name}")
|
| 208 |
+
logger.info(f" - Database: {self.database}")
|
| 209 |
+
logger.info(f" - Schema: {self.schema}")
|
| 210 |
+
|
| 211 |
+
# Prepare parameters
|
| 212 |
+
# Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 213 |
+
params = {
|
| 214 |
+
'feedback_id': feedback.feedback_id,
|
| 215 |
+
'open_ended_feedback': feedback.open_ended_feedback,
|
| 216 |
+
'score': feedback.score,
|
| 217 |
+
'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
|
| 218 |
+
'conversation_id': feedback.conversation_id,
|
| 219 |
+
'timestamp': int(feedback.timestamp),
|
| 220 |
+
'message_count': feedback.message_count,
|
| 221 |
+
'has_retrievals': feedback.has_retrievals,
|
| 222 |
+
'retrieval_count': feedback.retrieval_count,
|
| 223 |
+
'transcript': transcript_for_db, # JSON string
|
| 224 |
+
'retrievals': retrievals_for_db, # JSON string
|
| 225 |
+
'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
|
| 226 |
+
'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
|
| 227 |
+
'created_at': feedback.created_at
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# Execute insert
|
| 231 |
+
logger.info("π SQL EXECUTION: Executing INSERT query...")
|
| 232 |
+
cursor.execute(sql, params)
|
| 233 |
+
|
| 234 |
+
logger.info("β
SQL EXECUTION: Query executed successfully")
|
| 235 |
+
logger.info(f" - Rows affected: 1")
|
| 236 |
+
logger.info(f" - Status: SUCCESS")
|
| 237 |
+
|
| 238 |
+
cursor.close()
|
| 239 |
+
logger.info("β
SNOWFLAKE INSERT: Feedback inserted successfully")
|
| 240 |
+
logger.info(f"π Inserted feedback: {feedback.feedback_id}")
|
| 241 |
+
logger.info("=" * 80)
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
# Check if it's a Snowflake error
|
| 246 |
+
if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
|
| 247 |
+
logger.error(f"β SQL EXECUTION ERROR: {e}")
|
| 248 |
+
logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
|
| 249 |
+
logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
|
| 250 |
+
else:
|
| 251 |
+
logger.error(f"β SNOWFLAKE INSERT FAILED: {type(e).__name__}")
|
| 252 |
+
logger.error(f" - Error: {e}")
|
| 253 |
+
logger.error("=" * 80)
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def __enter__(self):
|
| 257 |
+
"""Context manager entry"""
|
| 258 |
+
self.connect()
|
| 259 |
+
return self
|
| 260 |
+
|
| 261 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 262 |
+
"""Context manager exit"""
|
| 263 |
+
self.disconnect()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
| 267 |
+
"""Create Snowflake connector from environment variables"""
|
| 268 |
+
user = os.getenv("SNOWFLAKE_USER")
|
| 269 |
+
password = os.getenv("SNOWFLAKE_PASSWORD")
|
| 270 |
+
account = os.getenv("SNOWFLAKE_ACCOUNT")
|
| 271 |
+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
|
| 272 |
+
database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
|
| 273 |
+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
| 274 |
+
|
| 275 |
+
if not all([user, password, account, warehouse]):
|
| 276 |
+
print("β οΈ Snowflake credentials not found in environment variables")
|
| 277 |
+
print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
return SnowflakeFeedbackConnector(
|
| 281 |
+
user=user,
|
| 282 |
+
password=password,
|
| 283 |
+
account=account,
|
| 284 |
+
warehouse=warehouse,
|
| 285 |
+
database=database,
|
| 286 |
+
schema=schema
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 291 |
+
"""Helper function to save feedback to Snowflake"""
|
| 292 |
+
logger.info("=" * 80)
|
| 293 |
+
logger.info("π΅ SNOWFLAKE SAVE: Starting save process")
|
| 294 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 295 |
+
|
| 296 |
+
# Get table name from parameter or env var
|
| 297 |
+
if table_name is None:
|
| 298 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 299 |
+
|
| 300 |
+
connector = get_snowflake_connector_from_env()
|
| 301 |
+
|
| 302 |
+
if not connector:
|
| 303 |
+
logger.warning("β οΈ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
|
| 304 |
+
logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 305 |
+
logger.info("=" * 80)
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
logger.info("π‘ SNOWFLAKE SAVE: Establishing connection...")
|
| 310 |
+
connector.connect()
|
| 311 |
+
logger.info("β
SNOWFLAKE SAVE: Connection established")
|
| 312 |
+
|
| 313 |
+
logger.info("π₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 314 |
+
success = connector.insert_feedback(feedback, table_name=table_name)
|
| 315 |
+
|
| 316 |
+
logger.info("π SNOWFLAKE SAVE: Disconnecting...")
|
| 317 |
+
connector.disconnect()
|
| 318 |
+
|
| 319 |
+
if success:
|
| 320 |
+
logger.info("β
SNOWFLAKE SAVE: Successfully saved feedback")
|
| 321 |
+
else:
|
| 322 |
+
logger.error("β SNOWFLAKE SAVE: Failed to save feedback")
|
| 323 |
+
|
| 324 |
+
logger.info("=" * 80)
|
| 325 |
+
return success
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"β SNOWFLAKE SAVE ERROR: {type(e).__name__}")
|
| 328 |
+
logger.error(f" - Error: {e}")
|
| 329 |
+
logger.info("=" * 80)
|
| 330 |
+
return False
|
| 331 |
+
|
src/gemini/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Integration Module
|
| 3 |
+
|
| 4 |
+
This module provides integration with Google Gemini File Search API
|
| 5 |
+
for RAG functionality using Gemini's built-in file search capabilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 9 |
+
|
| 10 |
+
__all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
|
| 11 |
+
|
src/gemini/file_search.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Client
|
| 3 |
+
|
| 4 |
+
Handles interaction with Google Gemini File Search API for RAG.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai import types
|
| 17 |
+
GEMINI_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
GEMINI_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class GeminiFileSearchResult:
|
| 24 |
+
"""Result from Gemini File Search query"""
|
| 25 |
+
answer: str
|
| 26 |
+
sources: List[Dict[str, Any]] # List of document references
|
| 27 |
+
grounding_metadata: Optional[Dict[str, Any]] = None
|
| 28 |
+
query: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GeminiFileSearchClient:
|
| 32 |
+
"""Client for interacting with Gemini File Search API"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
|
| 35 |
+
"""
|
| 36 |
+
Initialize Gemini File Search client.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
|
| 40 |
+
store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
|
| 41 |
+
"""
|
| 42 |
+
if not GEMINI_AVAILABLE:
|
| 43 |
+
raise ImportError("google-genai package not installed. Install with: pip install google-genai")
|
| 44 |
+
|
| 45 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 46 |
+
if not self.api_key:
|
| 47 |
+
raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
|
| 48 |
+
|
| 49 |
+
self.store_name = store_name or os.getenv("GEMINI_FILESTORE_NAME")
|
| 50 |
+
if not self.store_name:
|
| 51 |
+
raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
|
| 52 |
+
|
| 53 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 54 |
+
self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
|
| 55 |
+
|
| 56 |
+
def search(
|
| 57 |
+
self,
|
| 58 |
+
query: str,
|
| 59 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 60 |
+
model: Optional[str] = None
|
| 61 |
+
) -> GeminiFileSearchResult:
|
| 62 |
+
"""
|
| 63 |
+
Search using Gemini File Search.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
query: User query
|
| 67 |
+
filters: Optional filters (year, source, district, etc.)
|
| 68 |
+
model: Model to use (defaults to gemini-2.5-flash)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
GeminiFileSearchResult with answer and sources
|
| 72 |
+
"""
|
| 73 |
+
model = model or self.model
|
| 74 |
+
|
| 75 |
+
# Build filter context for the query if filters are provided
|
| 76 |
+
# Gemini File Search doesn't support explicit filters in the API,
|
| 77 |
+
# so we add them as context in the query
|
| 78 |
+
filter_context = ""
|
| 79 |
+
if filters:
|
| 80 |
+
filter_parts = []
|
| 81 |
+
if filters.get("year"):
|
| 82 |
+
years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
|
| 83 |
+
filter_parts.append(f"Year: {', '.join(years)}")
|
| 84 |
+
if filters.get("sources"):
|
| 85 |
+
sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
|
| 86 |
+
filter_parts.append(f"Source: {', '.join(sources)}")
|
| 87 |
+
if filters.get("district"):
|
| 88 |
+
districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
|
| 89 |
+
filter_parts.append(f"District: {', '.join(districts)}")
|
| 90 |
+
if filters.get("filenames"):
|
| 91 |
+
filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
|
| 92 |
+
filter_parts.append(f"Filename: {', '.join(filenames)}")
|
| 93 |
+
|
| 94 |
+
if filter_parts:
|
| 95 |
+
filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
|
| 96 |
+
|
| 97 |
+
# Combine query with filter context
|
| 98 |
+
# Add explicit instruction to only use information from retrieved documents
|
| 99 |
+
instruction = "\n\nIMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents. If the retrieved documents don't contain the requested information, clearly state that.\n\n"
|
| 100 |
+
full_query = query + filter_context + instruction
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Generate content with file search
|
| 104 |
+
# Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
|
| 105 |
+
try:
|
| 106 |
+
# Try the documented format first
|
| 107 |
+
response = self.client.models.generate_content(
|
| 108 |
+
model=model,
|
| 109 |
+
contents=full_query,
|
| 110 |
+
config=types.GenerateContentConfig(
|
| 111 |
+
tools=[
|
| 112 |
+
types.Tool(
|
| 113 |
+
file_search=types.FileSearch(
|
| 114 |
+
file_search_store_names=[self.store_name]
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
except (AttributeError, TypeError) as e:
|
| 121 |
+
# Fallback: try alternative format
|
| 122 |
+
logger.warning(f"Primary API format failed, trying alternative: {e}")
|
| 123 |
+
try:
|
| 124 |
+
response = self.client.models.generate_content(
|
| 125 |
+
model=model,
|
| 126 |
+
contents=full_query,
|
| 127 |
+
tools=[{
|
| 128 |
+
"file_search": {
|
| 129 |
+
"file_search_store_names": [self.store_name]
|
| 130 |
+
}
|
| 131 |
+
}]
|
| 132 |
+
)
|
| 133 |
+
except Exception as e2:
|
| 134 |
+
raise Exception(f"Failed to call Gemini API: {e2}")
|
| 135 |
+
|
| 136 |
+
# Extract answer
|
| 137 |
+
answer = ""
|
| 138 |
+
if hasattr(response, 'text'):
|
| 139 |
+
answer = response.text
|
| 140 |
+
elif hasattr(response, 'candidates') and response.candidates:
|
| 141 |
+
# Try to get text from first candidate
|
| 142 |
+
candidate = response.candidates[0]
|
| 143 |
+
if hasattr(candidate, 'content') and candidate.content:
|
| 144 |
+
if hasattr(candidate.content, 'parts'):
|
| 145 |
+
text_parts = []
|
| 146 |
+
for part in candidate.content.parts:
|
| 147 |
+
if hasattr(part, 'text'):
|
| 148 |
+
text_parts.append(part.text)
|
| 149 |
+
answer = " ".join(text_parts)
|
| 150 |
+
elif isinstance(candidate.content, str):
|
| 151 |
+
answer = candidate.content
|
| 152 |
+
else:
|
| 153 |
+
answer = str(response)
|
| 154 |
+
|
| 155 |
+
# Extract grounding metadata (document references)
|
| 156 |
+
sources = []
|
| 157 |
+
grounding_metadata = None
|
| 158 |
+
|
| 159 |
+
if hasattr(response, 'candidates') and response.candidates:
|
| 160 |
+
candidate = response.candidates[0]
|
| 161 |
+
|
| 162 |
+
# Get grounding metadata
|
| 163 |
+
if hasattr(candidate, 'grounding_metadata'):
|
| 164 |
+
grounding_metadata = candidate.grounding_metadata
|
| 165 |
+
|
| 166 |
+
# Extract source documents from grounding metadata
|
| 167 |
+
# Handle different response formats
|
| 168 |
+
grounding_chunks = None
|
| 169 |
+
if hasattr(grounding_metadata, 'grounding_chunks'):
|
| 170 |
+
grounding_chunks = grounding_metadata.grounding_chunks
|
| 171 |
+
elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
|
| 172 |
+
grounding_chunks = grounding_metadata['grounding_chunks']
|
| 173 |
+
|
| 174 |
+
if grounding_chunks:
|
| 175 |
+
for chunk in grounding_chunks:
|
| 176 |
+
# Handle both object and dict formats
|
| 177 |
+
try:
|
| 178 |
+
if isinstance(chunk, dict):
|
| 179 |
+
chunk_data = chunk
|
| 180 |
+
else:
|
| 181 |
+
# Object format - convert to dict-like access
|
| 182 |
+
chunk_data = {}
|
| 183 |
+
if hasattr(chunk, 'chunk'):
|
| 184 |
+
chunk_obj = chunk.chunk
|
| 185 |
+
chunk_data['chunk'] = {
|
| 186 |
+
'text': getattr(chunk_obj, 'text', ''),
|
| 187 |
+
'file_name': getattr(chunk_obj, 'file_name', '')
|
| 188 |
+
}
|
| 189 |
+
if hasattr(chunk, 'relevance_score'):
|
| 190 |
+
score_obj = chunk.relevance_score
|
| 191 |
+
chunk_data['relevance_score'] = {
|
| 192 |
+
'score': getattr(score_obj, 'score', 0.0)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
chunk_info = chunk_data.get('chunk', {})
|
| 196 |
+
text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
|
| 197 |
+
file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
|
| 198 |
+
|
| 199 |
+
score_data = chunk_data.get('relevance_score', {})
|
| 200 |
+
score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
|
| 201 |
+
|
| 202 |
+
if text or file_name: # Only add if we have content
|
| 203 |
+
source_info = {
|
| 204 |
+
"content": text,
|
| 205 |
+
"filename": file_name,
|
| 206 |
+
"score": score,
|
| 207 |
+
}
|
| 208 |
+
sources.append(source_info)
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.warning(f"Error extracting chunk info: {e}")
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
return GeminiFileSearchResult(
|
| 214 |
+
answer=answer,
|
| 215 |
+
sources=sources,
|
| 216 |
+
grounding_metadata=grounding_metadata,
|
| 217 |
+
query=query
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
except Exception as e:
|
| 221 |
+
# Return error result
|
| 222 |
+
return GeminiFileSearchResult(
|
| 223 |
+
answer=f"I apologize, but I encountered an error: {str(e)}",
|
| 224 |
+
sources=[],
|
| 225 |
+
query=query
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
|
| 229 |
+
"""
|
| 230 |
+
Format Gemini sources to match the format expected by the UI.
|
| 231 |
+
|
| 232 |
+
Returns list of document-like objects compatible with existing display code.
|
| 233 |
+
"""
|
| 234 |
+
from langchain.docstore.document import Document
|
| 235 |
+
|
| 236 |
+
formatted_sources = []
|
| 237 |
+
|
| 238 |
+
for i, source in enumerate(result.sources):
|
| 239 |
+
# Create a Document object compatible with existing code
|
| 240 |
+
doc = Document(
|
| 241 |
+
page_content=source.get("content", ""),
|
| 242 |
+
metadata={
|
| 243 |
+
"filename": source.get("filename", "Unknown"),
|
| 244 |
+
"source": "Gemini File Search",
|
| 245 |
+
"score": source.get("score"),
|
| 246 |
+
"chunk_index": i,
|
| 247 |
+
# Add default fields that might be expected
|
| 248 |
+
"page": None,
|
| 249 |
+
"year": None,
|
| 250 |
+
"district": None,
|
| 251 |
+
}
|
| 252 |
+
)
|
| 253 |
+
formatted_sources.append(doc)
|
| 254 |
+
|
| 255 |
+
return formatted_sources
|
| 256 |
+
|
src/{loader.py β llm/loader.py}
RENAMED
|
File without changes
|
src/pipeline.py
CHANGED
|
@@ -14,7 +14,7 @@ except ModuleNotFoundError as me:
|
|
| 14 |
|
| 15 |
from .logging import log_error
|
| 16 |
|
| 17 |
-
from .loader import chunks_to_documents
|
| 18 |
from .vectorstore import VectorStoreManager
|
| 19 |
from .reporting.service import ReportService
|
| 20 |
from .retrieval.context import ContextRetriever
|
|
|
|
| 14 |
|
| 15 |
from .logging import log_error
|
| 16 |
|
| 17 |
+
from .llm.loader import chunks_to_documents
|
| 18 |
from .vectorstore import VectorStoreManager
|
| 19 |
from .reporting.service import ReportService
|
| 20 |
from .retrieval.context import ContextRetriever
|
src/reporting/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
"""Report metadata and utilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from .metadata import get_report_metadata, get_available_sources
|
| 4 |
from .service import ReportService
|
|
|
|
| 1 |
+
"""Report metadata and utilities.
|
| 2 |
+
|
| 3 |
+
This module is kept for backward compatibility with pipeline.py.
|
| 4 |
+
For feedback-related functionality, use src.feedback instead.
|
| 5 |
+
"""
|
| 6 |
|
| 7 |
from .metadata import get_report_metadata, get_available_sources
|
| 8 |
from .service import ReportService
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ui_components/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI Components Module
|
| 3 |
+
|
| 4 |
+
This module contains UI-related components including styles, visualizations,
|
| 5 |
+
and utility functions for the Streamlit application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .styles import get_custom_css
|
| 9 |
+
from .components import (
|
| 10 |
+
display_chunk_statistics_charts,
|
| 11 |
+
display_chunk_statistics_table
|
| 12 |
+
)
|
| 13 |
+
from .utils import extract_chunk_statistics
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"get_custom_css",
|
| 17 |
+
"display_chunk_statistics_charts",
|
| 18 |
+
"display_chunk_statistics_table",
|
| 19 |
+
"extract_chunk_statistics"
|
| 20 |
+
]
|
| 21 |
+
|
src/ui_components/components.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI components for displaying statistics and visualizations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
|
| 12 |
+
"""Display statistics as interactive charts for 10+ results."""
|
| 13 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# Wrap everything in one styled container - open it
|
| 17 |
+
st.markdown(f"""
|
| 18 |
+
<div class="retrieval-distribution-container">
|
| 19 |
+
<h3 style="margin-top: 0;">π {title}</h3>
|
| 20 |
+
<div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
|
| 21 |
+
<div class="metric-container">
|
| 22 |
+
<div class="metric-label">Total Chunks</div>
|
| 23 |
+
<div class="metric-value">{stats['total_chunks']}</div>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="metric-container">
|
| 26 |
+
<div class="metric-label">Unique Sources</div>
|
| 27 |
+
<div class="metric-value">{stats['unique_sources']}</div>
|
| 28 |
+
</div>
|
| 29 |
+
<div class="metric-container">
|
| 30 |
+
<div class="metric-label">Unique Years</div>
|
| 31 |
+
<div class="metric-value">{stats['unique_years']}</div>
|
| 32 |
+
</div>
|
| 33 |
+
<div class="metric-container">
|
| 34 |
+
<div class="metric-label">Unique Files</div>
|
| 35 |
+
<div class="metric-value">{stats['unique_filenames']}</div>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
""", unsafe_allow_html=True)
|
| 39 |
+
|
| 40 |
+
# Charts - three columns to include Districts
|
| 41 |
+
col1, col2, col3 = st.columns(3)
|
| 42 |
+
|
| 43 |
+
with col1:
|
| 44 |
+
# Source distribution chart
|
| 45 |
+
if stats['source_distribution']:
|
| 46 |
+
source_df = pd.DataFrame(
|
| 47 |
+
list(stats['source_distribution'].items()),
|
| 48 |
+
columns=['Source', 'Count']
|
| 49 |
+
)
|
| 50 |
+
fig_source = px.bar(
|
| 51 |
+
source_df,
|
| 52 |
+
x='Count',
|
| 53 |
+
y='Source',
|
| 54 |
+
orientation='h',
|
| 55 |
+
title='Distribution by Source',
|
| 56 |
+
color='Count',
|
| 57 |
+
color_continuous_scale='viridis'
|
| 58 |
+
)
|
| 59 |
+
fig_source.update_layout(height=400, showlegend=False)
|
| 60 |
+
st.plotly_chart(fig_source, use_container_width=True)
|
| 61 |
+
|
| 62 |
+
with col2:
|
| 63 |
+
# Year distribution chart
|
| 64 |
+
if stats['year_distribution']:
|
| 65 |
+
# Filter out 'Unknown' years for the chart
|
| 66 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 67 |
+
if year_dist_filtered:
|
| 68 |
+
year_df = pd.DataFrame(
|
| 69 |
+
list(year_dist_filtered.items()),
|
| 70 |
+
columns=['Year', 'Count']
|
| 71 |
+
)
|
| 72 |
+
# Sort by year as integer but keep as string for categorical display
|
| 73 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 74 |
+
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
|
| 75 |
+
|
| 76 |
+
fig_year = px.bar(
|
| 77 |
+
year_df,
|
| 78 |
+
x='Year',
|
| 79 |
+
y='Count',
|
| 80 |
+
title='Distribution by Year',
|
| 81 |
+
color='Count',
|
| 82 |
+
color_continuous_scale='plasma'
|
| 83 |
+
)
|
| 84 |
+
# Ensure years are treated as categorical (discrete) not continuous
|
| 85 |
+
fig_year.update_xaxes(type='category')
|
| 86 |
+
fig_year.update_layout(height=400, showlegend=False)
|
| 87 |
+
st.plotly_chart(fig_year, use_container_width=True)
|
| 88 |
+
else:
|
| 89 |
+
st.info("No valid years found in the results")
|
| 90 |
+
|
| 91 |
+
with col3:
|
| 92 |
+
# District distribution chart
|
| 93 |
+
if stats.get('district_distribution'):
|
| 94 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 95 |
+
if district_dist_filtered:
|
| 96 |
+
district_df = pd.DataFrame(
|
| 97 |
+
list(district_dist_filtered.items()),
|
| 98 |
+
columns=['District', 'Count']
|
| 99 |
+
)
|
| 100 |
+
district_df = district_df.sort_values('Count', ascending=False)
|
| 101 |
+
|
| 102 |
+
fig_district = px.bar(
|
| 103 |
+
district_df,
|
| 104 |
+
x='Count',
|
| 105 |
+
y='District',
|
| 106 |
+
orientation='h',
|
| 107 |
+
title='Distribution by District',
|
| 108 |
+
color='Count',
|
| 109 |
+
color_continuous_scale='blues'
|
| 110 |
+
)
|
| 111 |
+
fig_district.update_layout(height=400, showlegend=False)
|
| 112 |
+
st.plotly_chart(fig_district, use_container_width=True)
|
| 113 |
+
else:
|
| 114 |
+
st.info("No valid districts found in the results")
|
| 115 |
+
|
| 116 |
+
# Close the container
|
| 117 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
|
| 121 |
+
"""Display statistics as tables for smaller results with fixed alignment."""
|
| 122 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# Wrap in styled container
|
| 126 |
+
st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 127 |
+
|
| 128 |
+
st.subheader(f"π {title}")
|
| 129 |
+
|
| 130 |
+
# Create a container with fixed height for alignment
|
| 131 |
+
stats_container = st.container()
|
| 132 |
+
|
| 133 |
+
with stats_container:
|
| 134 |
+
# Create 4 equal columns for consistent alignment
|
| 135 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 136 |
+
|
| 137 |
+
with col1:
|
| 138 |
+
st.markdown("**ποΈ Districts**")
|
| 139 |
+
if stats.get('district_distribution'):
|
| 140 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 141 |
+
if district_dist_filtered:
|
| 142 |
+
district_data = {
|
| 143 |
+
"District": list(district_dist_filtered.keys()),
|
| 144 |
+
"Count": list(district_dist_filtered.values())
|
| 145 |
+
}
|
| 146 |
+
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
|
| 147 |
+
st.dataframe(district_df, hide_index=True, use_container_width=True)
|
| 148 |
+
else:
|
| 149 |
+
st.write("No district data")
|
| 150 |
+
else:
|
| 151 |
+
st.write("No district data")
|
| 152 |
+
|
| 153 |
+
with col2:
|
| 154 |
+
st.markdown("**π Sources**")
|
| 155 |
+
if stats['source_distribution']:
|
| 156 |
+
source_data = {
|
| 157 |
+
"Source": list(stats['source_distribution'].keys()),
|
| 158 |
+
"Count": list(stats['source_distribution'].values())
|
| 159 |
+
}
|
| 160 |
+
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
|
| 161 |
+
st.dataframe(source_df, hide_index=True, use_container_width=True)
|
| 162 |
+
else:
|
| 163 |
+
st.write("No source data")
|
| 164 |
+
|
| 165 |
+
with col3:
|
| 166 |
+
st.markdown("**π
Years**")
|
| 167 |
+
if stats['year_distribution']:
|
| 168 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 169 |
+
if year_dist_filtered:
|
| 170 |
+
year_data = {
|
| 171 |
+
"Year": list(year_dist_filtered.keys()),
|
| 172 |
+
"Count": list(year_dist_filtered.values())
|
| 173 |
+
}
|
| 174 |
+
year_df = pd.DataFrame(year_data)
|
| 175 |
+
# Sort by year as integer but display as string
|
| 176 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 177 |
+
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
|
| 178 |
+
st.dataframe(year_df, hide_index=True, use_container_width=True)
|
| 179 |
+
else:
|
| 180 |
+
st.write("No year data")
|
| 181 |
+
else:
|
| 182 |
+
st.write("No year data")
|
| 183 |
+
|
| 184 |
+
with col4:
|
| 185 |
+
st.markdown("**π Files**")
|
| 186 |
+
if stats['filename_distribution']:
|
| 187 |
+
filename_items = list(stats['filename_distribution'].items())
|
| 188 |
+
filename_items.sort(key=lambda x: x[1], reverse=True)
|
| 189 |
+
|
| 190 |
+
# Show top files with truncated names
|
| 191 |
+
file_data = {
|
| 192 |
+
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
|
| 193 |
+
"Count": [c for f, c in filename_items[:5]]
|
| 194 |
+
}
|
| 195 |
+
file_df = pd.DataFrame(file_data)
|
| 196 |
+
st.dataframe(file_df, hide_index=True, use_container_width=True)
|
| 197 |
+
else:
|
| 198 |
+
st.write("No file data")
|
| 199 |
+
|
| 200 |
+
# Close container
|
| 201 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 202 |
+
|
src/ui_components/styles.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom CSS styles for Streamlit application
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_custom_css() -> str:
|
| 7 |
+
"""Get custom CSS styles as a string"""
|
| 8 |
+
return """
|
| 9 |
+
<style>
|
| 10 |
+
.main-header {
|
| 11 |
+
font-size: 2.5rem;
|
| 12 |
+
font-weight: bold;
|
| 13 |
+
color: #1f77b4;
|
| 14 |
+
text-align: center;
|
| 15 |
+
margin-bottom: 1rem;
|
| 16 |
+
width: 100%;
|
| 17 |
+
display: block;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.subtitle {
|
| 21 |
+
font-size: 1.2rem;
|
| 22 |
+
color: #666;
|
| 23 |
+
text-align: center;
|
| 24 |
+
margin-bottom: 2rem;
|
| 25 |
+
width: 100%;
|
| 26 |
+
display: block;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
.session-info {
|
| 30 |
+
background-color: #f0f2f6;
|
| 31 |
+
padding: 10px;
|
| 32 |
+
border-radius: 5px;
|
| 33 |
+
margin-bottom: 20px;
|
| 34 |
+
font-size: 0.9rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.user-message {
|
| 38 |
+
background-color: #007bff;
|
| 39 |
+
color: white;
|
| 40 |
+
padding: 12px 16px;
|
| 41 |
+
border-radius: 18px 18px 4px 18px;
|
| 42 |
+
margin: 8px 0;
|
| 43 |
+
margin-left: 20%;
|
| 44 |
+
word-wrap: break-word;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.bot-message {
|
| 48 |
+
background-color: #f1f3f4;
|
| 49 |
+
color: #333;
|
| 50 |
+
padding: 12px 16px;
|
| 51 |
+
border-radius: 18px 18px 18px 4px;
|
| 52 |
+
margin: 8px 0;
|
| 53 |
+
margin-right: 20%;
|
| 54 |
+
word-wrap: break-word;
|
| 55 |
+
border: 1px solid #e0e0e0;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.filter-section {
|
| 59 |
+
margin-bottom: 20px;
|
| 60 |
+
padding: 15px;
|
| 61 |
+
background-color: #f8f9fa;
|
| 62 |
+
border-radius: 8px;
|
| 63 |
+
border: 1px solid #e9ecef;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.filter-title {
|
| 67 |
+
font-weight: bold;
|
| 68 |
+
margin-bottom: 10px;
|
| 69 |
+
color: #495057;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.feedback-section {
|
| 73 |
+
background-color: #f8f9fa;
|
| 74 |
+
padding: 20px;
|
| 75 |
+
border-radius: 10px;
|
| 76 |
+
margin-top: 30px;
|
| 77 |
+
border: 2px solid #dee2e6;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.retrieval-history {
|
| 81 |
+
background-color: #ffffff;
|
| 82 |
+
padding: 15px;
|
| 83 |
+
border-radius: 5px;
|
| 84 |
+
margin: 10px 0;
|
| 85 |
+
border-left: 4px solid #007bff;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.retrieval-distribution-container {
|
| 89 |
+
background-color: #ffffff;
|
| 90 |
+
padding: 25px;
|
| 91 |
+
border-radius: 10px;
|
| 92 |
+
margin: 20px 0;
|
| 93 |
+
border: 2px solid #e0e0e0;
|
| 94 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.metric-label {
|
| 98 |
+
font-size: 0.9rem;
|
| 99 |
+
color: #555;
|
| 100 |
+
margin-bottom: 5px;
|
| 101 |
+
text-align: center;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.metric-value {
|
| 105 |
+
font-size: 1.8rem;
|
| 106 |
+
font-weight: bold;
|
| 107 |
+
color: #000000;
|
| 108 |
+
text-align: center;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.metric-container {
|
| 112 |
+
text-align: center;
|
| 113 |
+
padding: 10px;
|
| 114 |
+
}
|
| 115 |
+
</style>
|
| 116 |
+
"""
|
| 117 |
+
|
src/ui_components/utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI utility functions for data processing and statistics
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
|
| 10 |
+
"""Extract statistics from retrieved chunks."""
|
| 11 |
+
if not sources:
|
| 12 |
+
return {}
|
| 13 |
+
|
| 14 |
+
sources_list = []
|
| 15 |
+
years = []
|
| 16 |
+
filenames = []
|
| 17 |
+
districts = []
|
| 18 |
+
|
| 19 |
+
for doc in sources:
|
| 20 |
+
metadata = getattr(doc, 'metadata', {})
|
| 21 |
+
|
| 22 |
+
# Extract source
|
| 23 |
+
source = metadata.get('source', 'Unknown')
|
| 24 |
+
sources_list.append(source)
|
| 25 |
+
|
| 26 |
+
# Extract year
|
| 27 |
+
year = metadata.get('year', 'Unknown')
|
| 28 |
+
if year and year != 'Unknown':
|
| 29 |
+
try:
|
| 30 |
+
# Convert to int first, then back to string to ensure it's a proper year
|
| 31 |
+
year_int = int(float(year)) # Handle both int and float strings
|
| 32 |
+
if 1900 <= year_int <= 2030: # Reasonable year range
|
| 33 |
+
years.append(str(year_int))
|
| 34 |
+
else:
|
| 35 |
+
years.append('Unknown')
|
| 36 |
+
except (ValueError, TypeError):
|
| 37 |
+
years.append('Unknown')
|
| 38 |
+
else:
|
| 39 |
+
years.append('Unknown')
|
| 40 |
+
|
| 41 |
+
# Extract filename
|
| 42 |
+
filename = metadata.get('filename', 'Unknown')
|
| 43 |
+
filenames.append(filename)
|
| 44 |
+
|
| 45 |
+
# Extract district
|
| 46 |
+
district = metadata.get('district', 'Unknown')
|
| 47 |
+
if district and district != 'Unknown':
|
| 48 |
+
districts.append(district)
|
| 49 |
+
else:
|
| 50 |
+
districts.append('Unknown')
|
| 51 |
+
|
| 52 |
+
# Count occurrences
|
| 53 |
+
source_counts = Counter(sources_list)
|
| 54 |
+
year_counts = Counter(years)
|
| 55 |
+
filename_counts = Counter(filenames)
|
| 56 |
+
district_counts = Counter(districts)
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
'total_chunks': len(sources),
|
| 60 |
+
'unique_sources': len(source_counts),
|
| 61 |
+
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
|
| 62 |
+
'unique_filenames': len(filename_counts),
|
| 63 |
+
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
|
| 64 |
+
'source_distribution': dict(source_counts),
|
| 65 |
+
'year_distribution': dict(year_counts),
|
| 66 |
+
'filename_distribution': dict(filename_counts),
|
| 67 |
+
'district_distribution': dict(district_counts),
|
| 68 |
+
'sources': sources_list,
|
| 69 |
+
'years': years,
|
| 70 |
+
'filenames': filenames,
|
| 71 |
+
'districts': districts
|
| 72 |
+
}
|
| 73 |
+
|
utils.py β src/utils.py
RENAMED
|
File without changes
|
src/vectorstore.py
CHANGED
|
@@ -28,11 +28,19 @@ class MatryoshkaEmbeddings(Embeddings):
|
|
| 28 |
|
| 29 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
-
|
| 32 |
-
self.model = SentenceTransformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 34 |
else:
|
| 35 |
# Use standard HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 37 |
|
| 38 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -76,12 +84,15 @@ class VectorStoreManager:
|
|
| 76 |
|
| 77 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 78 |
"""Create embeddings model from configuration."""
|
| 79 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
-
|
| 81 |
model_name = self.config["retriever"]["model"]
|
| 82 |
normalize = self.config["retriever"]["normalize"]
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
encode_kwargs = {
|
| 86 |
"normalize_embeddings": normalize,
|
| 87 |
"batch_size": 100,
|
|
|
|
| 28 |
|
| 29 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
+
# Explicitly load on CPU first to avoid meta tensor issues
|
| 32 |
+
self.model = SentenceTransformer(
|
| 33 |
+
model_name,
|
| 34 |
+
truncate_dim=truncate_dim,
|
| 35 |
+
device="cpu" # Load on CPU first, prevents meta tensor error
|
| 36 |
+
)
|
| 37 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 38 |
else:
|
| 39 |
# Use standard HuggingFaceEmbeddings
|
| 40 |
+
# Pass device="cpu" to prevent meta tensor issues
|
| 41 |
+
if "model_kwargs" not in kwargs:
|
| 42 |
+
kwargs["model_kwargs"] = {}
|
| 43 |
+
kwargs["model_kwargs"]["device"] = "cpu"
|
| 44 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 45 |
|
| 46 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
| 84 |
|
| 85 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 86 |
"""Create embeddings model from configuration."""
|
|
|
|
|
|
|
| 87 |
model_name = self.config["retriever"]["model"]
|
| 88 |
normalize = self.config["retriever"]["normalize"]
|
| 89 |
|
| 90 |
+
# Fix for meta tensor issue: explicitly load on CPU first
|
| 91 |
+
# This prevents HuggingFaceEmbeddings from trying to move meta tensors
|
| 92 |
+
# The model will be loaded on CPU and can be moved later if needed
|
| 93 |
+
model_kwargs = {
|
| 94 |
+
"device": "cpu" # Load on CPU first to avoid meta tensor issues
|
| 95 |
+
}
|
| 96 |
encode_kwargs = {
|
| 97 |
"normalize_embeddings": normalize,
|
| 98 |
"batch_size": 100,
|
upload_to_gemini_filestore.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Upload Documents to Google Gemini File Search Store
|
| 4 |
+
|
| 5 |
+
This script uploads PDF documents to a Gemini File Search store for RAG.
|
| 6 |
+
It processes documents from the reports directory and uploads them with metadata.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import List, Dict, Any, Optional
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from google import genai
|
| 19 |
+
from google.genai import types
|
| 20 |
+
GEMINI_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
GEMINI_AVAILABLE = False
|
| 23 |
+
print("β google-genai package not installed. Install with: pip install google-genai")
|
| 24 |
+
|
| 25 |
+
# Load .env file
|
| 26 |
+
load_dotenv()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def extract_metadata_from_path(file_path: Path) -> Dict[str, Any]:
|
| 30 |
+
"""Extract metadata from file path structure."""
|
| 31 |
+
# Example: /path/to/reports/Annual Consolidated OAG audit reports 2018/Annual Consolidated OAG audit reports 2018.pdf
|
| 32 |
+
parts = file_path.parts
|
| 33 |
+
filename = file_path.stem # Without extension
|
| 34 |
+
|
| 35 |
+
metadata = {
|
| 36 |
+
"filename": file_path.name,
|
| 37 |
+
"filepath": str(file_path),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Extract year
|
| 41 |
+
year_match = None
|
| 42 |
+
for part in parts:
|
| 43 |
+
if any(year in part for year in ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025']):
|
| 44 |
+
for year in ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025']:
|
| 45 |
+
if year in part:
|
| 46 |
+
year_match = year
|
| 47 |
+
break
|
| 48 |
+
if year_match:
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
if year_match:
|
| 52 |
+
metadata["year"] = year_match
|
| 53 |
+
|
| 54 |
+
# Extract source/district
|
| 55 |
+
filename_lower = filename.lower()
|
| 56 |
+
if "consolidated" in filename_lower or "oag" in filename_lower:
|
| 57 |
+
metadata["source"] = "Consolidated"
|
| 58 |
+
elif "gulu" in filename_lower:
|
| 59 |
+
metadata["source"] = "Gulu DLG"
|
| 60 |
+
metadata["district"] = "Gulu"
|
| 61 |
+
elif "kalangala" in filename_lower:
|
| 62 |
+
metadata["source"] = "Kalangala DLG"
|
| 63 |
+
metadata["district"] = "Kalangala"
|
| 64 |
+
elif "kcca" in filename_lower:
|
| 65 |
+
metadata["source"] = "KCCA"
|
| 66 |
+
metadata["district"] = "Kampala"
|
| 67 |
+
elif "maaif" in filename_lower:
|
| 68 |
+
metadata["source"] = "MAAIF"
|
| 69 |
+
elif "mwts" in filename_lower:
|
| 70 |
+
metadata["source"] = "MWTS"
|
| 71 |
+
|
| 72 |
+
return metadata
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_or_create_filestore(client: genai.Client, store_name: Optional[str] = None) -> str:
|
| 76 |
+
"""Get existing file search store or create a new one."""
|
| 77 |
+
if store_name:
|
| 78 |
+
# Try to get existing store
|
| 79 |
+
try:
|
| 80 |
+
stores = client.file_search_stores.list()
|
| 81 |
+
for store in stores:
|
| 82 |
+
if store.name == store_name or store.display_name == store_name:
|
| 83 |
+
print(f"β
Using existing store: {store.display_name} ({store.name})")
|
| 84 |
+
return store.name
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"β οΈ Could not list stores: {e}")
|
| 87 |
+
|
| 88 |
+
# Create new store
|
| 89 |
+
display_name = store_name or "Audit Reports"
|
| 90 |
+
print(f"π Creating new file search store: '{display_name}'...")
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
file_search_store = client.file_search_stores.create(
|
| 94 |
+
config={'display_name': display_name}
|
| 95 |
+
)
|
| 96 |
+
print(f"β
Created store: {file_search_store.display_name} ({file_search_store.name})")
|
| 97 |
+
return file_search_store.name
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"β Failed to create store: {e}")
|
| 100 |
+
raise
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def format_metadata_for_gemini(metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 104 |
+
"""Format metadata dictionary for Gemini API customMetadata format.
|
| 105 |
+
|
| 106 |
+
Based on Gemini API, customMetadata should use:
|
| 107 |
+
- string_value for string fields
|
| 108 |
+
- numeric_value for numeric fields
|
| 109 |
+
"""
|
| 110 |
+
custom_metadata = []
|
| 111 |
+
|
| 112 |
+
# Add year if available (as numeric_value)
|
| 113 |
+
if metadata.get('year'):
|
| 114 |
+
try:
|
| 115 |
+
year_int = int(metadata['year'])
|
| 116 |
+
custom_metadata.append({
|
| 117 |
+
'key': 'year',
|
| 118 |
+
'numeric_value': year_int
|
| 119 |
+
})
|
| 120 |
+
except (ValueError, TypeError):
|
| 121 |
+
# Fallback to string if not numeric
|
| 122 |
+
custom_metadata.append({
|
| 123 |
+
'key': 'year',
|
| 124 |
+
'string_value': str(metadata['year'])
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# Add source if available (as string_value)
|
| 128 |
+
if metadata.get('source'):
|
| 129 |
+
custom_metadata.append({
|
| 130 |
+
'key': 'source',
|
| 131 |
+
'string_value': str(metadata['source'])
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
# Add district if available (as string_value)
|
| 135 |
+
if metadata.get('district'):
|
| 136 |
+
custom_metadata.append({
|
| 137 |
+
'key': 'district',
|
| 138 |
+
'string_value': str(metadata['district'])
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
# Add filename for reference (as string_value)
|
| 142 |
+
if metadata.get('filename'):
|
| 143 |
+
custom_metadata.append({
|
| 144 |
+
'key': 'filename',
|
| 145 |
+
'string_value': str(metadata['filename'])
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
return custom_metadata
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def check_file_exists(client: genai.Client, store_name: str, filename: str) -> bool:
|
| 152 |
+
"""Check if a file with the same name already exists in the store."""
|
| 153 |
+
try:
|
| 154 |
+
# List files in the store
|
| 155 |
+
store = client.file_search_stores.get(name=store_name)
|
| 156 |
+
# Note: The API might not have a direct list method, so we'll catch errors
|
| 157 |
+
return False # Assume not exists for now
|
| 158 |
+
except Exception:
|
| 159 |
+
return False # If we can't check, assume it doesn't exist
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def upload_file_to_store(
|
| 163 |
+
client: genai.Client,
|
| 164 |
+
file_path: Path,
|
| 165 |
+
store_name: str,
|
| 166 |
+
metadata: Dict[str, Any],
|
| 167 |
+
skip_existing: bool = True
|
| 168 |
+
) -> Optional[bool]:
|
| 169 |
+
"""Upload a single file to the file search store with metadata."""
|
| 170 |
+
try:
|
| 171 |
+
print(f" π€ Uploading: {file_path.name}...")
|
| 172 |
+
|
| 173 |
+
# Format metadata for Gemini API
|
| 174 |
+
custom_metadata = format_metadata_for_gemini(metadata)
|
| 175 |
+
|
| 176 |
+
# Display metadata being uploaded
|
| 177 |
+
if custom_metadata:
|
| 178 |
+
metadata_parts = []
|
| 179 |
+
for m in custom_metadata:
|
| 180 |
+
if 'numeric_value' in m:
|
| 181 |
+
metadata_parts.append(f"{m['key']}={m['numeric_value']}")
|
| 182 |
+
elif 'string_value' in m:
|
| 183 |
+
metadata_parts.append(f"{m['key']}={m['string_value']}")
|
| 184 |
+
if metadata_parts:
|
| 185 |
+
print(f" π Metadata: {', '.join(metadata_parts)}")
|
| 186 |
+
|
| 187 |
+
# Check if file already exists (if skip_existing is True)
|
| 188 |
+
if skip_existing:
|
| 189 |
+
# Note: We'll handle duplicates via error messages
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
# Upload and import file with metadata
|
| 193 |
+
# Note: Gemini API may not support customMetadata in upload_to_file_search_store
|
| 194 |
+
# We'll try with metadata first, then fallback without it if it fails
|
| 195 |
+
upload_params = {
|
| 196 |
+
'file': str(file_path),
|
| 197 |
+
'file_search_store_name': store_name,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# Build config
|
| 201 |
+
config = {
|
| 202 |
+
'display_name': metadata.get('filename', file_path.name),
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
# Upload file (metadata not supported in upload config per API)
|
| 206 |
+
# Note: Gemini File Search API doesn't support customMetadata in upload_to_file_search_store
|
| 207 |
+
# Metadata would need to be added via a separate API call after upload, if supported
|
| 208 |
+
# For now, we upload without metadata - the filename in display_name contains the info
|
| 209 |
+
upload_params['config'] = config
|
| 210 |
+
operation = client.file_search_stores.upload_to_file_search_store(**upload_params)
|
| 211 |
+
|
| 212 |
+
# Wait for import to complete
|
| 213 |
+
max_wait = 300 # 5 minutes max per file
|
| 214 |
+
start_time = time.time()
|
| 215 |
+
|
| 216 |
+
while not operation.done:
|
| 217 |
+
if time.time() - start_time > max_wait:
|
| 218 |
+
print(f" β οΈ Timeout waiting for upload to complete")
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
time.sleep(2)
|
| 222 |
+
try:
|
| 223 |
+
operation = client.operations.get(operation)
|
| 224 |
+
except Exception as op_error:
|
| 225 |
+
# Check if it's a "terminated" error (file might already exist)
|
| 226 |
+
error_str = str(op_error).lower()
|
| 227 |
+
if 'terminated' in error_str or 'already' in error_str:
|
| 228 |
+
print(f" β οΈ File may already exist or upload was interrupted")
|
| 229 |
+
print(f" π‘ Skipping this file")
|
| 230 |
+
return None # Return None to indicate "skipped"
|
| 231 |
+
raise
|
| 232 |
+
|
| 233 |
+
# Check for errors in the operation result
|
| 234 |
+
if hasattr(operation, 'error') and operation.error:
|
| 235 |
+
error_msg = str(operation.error)
|
| 236 |
+
if 'terminated' in error_msg.lower() or 'already' in error_msg.lower():
|
| 237 |
+
print(f" β οΈ File may already exist in the store")
|
| 238 |
+
print(f" π‘ Skipping this file")
|
| 239 |
+
return None # Return None to indicate "skipped" vs False for "failed"
|
| 240 |
+
print(f" β Upload failed: {operation.error}")
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
print(f" β
Uploaded successfully")
|
| 244 |
+
return True
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
error_str = str(e).lower()
|
| 248 |
+
# Handle specific error cases
|
| 249 |
+
if 'terminated' in error_str or 'already' in error_str or '400' in error_str:
|
| 250 |
+
print(f" β οΈ Upload error: File may already exist or upload was interrupted")
|
| 251 |
+
print(f" π‘ Error details: {e}")
|
| 252 |
+
print(f" π‘ Skipping this file")
|
| 253 |
+
return None # Return None to indicate "skipped"
|
| 254 |
+
print(f" β Error uploading {file_path.name}: {e}")
|
| 255 |
+
import traceback
|
| 256 |
+
traceback.print_exc()
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def find_report_files(reports_dir: Path) -> List[Path]:
|
| 261 |
+
"""Find all PDF report files in the reports directory."""
|
| 262 |
+
pdf_files = []
|
| 263 |
+
|
| 264 |
+
if not reports_dir.exists():
|
| 265 |
+
print(f"β Reports directory not found: {reports_dir}")
|
| 266 |
+
return pdf_files
|
| 267 |
+
|
| 268 |
+
# Find all PDF files
|
| 269 |
+
for pdf_file in reports_dir.rglob("*.pdf"):
|
| 270 |
+
pdf_files.append(pdf_file)
|
| 271 |
+
|
| 272 |
+
return sorted(pdf_files)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main():
|
| 276 |
+
"""Main function to upload documents to Gemini File Search store."""
|
| 277 |
+
print("=" * 60)
|
| 278 |
+
print("Gemini File Search Store Upload Tool")
|
| 279 |
+
print("=" * 60)
|
| 280 |
+
|
| 281 |
+
if not GEMINI_AVAILABLE:
|
| 282 |
+
print("\nβ Please install google-genai package:")
|
| 283 |
+
print(" pip install google-genai")
|
| 284 |
+
return 1
|
| 285 |
+
|
| 286 |
+
# Get API key
|
| 287 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 288 |
+
if not api_key:
|
| 289 |
+
print("\nβ GEMINI_API_KEY not found in environment variables")
|
| 290 |
+
print(" Please add GEMINI_API_KEY to your .env file")
|
| 291 |
+
return 1
|
| 292 |
+
|
| 293 |
+
# Get store name (optional)
|
| 294 |
+
store_name = os.getenv("GEMINI_FILESTORE_NAME")
|
| 295 |
+
|
| 296 |
+
# Get reports directory - try multiple possible locations
|
| 297 |
+
reports_dir_str = os.getenv("REPORTS_DIR")
|
| 298 |
+
if not reports_dir_str:
|
| 299 |
+
# Try common locations
|
| 300 |
+
possible_paths = [
|
| 301 |
+
"/Users/ayeroyan/workspace/chatbot-rag/reports",
|
| 302 |
+
Path(__file__).parent / "reports",
|
| 303 |
+
Path.cwd() / "reports",
|
| 304 |
+
]
|
| 305 |
+
for path in possible_paths:
|
| 306 |
+
if Path(path).exists():
|
| 307 |
+
reports_dir_str = str(path)
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
if not reports_dir_str:
|
| 311 |
+
reports_dir_str = "/Users/ayeroyan/workspace/chatbot-rag/reports" # Default fallback
|
| 312 |
+
|
| 313 |
+
reports_dir = Path(reports_dir_str)
|
| 314 |
+
|
| 315 |
+
# Initialize Gemini client
|
| 316 |
+
print(f"\nπ Connecting to Gemini API...")
|
| 317 |
+
try:
|
| 318 |
+
client = genai.Client(api_key=api_key)
|
| 319 |
+
print(f" β
Connected")
|
| 320 |
+
except Exception as e:
|
| 321 |
+
print(f" β Failed to connect: {e}")
|
| 322 |
+
return 1
|
| 323 |
+
|
| 324 |
+
# Get or create file search store
|
| 325 |
+
print(f"\nπ¦ Setting up file search store...")
|
| 326 |
+
try:
|
| 327 |
+
store_name = get_or_create_filestore(client, store_name)
|
| 328 |
+
except Exception as e:
|
| 329 |
+
print(f" β Failed to setup store: {e}")
|
| 330 |
+
return 1
|
| 331 |
+
|
| 332 |
+
# Find all PDF files
|
| 333 |
+
print(f"\nπ Scanning for PDF files in: {reports_dir}")
|
| 334 |
+
pdf_files = find_report_files(reports_dir)
|
| 335 |
+
|
| 336 |
+
if not pdf_files:
|
| 337 |
+
print(f" β No PDF files found in {reports_dir}")
|
| 338 |
+
return 1
|
| 339 |
+
|
| 340 |
+
print(f" β
Found {len(pdf_files)} PDF files")
|
| 341 |
+
|
| 342 |
+
# Upload files
|
| 343 |
+
print(f"\nπ€ Uploading files to store...")
|
| 344 |
+
print(f" Store: {store_name}")
|
| 345 |
+
print(f" Files: {len(pdf_files)}")
|
| 346 |
+
|
| 347 |
+
uploaded = 0
|
| 348 |
+
failed = 0
|
| 349 |
+
skipped = 0
|
| 350 |
+
|
| 351 |
+
for i, pdf_file in enumerate(pdf_files, 1):
|
| 352 |
+
print(f"\n[{i}/{len(pdf_files)}] Processing: {pdf_file.name}")
|
| 353 |
+
|
| 354 |
+
# Extract metadata
|
| 355 |
+
metadata = extract_metadata_from_path(pdf_file)
|
| 356 |
+
|
| 357 |
+
# Display extracted metadata
|
| 358 |
+
metadata_info = []
|
| 359 |
+
if metadata.get('year'):
|
| 360 |
+
metadata_info.append(f"Year: {metadata['year']}")
|
| 361 |
+
if metadata.get('source'):
|
| 362 |
+
metadata_info.append(f"Source: {metadata['source']}")
|
| 363 |
+
if metadata.get('district'):
|
| 364 |
+
metadata_info.append(f"District: {metadata['district']}")
|
| 365 |
+
|
| 366 |
+
if metadata_info:
|
| 367 |
+
print(f" π Extracted metadata: {', '.join(metadata_info)}")
|
| 368 |
+
|
| 369 |
+
# Upload file with metadata
|
| 370 |
+
result = upload_file_to_store(client, pdf_file, store_name, metadata, skip_existing=True)
|
| 371 |
+
|
| 372 |
+
if result is True:
|
| 373 |
+
uploaded += 1
|
| 374 |
+
elif result is None: # Skipped (already exists)
|
| 375 |
+
skipped += 1
|
| 376 |
+
else: # Failed
|
| 377 |
+
failed += 1
|
| 378 |
+
|
| 379 |
+
# Small delay between uploads to avoid rate limits
|
| 380 |
+
if i < len(pdf_files):
|
| 381 |
+
time.sleep(1)
|
| 382 |
+
|
| 383 |
+
# Summary
|
| 384 |
+
print(f"\n" + "=" * 60)
|
| 385 |
+
print(f"Upload Summary")
|
| 386 |
+
print(f"=" * 60)
|
| 387 |
+
print(f" β
Uploaded: {uploaded}")
|
| 388 |
+
if skipped > 0:
|
| 389 |
+
print(f" βοΈ Skipped (already exists): {skipped}")
|
| 390 |
+
print(f" β Failed: {failed}")
|
| 391 |
+
print(f" π¦ Store: {store_name}")
|
| 392 |
+
|
| 393 |
+
if uploaded > 0:
|
| 394 |
+
print(f"\nβ
Successfully uploaded {uploaded} files to Gemini File Search store!")
|
| 395 |
+
print(f" You can now use this store in the beta version of the chatbot.")
|
| 396 |
+
|
| 397 |
+
return 0 if failed == 0 else 1
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
sys.exit(main())
|
| 402 |
+
|
verify_qdrant_migration.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Qdrant Migration Verification Script
|
| 4 |
+
|
| 5 |
+
This script compares the source and destination Qdrant collections to verify
|
| 6 |
+
that the migration was successful. It:
|
| 7 |
+
1. Compares collection configurations
|
| 8 |
+
2. Fetches sample points from source
|
| 9 |
+
3. Retrieves same points from destination using IDs
|
| 10 |
+
4. Compares vectors, metadata, and all attributes
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import List, Dict, Any, Optional
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from qdrant_client import QdrantClient
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
# Try to import config loader and dotenv for automatic source detection
|
| 21 |
+
try:
|
| 22 |
+
from src.config.loader import load_config
|
| 23 |
+
CONFIG_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
CONFIG_AVAILABLE = False
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from dotenv import load_dotenv
|
| 29 |
+
DOTENV_AVAILABLE = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
DOTENV_AVAILABLE = False
|
| 32 |
+
|
| 33 |
+
# Load .env file automatically if available
|
| 34 |
+
if DOTENV_AVAILABLE:
|
| 35 |
+
project_root = Path(__file__).parent
|
| 36 |
+
env_file = project_root / ".env"
|
| 37 |
+
if env_file.exists():
|
| 38 |
+
load_dotenv(env_file, override=True)
|
| 39 |
+
else:
|
| 40 |
+
load_dotenv(override=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_collection_info(client: QdrantClient, collection_name: str) -> Dict[str, Any]:
|
| 44 |
+
"""Get collection information including vector size and point count."""
|
| 45 |
+
try:
|
| 46 |
+
collection_info = client.get_collection(collection_name)
|
| 47 |
+
|
| 48 |
+
# Handle different Qdrant client versions and response formats
|
| 49 |
+
if hasattr(collection_info, 'config'):
|
| 50 |
+
config = collection_info.config
|
| 51 |
+
if hasattr(config, 'params') and hasattr(config.params, 'vectors'):
|
| 52 |
+
vectors_config = config.params.vectors
|
| 53 |
+
if isinstance(vectors_config, dict):
|
| 54 |
+
vector_size = vectors_config.get('size')
|
| 55 |
+
distance = vectors_config.get('distance')
|
| 56 |
+
else:
|
| 57 |
+
vector_size = getattr(vectors_config, 'size', None)
|
| 58 |
+
distance = getattr(vectors_config, 'distance', None)
|
| 59 |
+
else:
|
| 60 |
+
vector_size = getattr(config, 'vector_size', None)
|
| 61 |
+
distance = getattr(config, 'distance', None)
|
| 62 |
+
else:
|
| 63 |
+
vector_size = getattr(collection_info, 'vector_size', None)
|
| 64 |
+
distance = getattr(collection_info, 'distance', None)
|
| 65 |
+
|
| 66 |
+
points_count = getattr(collection_info, 'points_count', 0)
|
| 67 |
+
indexed_vectors_count = getattr(collection_info, 'indexed_vectors_count', 0)
|
| 68 |
+
|
| 69 |
+
if vector_size is None:
|
| 70 |
+
try:
|
| 71 |
+
result, _ = client.scroll(collection_name=collection_name, limit=1, with_vectors=True)
|
| 72 |
+
if result and hasattr(result[0], 'vector') and result[0].vector:
|
| 73 |
+
vector_size = len(result[0].vector)
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"vector_size": vector_size,
|
| 79 |
+
"distance": distance or "Cosine",
|
| 80 |
+
"points_count": points_count,
|
| 81 |
+
"indexed_vectors_count": indexed_vectors_count,
|
| 82 |
+
}
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"β Error getting collection info: {e}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def fetch_points_by_ids(client: QdrantClient, collection_name: str, point_ids: List) -> Dict:
|
| 89 |
+
"""Fetch points by their IDs from a collection."""
|
| 90 |
+
try:
|
| 91 |
+
points = client.retrieve(
|
| 92 |
+
collection_name=collection_name,
|
| 93 |
+
ids=point_ids,
|
| 94 |
+
with_payload=True,
|
| 95 |
+
with_vectors=True
|
| 96 |
+
)
|
| 97 |
+
return {point.id: point for point in points}
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"β Error fetching points by IDs: {e}")
|
| 100 |
+
return {}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def compare_points(source_point, dest_point, point_id) -> Dict[str, Any]:
|
| 104 |
+
"""Compare two points and return differences."""
|
| 105 |
+
differences = []
|
| 106 |
+
matches = []
|
| 107 |
+
|
| 108 |
+
# Compare IDs
|
| 109 |
+
if source_point.id == dest_point.id:
|
| 110 |
+
matches.append("ID")
|
| 111 |
+
else:
|
| 112 |
+
differences.append(f"ID: source={source_point.id}, dest={dest_point.id}")
|
| 113 |
+
|
| 114 |
+
# Compare vectors
|
| 115 |
+
source_vec = getattr(source_point, 'vector', None)
|
| 116 |
+
dest_vec = getattr(dest_point, 'vector', None)
|
| 117 |
+
|
| 118 |
+
if source_vec is None and dest_vec is None:
|
| 119 |
+
matches.append("Vector (both None)")
|
| 120 |
+
elif source_vec is None or dest_vec is None:
|
| 121 |
+
differences.append(f"Vector: source={'None' if source_vec is None else f'len={len(source_vec)}'}, dest={'None' if dest_vec is None else f'len={len(dest_vec)}'}")
|
| 122 |
+
elif len(source_vec) != len(dest_vec):
|
| 123 |
+
differences.append(f"Vector length: source={len(source_vec)}, dest={len(dest_vec)}")
|
| 124 |
+
else:
|
| 125 |
+
# Compare vector values (with tolerance for floating point)
|
| 126 |
+
import numpy as np
|
| 127 |
+
try:
|
| 128 |
+
vec_diff = np.abs(np.array(source_vec) - np.array(dest_vec))
|
| 129 |
+
max_diff = float(np.max(vec_diff))
|
| 130 |
+
if max_diff < 1e-6:
|
| 131 |
+
matches.append(f"Vector (max diff: {max_diff:.2e})")
|
| 132 |
+
else:
|
| 133 |
+
differences.append(f"Vector values differ (max diff: {max_diff:.2e})")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
differences.append(f"Vector comparison error: {e}")
|
| 136 |
+
|
| 137 |
+
# Compare payloads
|
| 138 |
+
source_payload = getattr(source_point, 'payload', {}) or {}
|
| 139 |
+
dest_payload = getattr(dest_point, 'payload', {}) or {}
|
| 140 |
+
|
| 141 |
+
# Convert to dicts if needed
|
| 142 |
+
if hasattr(source_payload, '__dict__'):
|
| 143 |
+
source_payload = source_payload.__dict__
|
| 144 |
+
if hasattr(dest_payload, '__dict__'):
|
| 145 |
+
dest_payload = dest_payload.__dict__
|
| 146 |
+
|
| 147 |
+
source_keys = set(source_payload.keys())
|
| 148 |
+
dest_keys = set(dest_payload.keys())
|
| 149 |
+
|
| 150 |
+
if source_keys != dest_keys:
|
| 151 |
+
missing_in_dest = source_keys - dest_keys
|
| 152 |
+
extra_in_dest = dest_keys - source_keys
|
| 153 |
+
if missing_in_dest:
|
| 154 |
+
differences.append(f"Payload keys missing in dest: {missing_in_dest}")
|
| 155 |
+
if extra_in_dest:
|
| 156 |
+
differences.append(f"Payload keys extra in dest: {extra_in_dest}")
|
| 157 |
+
|
| 158 |
+
# Compare payload values
|
| 159 |
+
common_keys = source_keys & dest_keys
|
| 160 |
+
for key in common_keys:
|
| 161 |
+
source_val = source_payload[key]
|
| 162 |
+
dest_val = dest_payload[key]
|
| 163 |
+
|
| 164 |
+
if source_val == dest_val:
|
| 165 |
+
matches.append(f"Payload.{key}")
|
| 166 |
+
else:
|
| 167 |
+
# Handle nested structures
|
| 168 |
+
if isinstance(source_val, dict) and isinstance(dest_val, dict):
|
| 169 |
+
if source_val != dest_val:
|
| 170 |
+
differences.append(f"Payload.{key}: dicts differ")
|
| 171 |
+
elif isinstance(source_val, list) and isinstance(dest_val, list):
|
| 172 |
+
if source_val != dest_val:
|
| 173 |
+
differences.append(f"Payload.{key}: lists differ (len: {len(source_val)} vs {len(dest_val)})")
|
| 174 |
+
else:
|
| 175 |
+
differences.append(f"Payload.{key}: '{source_val}' != '{dest_val}'")
|
| 176 |
+
|
| 177 |
+
return {
|
| 178 |
+
"point_id": point_id,
|
| 179 |
+
"matches": matches,
|
| 180 |
+
"differences": differences,
|
| 181 |
+
"match_count": len(matches),
|
| 182 |
+
"diff_count": len(differences)
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
print("="*70)
|
| 188 |
+
print("Qdrant Migration Verification Script")
|
| 189 |
+
print("="*70)
|
| 190 |
+
|
| 191 |
+
# Auto-detect source from config and .env file
|
| 192 |
+
source_url = os.getenv('QDRANT_URL')
|
| 193 |
+
source_key = os.getenv('QDRANT_API_KEY')
|
| 194 |
+
source_collection = os.getenv('QDRANT_COLLECTION', 'docling')
|
| 195 |
+
|
| 196 |
+
if CONFIG_AVAILABLE:
|
| 197 |
+
try:
|
| 198 |
+
config = load_config()
|
| 199 |
+
qdrant_config = config.get('qdrant', {})
|
| 200 |
+
if not source_url:
|
| 201 |
+
source_url = qdrant_config.get('url')
|
| 202 |
+
if not source_key:
|
| 203 |
+
source_key = qdrant_config.get('api_key')
|
| 204 |
+
if not source_collection:
|
| 205 |
+
source_collection = qdrant_config.get('collection_name', 'docling')
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f"β οΈ Could not load config: {e}")
|
| 208 |
+
|
| 209 |
+
# Get destination from env
|
| 210 |
+
dest_url = os.getenv('DEST_QDRANT_URL')
|
| 211 |
+
dest_key = os.getenv('DEST_QDRANT_API_KEY')
|
| 212 |
+
dest_collection = os.getenv('DEST_COLLECTION') # Optional, will auto-detect
|
| 213 |
+
|
| 214 |
+
# Validate
|
| 215 |
+
if not source_url or not source_key:
|
| 216 |
+
print("β Source Qdrant credentials missing!")
|
| 217 |
+
print(" Set QDRANT_URL and QDRANT_API_KEY in .env or environment")
|
| 218 |
+
return 1
|
| 219 |
+
|
| 220 |
+
if not dest_url or not dest_key:
|
| 221 |
+
print("β Destination Qdrant credentials missing!")
|
| 222 |
+
print(" Set DEST_QDRANT_URL and DEST_QDRANT_API_KEY in .env or environment")
|
| 223 |
+
return 1
|
| 224 |
+
|
| 225 |
+
print(f"\nπ Configuration:")
|
| 226 |
+
print(f" Source: {source_url}")
|
| 227 |
+
print(f" Source Collection: {source_collection}")
|
| 228 |
+
print(f" Destination: {dest_url}")
|
| 229 |
+
if dest_collection:
|
| 230 |
+
print(f" Destination Collection: {dest_collection} (specified)")
|
| 231 |
+
else:
|
| 232 |
+
print(f" Destination Collection: (auto-detect)")
|
| 233 |
+
|
| 234 |
+
# Connect to Qdrant instances
|
| 235 |
+
print(f"\nπ Connecting to Qdrant instances...")
|
| 236 |
+
try:
|
| 237 |
+
source_client = QdrantClient(url=source_url, api_key=source_key, timeout=120)
|
| 238 |
+
print(f" β
Connected to source")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f" β Failed to connect to source: {e}")
|
| 241 |
+
return 1
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
dest_client = QdrantClient(url=dest_url, api_key=dest_key, timeout=120)
|
| 245 |
+
print(f" β
Connected to destination")
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f" β Failed to connect to destination: {e}")
|
| 248 |
+
return 1
|
| 249 |
+
|
| 250 |
+
# Auto-detect destination collection if not specified
|
| 251 |
+
if not dest_collection:
|
| 252 |
+
try:
|
| 253 |
+
collections = dest_client.get_collections().collections
|
| 254 |
+
collection_names = [c.name for c in collections]
|
| 255 |
+
if len(collection_names) == 1:
|
| 256 |
+
dest_collection = collection_names[0]
|
| 257 |
+
print(f"\nπ Auto-detected destination collection: '{dest_collection}'")
|
| 258 |
+
elif len(collection_names) > 1:
|
| 259 |
+
print(f"\nβ οΈ Found {len(collection_names)} collections in destination:")
|
| 260 |
+
for name in collection_names:
|
| 261 |
+
print(f" - {name}")
|
| 262 |
+
print(f"\n Using first collection: '{collection_names[0]}'")
|
| 263 |
+
dest_collection = collection_names[0]
|
| 264 |
+
else:
|
| 265 |
+
print("β No collections found in destination!")
|
| 266 |
+
return 1
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"β Could not list destination collections: {e}")
|
| 269 |
+
return 1
|
| 270 |
+
|
| 271 |
+
# Get collection info
|
| 272 |
+
print(f"\nπ Collection Information Comparison")
|
| 273 |
+
print("="*70)
|
| 274 |
+
|
| 275 |
+
source_info = get_collection_info(source_client, source_collection)
|
| 276 |
+
dest_info = get_collection_info(dest_client, dest_collection)
|
| 277 |
+
|
| 278 |
+
if not source_info:
|
| 279 |
+
print("β Could not get source collection info")
|
| 280 |
+
return 1
|
| 281 |
+
|
| 282 |
+
if not dest_info:
|
| 283 |
+
print("β Could not get destination collection info")
|
| 284 |
+
return 1
|
| 285 |
+
|
| 286 |
+
print(f"\nSource Collection ('{source_collection}'):")
|
| 287 |
+
print(f" Vector size: {source_info['vector_size']}")
|
| 288 |
+
print(f" Distance: {source_info['distance']}")
|
| 289 |
+
print(f" Points: {source_info['points_count']:,}")
|
| 290 |
+
print(f" Indexed: {source_info['indexed_vectors_count']:,}")
|
| 291 |
+
|
| 292 |
+
print(f"\nDestination Collection ('{dest_collection}'):")
|
| 293 |
+
print(f" Vector size: {dest_info['vector_size']}")
|
| 294 |
+
print(f" Distance: {dest_info['distance']}")
|
| 295 |
+
print(f" Points: {dest_info['points_count']:,}")
|
| 296 |
+
print(f" Indexed: {dest_info['indexed_vectors_count']:,}")
|
| 297 |
+
|
| 298 |
+
# Compare configs
|
| 299 |
+
print(f"\nπ Configuration Comparison:")
|
| 300 |
+
config_matches = []
|
| 301 |
+
config_diffs = []
|
| 302 |
+
|
| 303 |
+
if source_info['vector_size'] == dest_info['vector_size']:
|
| 304 |
+
config_matches.append(f"Vector size: {source_info['vector_size']}")
|
| 305 |
+
else:
|
| 306 |
+
config_diffs.append(f"Vector size: source={source_info['vector_size']}, dest={dest_info['vector_size']}")
|
| 307 |
+
|
| 308 |
+
if str(source_info['distance']) == str(dest_info['distance']):
|
| 309 |
+
config_matches.append(f"Distance: {source_info['distance']}")
|
| 310 |
+
else:
|
| 311 |
+
config_diffs.append(f"Distance: source={source_info['distance']}, dest={dest_info['distance']}")
|
| 312 |
+
|
| 313 |
+
if source_info['points_count'] == dest_info['points_count']:
|
| 314 |
+
config_matches.append(f"Points count: {source_info['points_count']:,}")
|
| 315 |
+
else:
|
| 316 |
+
config_diffs.append(f"Points count: source={source_info['points_count']:,}, dest={dest_info['points_count']:,}")
|
| 317 |
+
|
| 318 |
+
if config_matches:
|
| 319 |
+
print(f" β
Matches: {len(config_matches)}")
|
| 320 |
+
for match in config_matches:
|
| 321 |
+
print(f" - {match}")
|
| 322 |
+
|
| 323 |
+
if config_diffs:
|
| 324 |
+
print(f" β Differences: {len(config_diffs)}")
|
| 325 |
+
for diff in config_diffs:
|
| 326 |
+
print(f" - {diff}")
|
| 327 |
+
|
| 328 |
+
# Fetch sample points from source
|
| 329 |
+
print(f"\nπ₯ Fetching sample points from source...")
|
| 330 |
+
sample_size = 2000 # Fetch 20 sample points
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
source_points_result, _ = source_client.scroll(
|
| 334 |
+
collection_name=source_collection,
|
| 335 |
+
limit=sample_size,
|
| 336 |
+
with_payload=True,
|
| 337 |
+
with_vectors=True
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if not source_points_result:
|
| 341 |
+
print("β No points found in source collection!")
|
| 342 |
+
return 1
|
| 343 |
+
|
| 344 |
+
print(f" β
Fetched {len(source_points_result)} points from source")
|
| 345 |
+
|
| 346 |
+
# Extract point IDs
|
| 347 |
+
source_point_ids = [point.id for point in source_points_result]
|
| 348 |
+
print(f" Point IDs: {source_point_ids[:5]}{'...' if len(source_point_ids) > 5 else ''}")
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
print(f"β Error fetching source points: {e}")
|
| 352 |
+
import traceback
|
| 353 |
+
traceback.print_exc()
|
| 354 |
+
return 1
|
| 355 |
+
|
| 356 |
+
# Fetch same points from destination
|
| 357 |
+
print(f"\nπ₯ Fetching same points from destination by ID...")
|
| 358 |
+
try:
|
| 359 |
+
dest_points_dict = fetch_points_by_ids(dest_client, dest_collection, source_point_ids)
|
| 360 |
+
print(f" β
Fetched {len(dest_points_dict)} points from destination")
|
| 361 |
+
|
| 362 |
+
missing_ids = set(source_point_ids) - set(dest_points_dict.keys())
|
| 363 |
+
if missing_ids:
|
| 364 |
+
print(f" β οΈ Missing {len(missing_ids)} points in destination: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
|
| 365 |
+
|
| 366 |
+
except Exception as e:
|
| 367 |
+
print(f"β Error fetching destination points: {e}")
|
| 368 |
+
import traceback
|
| 369 |
+
traceback.print_exc()
|
| 370 |
+
return 1
|
| 371 |
+
|
| 372 |
+
# Compare points
|
| 373 |
+
print(f"\nπ Point-by-Point Comparison")
|
| 374 |
+
print("="*70)
|
| 375 |
+
|
| 376 |
+
comparison_results = []
|
| 377 |
+
for source_point in source_points_result:
|
| 378 |
+
point_id = source_point.id
|
| 379 |
+
dest_point = dest_points_dict.get(point_id)
|
| 380 |
+
|
| 381 |
+
if dest_point is None:
|
| 382 |
+
comparison_results.append({
|
| 383 |
+
"point_id": point_id,
|
| 384 |
+
"status": "MISSING",
|
| 385 |
+
"matches": [],
|
| 386 |
+
"differences": [f"Point not found in destination"]
|
| 387 |
+
})
|
| 388 |
+
else:
|
| 389 |
+
comparison = compare_points(source_point, dest_point, point_id)
|
| 390 |
+
comparison["status"] = "MATCH" if comparison["diff_count"] == 0 else "DIFF"
|
| 391 |
+
comparison_results.append(comparison)
|
| 392 |
+
|
| 393 |
+
# Summary
|
| 394 |
+
matches = [r for r in comparison_results if r["status"] == "MATCH"]
|
| 395 |
+
diffs = [r for r in comparison_results if r["status"] == "DIFF"]
|
| 396 |
+
missing = [r for r in comparison_results if r["status"] == "MISSING"]
|
| 397 |
+
|
| 398 |
+
print(f"\nπ Comparison Summary:")
|
| 399 |
+
print(f" Total points compared: {len(comparison_results)}")
|
| 400 |
+
print(f" β
Perfect matches: {len(matches)}")
|
| 401 |
+
print(f" β οΈ Differences found: {len(diffs)}")
|
| 402 |
+
print(f" β Missing in destination: {len(missing)}")
|
| 403 |
+
|
| 404 |
+
# Show details for points with differences
|
| 405 |
+
if diffs:
|
| 406 |
+
print(f"\nβ οΈ Points with differences:")
|
| 407 |
+
for diff_result in diffs[:10]: # Show first 10
|
| 408 |
+
print(f"\n Point ID: {diff_result['point_id']}")
|
| 409 |
+
if diff_result['matches']:
|
| 410 |
+
print(f" β
Matches ({len(diff_result['matches'])}): {', '.join(diff_result['matches'][:5])}")
|
| 411 |
+
if diff_result['differences']:
|
| 412 |
+
print(f" β Differences ({len(diff_result['differences'])}):")
|
| 413 |
+
for d in diff_result['differences'][:5]:
|
| 414 |
+
print(f" - {d}")
|
| 415 |
+
|
| 416 |
+
if missing:
|
| 417 |
+
print(f"\nβ Missing points in destination:")
|
| 418 |
+
for missing_result in missing[:10]:
|
| 419 |
+
print(f" - Point ID: {missing_result['point_id']}")
|
| 420 |
+
|
| 421 |
+
# Final verdict
|
| 422 |
+
print(f"\n" + "="*70)
|
| 423 |
+
if len(missing) == 0 and len(diffs) == 0:
|
| 424 |
+
print("β
VERIFICATION PASSED: All points match perfectly!")
|
| 425 |
+
return 0
|
| 426 |
+
elif len(missing) == 0:
|
| 427 |
+
print(f"β οΈ VERIFICATION PARTIAL: All points present but {len(diffs)} have differences")
|
| 428 |
+
return 1
|
| 429 |
+
else:
|
| 430 |
+
print(f"β VERIFICATION FAILED: {len(missing)} points missing, {len(diffs)} have differences")
|
| 431 |
+
return 1
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
sys.exit(main())
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
|