Spaces:
Sleeping
Gemini FSA (#6)
Browse files- refactor (1154c8ddd84455509a7f915a8588d951f925bbc6)
- refactor (72318ee79a0a0bcafca07cc5be70aace39a25f0c)
- add district Metadata (7c8b7838d143b45686e8360bbf64d2ef3c4a5624)
- refactor and add sample questions (02d7f4f76ed3415d70f279e6ab188435ae60fe92)
- add retrieval visualisations (5262a14ec5cccd4d5e9796cf7c881b211809f9fc)
- add Retrieval Distribution stats (763a8b9f7efaf0b823138df29267434d72edd477)
- Merge branch 'main' of https://huggingface.co/spaces/akryldigital/audit_assistant (264ca849bc4217b186eb0f8de0d23bce39856de7)
- Merge branch 'main' of https://huggingface.co/spaces/akryldigital/audit_assistant (b4984e28f7a41b312d3eba90b68651ab903f7d08)
- refactor + add gemini (72eb0bfa173ea05e8fd8e3b63429ef4678a01663)
- fix use_container_width=False (f8a1d4171fe8d1b3b84938b7b14bbe497e4159fe)
- finalize gemini version (3fc1b5f53b40772ba3c8abf400f1d987c12c4ee1)
- add gemini traceability (6f5999e84c97f19e3d1ff442873b52ef6ca8208e)
- fix gemini chunk extraction (06faccdb62f42a514e9f6b8cc93ca73f5cab5fa1)
- fix gemini chunk extraction (54bf55f7a03022a79da5ead74a289cb893efa88f)
- add upload debug (39edab443db3982dfc9f582868c66b9ab787a208)
- Remove scripts and ignore local_* files (de1d74a230264cf4ae8df516291869434f265d9f)
- .gitignore +4 -1
- app.py +284 -512
- src/agents/__init__.py +10 -0
- src/agents/gemini_chatbot.py +392 -0
- multi_agent_chatbot.py β src/agents/multi_agent_chatbot.py +271 -40
- smart_chatbot.py β src/agents/smart_chatbot.py +0 -0
- src/feedback/__init__.py +152 -0
- src/feedback/feedback_schema.py +161 -0
- src/feedback/snowflake_connector.py +331 -0
- src/gemini/__init__.py +11 -0
- src/gemini/file_search.py +427 -0
- src/{loader.py β llm/loader.py} +0 -0
- src/pipeline.py +1 -1
- src/reporting/__init__.py +5 -1
- src/streamlit_app.py +0 -40
- src/ui_components/__init__.py +21 -0
- src/ui_components/components.py +202 -0
- src/ui_components/styles.py +117 -0
- src/ui_components/utils.py +73 -0
- utils.py β src/utils.py +0 -0
- src/vectorstore.py +35 -5
|
@@ -109,4 +109,7 @@ pytest_cache/
|
|
| 109 |
tmp/
|
| 110 |
temp/
|
| 111 |
*.tmp
|
| 112 |
-
*.temp
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
tmp/
|
| 110 |
temp/
|
| 111 |
*.tmp
|
| 112 |
+
*.temp
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
local_*
|
|
@@ -10,6 +10,7 @@ import uuid
|
|
| 10 |
import logging
|
| 11 |
import traceback
|
| 12 |
from pathlib import Path
|
|
|
|
| 13 |
from collections import Counter
|
| 14 |
from typing import List, Dict, Any, Optional
|
| 15 |
|
|
@@ -19,10 +20,11 @@ import streamlit as st
|
|
| 19 |
import plotly.express as px
|
| 20 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 21 |
|
| 22 |
-
|
| 23 |
-
from
|
| 24 |
-
from src.
|
| 25 |
-
from src.
|
|
|
|
| 26 |
from src.config.paths import (
|
| 27 |
IS_DEPLOYED,
|
| 28 |
PROJECT_DIR,
|
|
@@ -31,6 +33,7 @@ from src.config.paths import (
|
|
| 31 |
CONVERSATIONS_DIR,
|
| 32 |
)
|
| 33 |
|
|
|
|
| 34 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 35 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
| 36 |
omp_threads = os.environ.get("OMP_NUM_THREADS", "")
|
|
@@ -70,6 +73,9 @@ if IS_DEPLOYED and HF_CACHE_DIR:
|
|
| 70 |
except (PermissionError, OSError):
|
| 71 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 72 |
pass
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Configure logging
|
| 75 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -90,116 +96,9 @@ st.set_page_config(
|
|
| 90 |
page_title="Intelligent Audit Report Chatbot"
|
| 91 |
)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
st.markdown(
|
| 95 |
-
|
| 96 |
-
.main-header {
|
| 97 |
-
font-size: 2.5rem;
|
| 98 |
-
font-weight: bold;
|
| 99 |
-
color: #1f77b4;
|
| 100 |
-
text-align: center;
|
| 101 |
-
margin-bottom: 1rem;
|
| 102 |
-
width: 100%;
|
| 103 |
-
display: block;
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
.subtitle {
|
| 107 |
-
font-size: 1.2rem;
|
| 108 |
-
color: #666;
|
| 109 |
-
text-align: center;
|
| 110 |
-
margin-bottom: 2rem;
|
| 111 |
-
width: 100%;
|
| 112 |
-
display: block;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
.session-info {
|
| 116 |
-
background-color: #f0f2f6;
|
| 117 |
-
padding: 10px;
|
| 118 |
-
border-radius: 5px;
|
| 119 |
-
margin-bottom: 20px;
|
| 120 |
-
font-size: 0.9rem;
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
.user-message {
|
| 124 |
-
background-color: #007bff;
|
| 125 |
-
color: white;
|
| 126 |
-
padding: 12px 16px;
|
| 127 |
-
border-radius: 18px 18px 4px 18px;
|
| 128 |
-
margin: 8px 0;
|
| 129 |
-
margin-left: 20%;
|
| 130 |
-
word-wrap: break-word;
|
| 131 |
-
}
|
| 132 |
-
|
| 133 |
-
.bot-message {
|
| 134 |
-
background-color: #f1f3f4;
|
| 135 |
-
color: #333;
|
| 136 |
-
padding: 12px 16px;
|
| 137 |
-
border-radius: 18px 18px 18px 4px;
|
| 138 |
-
margin: 8px 0;
|
| 139 |
-
margin-right: 20%;
|
| 140 |
-
word-wrap: break-word;
|
| 141 |
-
border: 1px solid #e0e0e0;
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
.filter-section {
|
| 145 |
-
margin-bottom: 20px;
|
| 146 |
-
padding: 15px;
|
| 147 |
-
background-color: #f8f9fa;
|
| 148 |
-
border-radius: 8px;
|
| 149 |
-
border: 1px solid #e9ecef;
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
.filter-title {
|
| 153 |
-
font-weight: bold;
|
| 154 |
-
margin-bottom: 10px;
|
| 155 |
-
color: #495057;
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
.feedback-section {
|
| 159 |
-
background-color: #f8f9fa;
|
| 160 |
-
padding: 20px;
|
| 161 |
-
border-radius: 10px;
|
| 162 |
-
margin-top: 30px;
|
| 163 |
-
border: 2px solid #dee2e6;
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
.retrieval-history {
|
| 167 |
-
background-color: #ffffff;
|
| 168 |
-
padding: 15px;
|
| 169 |
-
border-radius: 5px;
|
| 170 |
-
margin: 10px 0;
|
| 171 |
-
border-left: 4px solid #007bff;
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
.retrieval-distribution-container {
|
| 175 |
-
background-color: #ffffff;
|
| 176 |
-
padding: 25px;
|
| 177 |
-
border-radius: 10px;
|
| 178 |
-
margin: 20px 0;
|
| 179 |
-
border: 2px solid #e0e0e0;
|
| 180 |
-
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
.metric-label {
|
| 184 |
-
font-size: 0.9rem;
|
| 185 |
-
color: #555;
|
| 186 |
-
margin-bottom: 5px;
|
| 187 |
-
text-align: center;
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
.metric-value {
|
| 191 |
-
font-size: 1.8rem;
|
| 192 |
-
font-weight: bold;
|
| 193 |
-
color: #000000;
|
| 194 |
-
text-align: center;
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
.metric-container {
|
| 198 |
-
text-align: center;
|
| 199 |
-
padding: 10px;
|
| 200 |
-
}
|
| 201 |
-
</style>
|
| 202 |
-
""", unsafe_allow_html=True)
|
| 203 |
|
| 204 |
def get_system_type():
|
| 205 |
"""Get the current system type"""
|
|
@@ -209,14 +108,17 @@ def get_system_type():
|
|
| 209 |
else:
|
| 210 |
return "Multi-Agent System"
|
| 211 |
|
| 212 |
-
def get_chatbot():
|
| 213 |
-
"""Initialize and return the chatbot based on
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
if system == 'smart':
|
| 217 |
-
return get_smart_chatbot()
|
| 218 |
else:
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
def serialize_messages(messages):
|
| 222 |
"""Serialize LangChain messages to dictionaries"""
|
|
@@ -262,368 +164,8 @@ def serialize_documents(sources):
|
|
| 262 |
return serialized
|
| 263 |
|
| 264 |
|
| 265 |
-
|
| 266 |
-
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
|
| 267 |
-
transcript = []
|
| 268 |
-
for msg in messages:
|
| 269 |
-
if isinstance(msg, HumanMessage):
|
| 270 |
-
transcript.append({
|
| 271 |
-
"role": "user",
|
| 272 |
-
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 273 |
-
})
|
| 274 |
-
elif isinstance(msg, AIMessage):
|
| 275 |
-
transcript.append({
|
| 276 |
-
"role": "assistant",
|
| 277 |
-
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 278 |
-
})
|
| 279 |
-
return transcript
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
|
| 283 |
-
"""Build retrievals structure from retrieval history"""
|
| 284 |
-
retrievals = []
|
| 285 |
-
|
| 286 |
-
for entry in rag_retrieval_history:
|
| 287 |
-
# Get the user message that triggered this retrieval
|
| 288 |
-
# The entry has conversation_up_to which includes messages up to that point
|
| 289 |
-
conversation_up_to = entry.get("conversation_up_to", [])
|
| 290 |
-
|
| 291 |
-
# Find the last user message in conversation_up_to (this is the trigger)
|
| 292 |
-
user_message_trigger = ""
|
| 293 |
-
for msg_dict in reversed(conversation_up_to):
|
| 294 |
-
if msg_dict.get("type") == "HumanMessage":
|
| 295 |
-
user_message_trigger = msg_dict.get("content", "")
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
-
# Fallback: if not found in conversation_up_to, get from actual messages
|
| 299 |
-
# This handles edge cases where conversation_up_to might be incomplete
|
| 300 |
-
if not user_message_trigger:
|
| 301 |
-
# Find which retrieval this is (0-indexed)
|
| 302 |
-
retrieval_idx = rag_retrieval_history.index(entry)
|
| 303 |
-
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
|
| 304 |
-
# because each retrieval is preceded by: user message, bot response, user message, ...
|
| 305 |
-
# But we need to account for the fact that the first retrieval happens after the first user message
|
| 306 |
-
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
|
| 307 |
-
if retrieval_idx < len(user_msgs):
|
| 308 |
-
user_message_trigger = str(user_msgs[retrieval_idx].content)
|
| 309 |
-
elif user_msgs:
|
| 310 |
-
# Fallback to last user message
|
| 311 |
-
user_message_trigger = str(user_msgs[-1].content)
|
| 312 |
-
|
| 313 |
-
# Get retrieved documents and truncate content to 100 chars
|
| 314 |
-
docs_retrieved = entry.get("docs_retrieved", [])
|
| 315 |
-
retrieved_docs = []
|
| 316 |
-
for doc in docs_retrieved:
|
| 317 |
-
doc_copy = doc.copy()
|
| 318 |
-
# Truncate content to 100 characters (keep all other fields)
|
| 319 |
-
if "content" in doc_copy:
|
| 320 |
-
doc_copy["content"] = doc_copy["content"][:100]
|
| 321 |
-
retrieved_docs.append(doc_copy)
|
| 322 |
-
|
| 323 |
-
retrievals.append({
|
| 324 |
-
"retrieved_docs": retrieved_docs,
|
| 325 |
-
"user_message_trigger": user_message_trigger
|
| 326 |
-
})
|
| 327 |
-
|
| 328 |
-
return retrievals
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def build_feedback_score_related_retrieval_docs(
|
| 332 |
-
is_feedback_about_last_retrieval: bool,
|
| 333 |
-
messages: List[Any],
|
| 334 |
-
rag_retrieval_history: List[Dict[str, Any]]
|
| 335 |
-
) -> Optional[Dict[str, Any]]:
|
| 336 |
-
"""Build feedback_score_related_retrieval_docs structure"""
|
| 337 |
-
if not rag_retrieval_history:
|
| 338 |
-
return None
|
| 339 |
-
|
| 340 |
-
# Get the relevant retrieval entry
|
| 341 |
-
if is_feedback_about_last_retrieval:
|
| 342 |
-
relevant_entry = rag_retrieval_history[-1]
|
| 343 |
-
else:
|
| 344 |
-
# If feedback is about all retrievals, use the last one as default
|
| 345 |
-
relevant_entry = rag_retrieval_history[-1]
|
| 346 |
-
|
| 347 |
-
# Get conversation up to that point
|
| 348 |
-
conversation_up_to = relevant_entry.get("conversation_up_to", [])
|
| 349 |
-
|
| 350 |
-
# Convert to transcript format (role/content)
|
| 351 |
-
conversation_up_to_point = []
|
| 352 |
-
for msg_dict in conversation_up_to:
|
| 353 |
-
if msg_dict.get("type") == "HumanMessage":
|
| 354 |
-
conversation_up_to_point.append({
|
| 355 |
-
"role": "user",
|
| 356 |
-
"content": msg_dict.get("content", "")
|
| 357 |
-
})
|
| 358 |
-
elif msg_dict.get("type") == "AIMessage":
|
| 359 |
-
conversation_up_to_point.append({
|
| 360 |
-
"role": "assistant",
|
| 361 |
-
"content": msg_dict.get("content", "")
|
| 362 |
-
})
|
| 363 |
-
|
| 364 |
-
# Get retrieved docs with full content (not truncated)
|
| 365 |
-
retrieved_docs = relevant_entry.get("docs_retrieved", [])
|
| 366 |
-
|
| 367 |
-
return {
|
| 368 |
-
"conversation_up_to_point": conversation_up_to_point,
|
| 369 |
-
"retrieved_docs": retrieved_docs
|
| 370 |
-
}
|
| 371 |
|
| 372 |
-
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
|
| 373 |
-
"""Extract statistics from retrieved chunks."""
|
| 374 |
-
if not sources:
|
| 375 |
-
return {}
|
| 376 |
-
|
| 377 |
-
sources_list = []
|
| 378 |
-
years = []
|
| 379 |
-
filenames = []
|
| 380 |
-
districts = []
|
| 381 |
-
|
| 382 |
-
for doc in sources:
|
| 383 |
-
metadata = getattr(doc, 'metadata', {})
|
| 384 |
-
|
| 385 |
-
# Extract source
|
| 386 |
-
source = metadata.get('source', 'Unknown')
|
| 387 |
-
sources_list.append(source)
|
| 388 |
-
|
| 389 |
-
# Extract year
|
| 390 |
-
year = metadata.get('year', 'Unknown')
|
| 391 |
-
if year and year != 'Unknown':
|
| 392 |
-
try:
|
| 393 |
-
# Convert to int first, then back to string to ensure it's a proper year
|
| 394 |
-
year_int = int(float(year)) # Handle both int and float strings
|
| 395 |
-
if 1900 <= year_int <= 2030: # Reasonable year range
|
| 396 |
-
years.append(str(year_int))
|
| 397 |
-
else:
|
| 398 |
-
years.append('Unknown')
|
| 399 |
-
except (ValueError, TypeError):
|
| 400 |
-
years.append('Unknown')
|
| 401 |
-
else:
|
| 402 |
-
years.append('Unknown')
|
| 403 |
-
|
| 404 |
-
# Extract filename
|
| 405 |
-
filename = metadata.get('filename', 'Unknown')
|
| 406 |
-
filenames.append(filename)
|
| 407 |
-
|
| 408 |
-
# Extract district
|
| 409 |
-
district = metadata.get('district', 'Unknown')
|
| 410 |
-
if district and district != 'Unknown':
|
| 411 |
-
districts.append(district)
|
| 412 |
-
else:
|
| 413 |
-
districts.append('Unknown')
|
| 414 |
-
|
| 415 |
-
# Count occurrences
|
| 416 |
-
source_counts = Counter(sources_list)
|
| 417 |
-
year_counts = Counter(years)
|
| 418 |
-
filename_counts = Counter(filenames)
|
| 419 |
-
district_counts = Counter(districts)
|
| 420 |
-
|
| 421 |
-
return {
|
| 422 |
-
'total_chunks': len(sources),
|
| 423 |
-
'unique_sources': len(source_counts),
|
| 424 |
-
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
|
| 425 |
-
'unique_filenames': len(filename_counts),
|
| 426 |
-
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
|
| 427 |
-
'source_distribution': dict(source_counts),
|
| 428 |
-
'year_distribution': dict(year_counts),
|
| 429 |
-
'filename_distribution': dict(filename_counts),
|
| 430 |
-
'district_distribution': dict(district_counts),
|
| 431 |
-
'sources': sources_list,
|
| 432 |
-
'years': years,
|
| 433 |
-
'filenames': filenames,
|
| 434 |
-
'districts': districts
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
|
| 438 |
-
"""Display statistics as interactive charts for 10+ results."""
|
| 439 |
-
if not stats or stats.get('total_chunks', 0) == 0:
|
| 440 |
-
return
|
| 441 |
-
|
| 442 |
-
# Wrap everything in one styled container - open it
|
| 443 |
-
st.markdown(f"""
|
| 444 |
-
<div class="retrieval-distribution-container">
|
| 445 |
-
<h3 style="margin-top: 0;">π {title}</h3>
|
| 446 |
-
<div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
|
| 447 |
-
<div class="metric-container">
|
| 448 |
-
<div class="metric-label">Total Chunks</div>
|
| 449 |
-
<div class="metric-value">{stats['total_chunks']}</div>
|
| 450 |
-
</div>
|
| 451 |
-
<div class="metric-container">
|
| 452 |
-
<div class="metric-label">Unique Sources</div>
|
| 453 |
-
<div class="metric-value">{stats['unique_sources']}</div>
|
| 454 |
-
</div>
|
| 455 |
-
<div class="metric-container">
|
| 456 |
-
<div class="metric-label">Unique Years</div>
|
| 457 |
-
<div class="metric-value">{stats['unique_years']}</div>
|
| 458 |
-
</div>
|
| 459 |
-
<div class="metric-container">
|
| 460 |
-
<div class="metric-label">Unique Files</div>
|
| 461 |
-
<div class="metric-value">{stats['unique_filenames']}</div>
|
| 462 |
-
</div>
|
| 463 |
-
</div>
|
| 464 |
-
""", unsafe_allow_html=True)
|
| 465 |
-
|
| 466 |
-
# Charts - three columns to include Districts
|
| 467 |
-
col1, col2, col3 = st.columns(3)
|
| 468 |
-
|
| 469 |
-
with col1:
|
| 470 |
-
# Source distribution chart
|
| 471 |
-
if stats['source_distribution']:
|
| 472 |
-
source_df = pd.DataFrame(
|
| 473 |
-
list(stats['source_distribution'].items()),
|
| 474 |
-
columns=['Source', 'Count']
|
| 475 |
-
)
|
| 476 |
-
fig_source = px.bar(
|
| 477 |
-
source_df,
|
| 478 |
-
x='Count',
|
| 479 |
-
y='Source',
|
| 480 |
-
orientation='h',
|
| 481 |
-
title='Distribution by Source',
|
| 482 |
-
color='Count',
|
| 483 |
-
color_continuous_scale='viridis'
|
| 484 |
-
)
|
| 485 |
-
fig_source.update_layout(height=400, showlegend=False)
|
| 486 |
-
st.plotly_chart(fig_source, use_container_width=True)
|
| 487 |
-
|
| 488 |
-
with col2:
|
| 489 |
-
# Year distribution chart
|
| 490 |
-
if stats['year_distribution']:
|
| 491 |
-
# Filter out 'Unknown' years for the chart
|
| 492 |
-
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 493 |
-
if year_dist_filtered:
|
| 494 |
-
year_df = pd.DataFrame(
|
| 495 |
-
list(year_dist_filtered.items()),
|
| 496 |
-
columns=['Year', 'Count']
|
| 497 |
-
)
|
| 498 |
-
# Sort by year as integer but keep as string for categorical display
|
| 499 |
-
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 500 |
-
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
|
| 501 |
-
|
| 502 |
-
fig_year = px.bar(
|
| 503 |
-
year_df,
|
| 504 |
-
x='Year',
|
| 505 |
-
y='Count',
|
| 506 |
-
title='Distribution by Year',
|
| 507 |
-
color='Count',
|
| 508 |
-
color_continuous_scale='plasma'
|
| 509 |
-
)
|
| 510 |
-
# Ensure years are treated as categorical (discrete) not continuous
|
| 511 |
-
fig_year.update_xaxes(type='category')
|
| 512 |
-
fig_year.update_layout(height=400, showlegend=False)
|
| 513 |
-
st.plotly_chart(fig_year, use_container_width=True)
|
| 514 |
-
else:
|
| 515 |
-
st.info("No valid years found in the results")
|
| 516 |
-
|
| 517 |
-
with col3:
|
| 518 |
-
# District distribution chart
|
| 519 |
-
if stats.get('district_distribution'):
|
| 520 |
-
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 521 |
-
if district_dist_filtered:
|
| 522 |
-
district_df = pd.DataFrame(
|
| 523 |
-
list(district_dist_filtered.items()),
|
| 524 |
-
columns=['District', 'Count']
|
| 525 |
-
)
|
| 526 |
-
district_df = district_df.sort_values('Count', ascending=False)
|
| 527 |
-
|
| 528 |
-
fig_district = px.bar(
|
| 529 |
-
district_df,
|
| 530 |
-
x='Count',
|
| 531 |
-
y='District',
|
| 532 |
-
orientation='h',
|
| 533 |
-
title='Distribution by District',
|
| 534 |
-
color='Count',
|
| 535 |
-
color_continuous_scale='blues'
|
| 536 |
-
)
|
| 537 |
-
fig_district.update_layout(height=400, showlegend=False)
|
| 538 |
-
st.plotly_chart(fig_district, use_container_width=True)
|
| 539 |
-
else:
|
| 540 |
-
st.info("No valid districts found in the results")
|
| 541 |
-
|
| 542 |
-
# Close the container
|
| 543 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 544 |
-
|
| 545 |
-
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
|
| 546 |
-
"""Display statistics as tables for smaller results with fixed alignment."""
|
| 547 |
-
if not stats or stats.get('total_chunks', 0) == 0:
|
| 548 |
-
return
|
| 549 |
-
|
| 550 |
-
# Wrap in styled container
|
| 551 |
-
st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 552 |
-
|
| 553 |
-
st.subheader(f"π {title}")
|
| 554 |
-
|
| 555 |
-
# Create a container with fixed height for alignment
|
| 556 |
-
stats_container = st.container()
|
| 557 |
-
|
| 558 |
-
with stats_container:
|
| 559 |
-
# Create 4 equal columns for consistent alignment
|
| 560 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 561 |
-
|
| 562 |
-
with col1:
|
| 563 |
-
st.markdown("**ποΈ Districts**")
|
| 564 |
-
if stats.get('district_distribution'):
|
| 565 |
-
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 566 |
-
if district_dist_filtered:
|
| 567 |
-
district_data = {
|
| 568 |
-
"District": list(district_dist_filtered.keys()),
|
| 569 |
-
"Count": list(district_dist_filtered.values())
|
| 570 |
-
}
|
| 571 |
-
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
|
| 572 |
-
st.dataframe(district_df, hide_index=True, use_container_width=True)
|
| 573 |
-
else:
|
| 574 |
-
st.write("No district data")
|
| 575 |
-
else:
|
| 576 |
-
st.write("No district data")
|
| 577 |
-
|
| 578 |
-
with col2:
|
| 579 |
-
st.markdown("**π Sources**")
|
| 580 |
-
if stats['source_distribution']:
|
| 581 |
-
source_data = {
|
| 582 |
-
"Source": list(stats['source_distribution'].keys()),
|
| 583 |
-
"Count": list(stats['source_distribution'].values())
|
| 584 |
-
}
|
| 585 |
-
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
|
| 586 |
-
st.dataframe(source_df, hide_index=True, use_container_width=True)
|
| 587 |
-
else:
|
| 588 |
-
st.write("No source data")
|
| 589 |
-
|
| 590 |
-
with col3:
|
| 591 |
-
st.markdown("**π
Years**")
|
| 592 |
-
if stats['year_distribution']:
|
| 593 |
-
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 594 |
-
if year_dist_filtered:
|
| 595 |
-
year_data = {
|
| 596 |
-
"Year": list(year_dist_filtered.keys()),
|
| 597 |
-
"Count": list(year_dist_filtered.values())
|
| 598 |
-
}
|
| 599 |
-
year_df = pd.DataFrame(year_data)
|
| 600 |
-
# Sort by year as integer but display as string
|
| 601 |
-
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 602 |
-
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
|
| 603 |
-
st.dataframe(year_df, hide_index=True, use_container_width=True)
|
| 604 |
-
else:
|
| 605 |
-
st.write("No year data")
|
| 606 |
-
else:
|
| 607 |
-
st.write("No year data")
|
| 608 |
-
|
| 609 |
-
with col4:
|
| 610 |
-
st.markdown("**π Files**")
|
| 611 |
-
if stats['filename_distribution']:
|
| 612 |
-
filename_items = list(stats['filename_distribution'].items())
|
| 613 |
-
filename_items.sort(key=lambda x: x[1], reverse=True)
|
| 614 |
-
|
| 615 |
-
# Show top files with truncated names
|
| 616 |
-
file_data = {
|
| 617 |
-
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
|
| 618 |
-
"Count": [c for f, c in filename_items[:5]]
|
| 619 |
-
}
|
| 620 |
-
file_df = pd.DataFrame(file_data)
|
| 621 |
-
st.dataframe(file_df, hide_index=True, use_container_width=True)
|
| 622 |
-
else:
|
| 623 |
-
st.write("No file data")
|
| 624 |
-
|
| 625 |
-
# Close container
|
| 626 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 627 |
|
| 628 |
@st.cache_data
|
| 629 |
def load_filter_options():
|
|
@@ -649,11 +191,48 @@ def main():
|
|
| 649 |
# Track RAG retrieval history for feedback
|
| 650 |
if 'rag_retrieval_history' not in st.session_state:
|
| 651 |
st.session_state.rag_retrieval_history = []
|
| 652 |
-
#
|
| 653 |
-
if '
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 659 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
@@ -665,9 +244,43 @@ def main():
|
|
| 665 |
st.session_state.reset_conversation = False
|
| 666 |
st.rerun()
|
| 667 |
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
# Session info
|
| 673 |
duration = int(time.time() - st.session_state.session_start_time)
|
|
@@ -729,7 +342,7 @@ def main():
|
|
| 729 |
# Determine if filename filter is active
|
| 730 |
filename_mode = len(selected_filenames) > 0
|
| 731 |
# Sources filter
|
| 732 |
-
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 733 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 734 |
selected_sources = st.multiselect(
|
| 735 |
"Select sources:",
|
|
@@ -826,7 +439,7 @@ def main():
|
|
| 826 |
)
|
| 827 |
|
| 828 |
with col2:
|
| 829 |
-
send_button = st.button("Send", key="send_button",
|
| 830 |
|
| 831 |
# Clear chat button
|
| 832 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
|
@@ -878,6 +491,36 @@ def main():
|
|
| 878 |
if rag_result:
|
| 879 |
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
# Get the actual RAG query
|
| 882 |
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 883 |
if actual_rag_query:
|
|
@@ -887,12 +530,25 @@ def main():
|
|
| 887 |
else:
|
| 888 |
formatted_query = "No RAG query available"
|
| 889 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
retrieval_entry = {
|
| 891 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 892 |
"rag_query_expansion": formatted_query,
|
| 893 |
-
"docs_retrieved": serialize_documents(sources)
|
|
|
|
|
|
|
| 894 |
}
|
| 895 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
|
|
|
|
|
|
|
|
|
| 896 |
else:
|
| 897 |
response = chat_result
|
| 898 |
st.session_state.last_rag_result = None
|
|
@@ -922,6 +578,16 @@ def main():
|
|
| 922 |
# Dictionary format from multi-agent system
|
| 923 |
sources = rag_result['sources']
|
| 924 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
if sources and len(sources) > 0:
|
| 926 |
# Count unique filenames
|
| 927 |
unique_filenames = set()
|
|
@@ -951,9 +617,18 @@ def main():
|
|
| 951 |
for i, doc in enumerate(sources): # Show all documents
|
| 952 |
# Get relevance score and ID if available
|
| 953 |
metadata = getattr(doc, 'metadata', {})
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
|
| 958 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 959 |
# Display document metadata with emojis
|
|
@@ -1031,7 +706,7 @@ def main():
|
|
| 1031 |
|
| 1032 |
submitted = st.form_submit_button(
|
| 1033 |
"π€ Submit Feedback",
|
| 1034 |
-
|
| 1035 |
disabled=submit_disabled
|
| 1036 |
)
|
| 1037 |
|
|
@@ -1043,16 +718,18 @@ def main():
|
|
| 1043 |
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 1044 |
|
| 1045 |
# Extract transcript from messages
|
| 1046 |
-
transcript = extract_transcript(st.session_state.messages)
|
| 1047 |
|
| 1048 |
# Build retrievals structure
|
| 1049 |
-
retrievals = build_retrievals_structure(
|
|
|
|
| 1050 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 1051 |
st.session_state.messages
|
| 1052 |
)
|
| 1053 |
|
| 1054 |
# Build feedback_score_related_retrieval_docs
|
| 1055 |
-
|
|
|
|
| 1056 |
is_feedback_about_last_retrieval,
|
| 1057 |
st.session_state.messages,
|
| 1058 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
|
@@ -1082,7 +759,7 @@ def main():
|
|
| 1082 |
# Create UserFeedback dataclass instance
|
| 1083 |
feedback_obj = None # Initialize outside try block
|
| 1084 |
try:
|
| 1085 |
-
feedback_obj = create_feedback_from_dict(feedback_dict)
|
| 1086 |
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 1087 |
st.write(f"β
**Feedback Object Created**")
|
| 1088 |
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
|
@@ -1138,7 +815,11 @@ def main():
|
|
| 1138 |
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 1139 |
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 1140 |
|
| 1141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1142 |
if snowflake_success:
|
| 1143 |
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 1144 |
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
|
@@ -1193,7 +874,7 @@ def main():
|
|
| 1193 |
st.success("β
Feedback already submitted for this conversation!")
|
| 1194 |
col1, col2 = st.columns([1, 1])
|
| 1195 |
with col1:
|
| 1196 |
-
if st.button("π Submit New Feedback", key="new_feedback_button",
|
| 1197 |
try:
|
| 1198 |
st.session_state.feedback_submitted = False
|
| 1199 |
st.rerun()
|
|
@@ -1202,7 +883,7 @@ def main():
|
|
| 1202 |
logger.error(f"Error resetting feedback state: {e}")
|
| 1203 |
st.error(f"Error resetting feedback. Please refresh the page.")
|
| 1204 |
with col2:
|
| 1205 |
-
if st.button("π View Conversation", key="view_conversation_button",
|
| 1206 |
# Scroll to conversation - this is handled by the auto-scroll at bottom
|
| 1207 |
pass
|
| 1208 |
|
|
@@ -1211,20 +892,111 @@ def main():
|
|
| 1211 |
st.markdown("---")
|
| 1212 |
st.markdown("#### π Retrieval History")
|
| 1213 |
|
| 1214 |
-
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=
|
| 1215 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 1216 |
-
st.markdown(f"**Retrieval #{idx}**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
|
| 1218 |
# Display the actual RAG query
|
| 1219 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
|
|
|
| 1220 |
st.code(rag_query_expansion, language="text")
|
| 1221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
# Display summary stats
|
|
|
|
| 1223 |
st.json({
|
| 1224 |
-
"conversation_length": len(
|
| 1225 |
-
"documents_retrieved": len(
|
| 1226 |
})
|
| 1227 |
-
|
|
|
|
|
|
|
| 1228 |
|
| 1229 |
# Example Questions Section
|
| 1230 |
st.markdown("---")
|
|
@@ -1245,7 +1017,7 @@ def main():
|
|
| 1245 |
st.markdown(f"**Example:** `{example_q1}`")
|
| 1246 |
st.info("π‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
|
| 1247 |
with col2:
|
| 1248 |
-
if st.button("π Use This Question", key="use_example_1",
|
| 1249 |
st.session_state.pending_question = example_q1
|
| 1250 |
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1251 |
st.rerun()
|
|
@@ -1266,7 +1038,7 @@ def main():
|
|
| 1266 |
)
|
| 1267 |
col1, col2 = st.columns([1, 4])
|
| 1268 |
with col1:
|
| 1269 |
-
if st.button("π Use Question 2", key="use_custom_1",
|
| 1270 |
if custom_q1.strip():
|
| 1271 |
st.session_state.pending_question = custom_q1.strip()
|
| 1272 |
st.session_state.custom_question_1 = custom_q1.strip()
|
|
@@ -1292,7 +1064,7 @@ def main():
|
|
| 1292 |
)
|
| 1293 |
col1, col2 = st.columns([1, 4])
|
| 1294 |
with col1:
|
| 1295 |
-
if st.button("π Use Question 3", key="use_custom_2",
|
| 1296 |
if custom_q2.strip():
|
| 1297 |
st.session_state.pending_question = custom_q2.strip()
|
| 1298 |
st.session_state.custom_question_2 = custom_q2.strip()
|
|
|
|
| 10 |
import logging
|
| 11 |
import traceback
|
| 12 |
from pathlib import Path
|
| 13 |
+
|
| 14 |
from collections import Counter
|
| 15 |
from typing import List, Dict, Any, Optional
|
| 16 |
|
|
|
|
| 20 |
import plotly.express as px
|
| 21 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 22 |
|
| 23 |
+
|
| 24 |
+
from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
|
| 25 |
+
from src.feedback import FeedbackManager
|
| 26 |
+
from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
|
| 27 |
+
|
| 28 |
from src.config.paths import (
|
| 29 |
IS_DEPLOYED,
|
| 30 |
PROJECT_DIR,
|
|
|
|
| 33 |
CONVERSATIONS_DIR,
|
| 34 |
)
|
| 35 |
|
| 36 |
+
|
| 37 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 38 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
| 39 |
omp_threads = os.environ.get("OMP_NUM_THREADS", "")
|
|
|
|
| 73 |
except (PermissionError, OSError):
|
| 74 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 75 |
pass
|
| 76 |
+
else:
|
| 77 |
+
from dotenv import load_dotenv
|
| 78 |
+
load_dotenv()
|
| 79 |
|
| 80 |
# Configure logging
|
| 81 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 96 |
page_title="Intelligent Audit Report Chatbot"
|
| 97 |
)
|
| 98 |
|
| 99 |
+
|
| 100 |
+
st.markdown(get_custom_css(), unsafe_allow_html=True)
|
| 101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def get_system_type():
|
| 104 |
"""Get the current system type"""
|
|
|
|
| 108 |
else:
|
| 109 |
return "Multi-Agent System"
|
| 110 |
|
| 111 |
+
def get_chatbot(version: str = "v1"):
|
| 112 |
+
"""Initialize and return the chatbot based on version"""
|
| 113 |
+
if version == "beta":
|
| 114 |
+
return get_gemini_chatbot()
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
+
# Check environment variable for system type (v1)
|
| 117 |
+
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 118 |
+
if system == 'smart':
|
| 119 |
+
return get_smart_chatbot()
|
| 120 |
+
else:
|
| 121 |
+
return get_multi_agent_chatbot()
|
| 122 |
|
| 123 |
def serialize_messages(messages):
|
| 124 |
"""Serialize LangChain messages to dictionaries"""
|
|
|
|
| 164 |
return serialized
|
| 165 |
|
| 166 |
|
| 167 |
+
feedback_manager = FeedbackManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
@st.cache_data
|
| 171 |
def load_filter_options():
|
|
|
|
| 191 |
# Track RAG retrieval history for feedback
|
| 192 |
if 'rag_retrieval_history' not in st.session_state:
|
| 193 |
st.session_state.rag_retrieval_history = []
|
| 194 |
+
# Version selection (v1 or beta)
|
| 195 |
+
if 'chatbot_version' not in st.session_state:
|
| 196 |
+
st.session_state.chatbot_version = "v1"
|
| 197 |
+
|
| 198 |
+
# Initialize chatbot based on version (only if not already initialized for this version)
|
| 199 |
+
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 200 |
+
|
| 201 |
+
# Check if we need to initialize: chatbot doesn't exist OR version changed
|
| 202 |
+
needs_init = (
|
| 203 |
+
chatbot_version_key not in st.session_state or
|
| 204 |
+
st.session_state.get('_last_version') != st.session_state.chatbot_version
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if needs_init:
|
| 208 |
+
try:
|
| 209 |
+
# Different spinner messages for different versions
|
| 210 |
+
if st.session_state.chatbot_version == "beta":
|
| 211 |
+
spinner_msg = "π Initializing Gemini FSA"
|
| 212 |
+
else:
|
| 213 |
+
spinner_msg = "π Loading AI models and connecting to database..."
|
| 214 |
+
|
| 215 |
+
with st.spinner(spinner_msg):
|
| 216 |
+
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 217 |
+
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 218 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 219 |
+
print("β
AI system ready!")
|
| 220 |
+
except Exception as e:
|
| 221 |
+
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 222 |
+
# Only show Gemini-specific error message for beta version
|
| 223 |
+
if st.session_state.chatbot_version == "beta":
|
| 224 |
+
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
|
| 225 |
+
else:
|
| 226 |
+
st.error("Please check your configuration and ensure all required models and databases are accessible.")
|
| 227 |
+
# Reset to v1 to prevent infinite loop
|
| 228 |
+
st.session_state.chatbot_version = "v1"
|
| 229 |
+
st.session_state['_last_version'] = "v1"
|
| 230 |
+
if 'chatbot' in st.session_state:
|
| 231 |
+
del st.session_state['chatbot']
|
| 232 |
+
st.stop() # Stop execution to prevent infinite loop
|
| 233 |
+
else:
|
| 234 |
+
# Chatbot already initialized for this version, just use it
|
| 235 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 236 |
|
| 237 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 238 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
|
|
| 244 |
st.session_state.reset_conversation = False
|
| 245 |
st.rerun()
|
| 246 |
|
| 247 |
+
|
| 248 |
+
# Version selection radio button (top right)
|
| 249 |
+
col1, col2 = st.columns([3, 1])
|
| 250 |
+
with col1:
|
| 251 |
+
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 252 |
+
with col2:
|
| 253 |
+
st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
|
| 254 |
+
selected_version = st.radio(
|
| 255 |
+
"**Version:**",
|
| 256 |
+
options=["v1", "beta"],
|
| 257 |
+
index=0 if st.session_state.chatbot_version == "v1" else 1,
|
| 258 |
+
horizontal=True,
|
| 259 |
+
key="version_selector",
|
| 260 |
+
help="Select v1 (default RAG system) or beta (Gemini FSA)"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Update version if changed
|
| 264 |
+
if selected_version != st.session_state.chatbot_version:
|
| 265 |
+
# Store the old version to check if we need to switch
|
| 266 |
+
old_version = st.session_state.chatbot_version
|
| 267 |
+
st.session_state.chatbot_version = selected_version
|
| 268 |
+
|
| 269 |
+
# If chatbot for new version already exists, just switch to it
|
| 270 |
+
new_chatbot_key = f"chatbot_{selected_version}"
|
| 271 |
+
if new_chatbot_key in st.session_state:
|
| 272 |
+
# Chatbot already exists, just switch
|
| 273 |
+
st.session_state.chatbot = st.session_state[new_chatbot_key]
|
| 274 |
+
st.session_state['_last_version'] = selected_version
|
| 275 |
+
else:
|
| 276 |
+
# Need to initialize new version - will be handled by initialization logic above
|
| 277 |
+
st.session_state['_last_version'] = old_version # Set to old to trigger init check
|
| 278 |
+
|
| 279 |
+
st.rerun()
|
| 280 |
+
|
| 281 |
+
# Show version info
|
| 282 |
+
if st.session_state.chatbot_version == "beta":
|
| 283 |
+
st.info("π¬ **Beta Mode**: Using Google Gemini FSA")
|
| 284 |
|
| 285 |
# Session info
|
| 286 |
duration = int(time.time() - st.session_state.session_start_time)
|
|
|
|
| 342 |
# Determine if filename filter is active
|
| 343 |
filename_mode = len(selected_filenames) > 0
|
| 344 |
# Sources filter
|
| 345 |
+
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 346 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 347 |
selected_sources = st.multiselect(
|
| 348 |
"Select sources:",
|
|
|
|
| 439 |
)
|
| 440 |
|
| 441 |
with col2:
|
| 442 |
+
send_button = st.button("Send", key="send_button", width='stretch')
|
| 443 |
|
| 444 |
# Clear chat button
|
| 445 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
|
|
|
| 491 |
if rag_result:
|
| 492 |
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 493 |
|
| 494 |
+
# For Gemini, also check gemini_result for sources
|
| 495 |
+
if not sources or len(sources) == 0:
|
| 496 |
+
gemini_result = chat_result.get('gemini_result')
|
| 497 |
+
print(f"π DEBUG: Checking gemini_result for sources...")
|
| 498 |
+
print(f" gemini_result exists: {gemini_result is not None}")
|
| 499 |
+
if gemini_result:
|
| 500 |
+
print(f" gemini_result type: {type(gemini_result)}")
|
| 501 |
+
print(f" has sources attr: {hasattr(gemini_result, 'sources')}")
|
| 502 |
+
if hasattr(gemini_result, 'sources'):
|
| 503 |
+
print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}")
|
| 504 |
+
|
| 505 |
+
if gemini_result and hasattr(gemini_result, 'sources'):
|
| 506 |
+
# Format Gemini sources for display
|
| 507 |
+
if hasattr(st.session_state.chatbot, 'gemini_client'):
|
| 508 |
+
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
|
| 509 |
+
print(f"β
Formatted {len(sources)} sources from gemini_client")
|
| 510 |
+
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
|
| 511 |
+
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 512 |
+
print(f"β
Formatted {len(sources)} sources from _format_gemini_sources")
|
| 513 |
+
|
| 514 |
+
# Update rag_result with sources if we found them
|
| 515 |
+
if sources and len(sources) > 0:
|
| 516 |
+
if isinstance(rag_result, dict):
|
| 517 |
+
rag_result['sources'] = sources
|
| 518 |
+
elif hasattr(rag_result, 'sources'):
|
| 519 |
+
rag_result.sources = sources
|
| 520 |
+
# Update last_rag_result with sources
|
| 521 |
+
st.session_state.last_rag_result = rag_result
|
| 522 |
+
print(f"β
Updated rag_result with {len(sources)} sources")
|
| 523 |
+
|
| 524 |
# Get the actual RAG query
|
| 525 |
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 526 |
if actual_rag_query:
|
|
|
|
| 530 |
else:
|
| 531 |
formatted_query = "No RAG query available"
|
| 532 |
|
| 533 |
+
# Extract filters from active filters
|
| 534 |
+
filters_used = {
|
| 535 |
+
"sources": st.session_state.active_filters.get('sources', []),
|
| 536 |
+
"years": st.session_state.active_filters.get('years', []),
|
| 537 |
+
"districts": st.session_state.active_filters.get('districts', []),
|
| 538 |
+
"filenames": st.session_state.active_filters.get('filenames', [])
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
retrieval_entry = {
|
| 542 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 543 |
"rag_query_expansion": formatted_query,
|
| 544 |
+
"docs_retrieved": serialize_documents(sources),
|
| 545 |
+
"filters_applied": filters_used,
|
| 546 |
+
"timestamp": time.time()
|
| 547 |
}
|
| 548 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 549 |
+
|
| 550 |
+
# Debug logging
|
| 551 |
+
print(f"π RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history")
|
| 552 |
else:
|
| 553 |
response = chat_result
|
| 554 |
st.session_state.last_rag_result = None
|
|
|
|
| 578 |
# Dictionary format from multi-agent system
|
| 579 |
sources = rag_result['sources']
|
| 580 |
|
| 581 |
+
# For Gemini, also check if we need to format sources from gemini_result
|
| 582 |
+
if (not sources or len(sources) == 0) and isinstance(rag_result, dict):
|
| 583 |
+
gemini_result = rag_result.get('gemini_result')
|
| 584 |
+
if gemini_result and hasattr(gemini_result, 'sources'):
|
| 585 |
+
# Format Gemini sources for display
|
| 586 |
+
if hasattr(st.session_state.chatbot, 'gemini_client'):
|
| 587 |
+
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
|
| 588 |
+
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
|
| 589 |
+
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 590 |
+
|
| 591 |
if sources and len(sources) > 0:
|
| 592 |
# Count unique filenames
|
| 593 |
unique_filenames = set()
|
|
|
|
| 617 |
for i, doc in enumerate(sources): # Show all documents
|
| 618 |
# Get relevance score and ID if available
|
| 619 |
metadata = getattr(doc, 'metadata', {})
|
| 620 |
+
# Handle both standard RAG scores and Gemini scores
|
| 621 |
+
score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
|
| 622 |
+
chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
|
| 623 |
+
if score is not None:
|
| 624 |
+
try:
|
| 625 |
+
score_text = f" (Score: {float(score):.3f})"
|
| 626 |
+
except (ValueError, TypeError):
|
| 627 |
+
score_text = ""
|
| 628 |
+
else:
|
| 629 |
+
score_text = ""
|
| 630 |
+
if chunk_id and chunk_id != 'Unknown':
|
| 631 |
+
score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
|
| 632 |
|
| 633 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 634 |
# Display document metadata with emojis
|
|
|
|
| 706 |
|
| 707 |
submitted = st.form_submit_button(
|
| 708 |
"π€ Submit Feedback",
|
| 709 |
+
width='stretch',
|
| 710 |
disabled=submit_disabled
|
| 711 |
)
|
| 712 |
|
|
|
|
| 718 |
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 719 |
|
| 720 |
# Extract transcript from messages
|
| 721 |
+
transcript = feedback_manager.extract_transcript(st.session_state.messages)
|
| 722 |
|
| 723 |
# Build retrievals structure
|
| 724 |
+
retrievals = feedback_manager.build_retrievals_structure(
|
| 725 |
+
|
| 726 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 727 |
st.session_state.messages
|
| 728 |
)
|
| 729 |
|
| 730 |
# Build feedback_score_related_retrieval_docs
|
| 731 |
+
|
| 732 |
+
feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
|
| 733 |
is_feedback_about_last_retrieval,
|
| 734 |
st.session_state.messages,
|
| 735 |
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
|
|
|
| 759 |
# Create UserFeedback dataclass instance
|
| 760 |
feedback_obj = None # Initialize outside try block
|
| 761 |
try:
|
| 762 |
+
feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
|
| 763 |
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 764 |
st.write(f"β
**Feedback Object Created**")
|
| 765 |
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
|
|
|
| 815 |
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 816 |
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 817 |
|
| 818 |
+
# Show spinner while saving to Snowflake (can take 10-15 seconds)
|
| 819 |
+
# This includes: connection establishment (~5s), data preparation, and SQL execution (~5s)
|
| 820 |
+
with st.spinner("πΎ Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"):
|
| 821 |
+
snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
|
| 822 |
+
|
| 823 |
if snowflake_success:
|
| 824 |
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 825 |
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
|
|
|
| 874 |
st.success("β
Feedback already submitted for this conversation!")
|
| 875 |
col1, col2 = st.columns([1, 1])
|
| 876 |
with col1:
|
| 877 |
+
if st.button("π Submit New Feedback", key="new_feedback_button", width='stretch'):
|
| 878 |
try:
|
| 879 |
st.session_state.feedback_submitted = False
|
| 880 |
st.rerun()
|
|
|
|
| 883 |
logger.error(f"Error resetting feedback state: {e}")
|
| 884 |
st.error(f"Error resetting feedback. Please refresh the page.")
|
| 885 |
with col2:
|
| 886 |
+
if st.button("π View Conversation", key="view_conversation_button", width='stretch'):
|
| 887 |
# Scroll to conversation - this is handled by the auto-scroll at bottom
|
| 888 |
pass
|
| 889 |
|
|
|
|
| 892 |
st.markdown("---")
|
| 893 |
st.markdown("#### π Retrieval History")
|
| 894 |
|
| 895 |
+
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
|
| 896 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 897 |
+
st.markdown(f"### **Retrieval #{idx}**")
|
| 898 |
+
|
| 899 |
+
# Display timestamp if available
|
| 900 |
+
if entry.get("timestamp"):
|
| 901 |
+
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
|
| 902 |
+
st.caption(f"π {timestamp_str}")
|
| 903 |
|
| 904 |
# Display the actual RAG query
|
| 905 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
| 906 |
+
st.markdown("**π RAG Query:**")
|
| 907 |
st.code(rag_query_expansion, language="text")
|
| 908 |
|
| 909 |
+
# Display filters used
|
| 910 |
+
filters_applied = entry.get("filters_applied", {})
|
| 911 |
+
if filters_applied and any(filters_applied.values()):
|
| 912 |
+
st.markdown("**π― Filters Applied:**")
|
| 913 |
+
filter_display = {}
|
| 914 |
+
if filters_applied.get("sources"):
|
| 915 |
+
filter_display["Sources"] = filters_applied["sources"]
|
| 916 |
+
if filters_applied.get("years"):
|
| 917 |
+
filter_display["Years"] = filters_applied["years"]
|
| 918 |
+
if filters_applied.get("districts"):
|
| 919 |
+
filter_display["Districts"] = filters_applied["districts"]
|
| 920 |
+
if filters_applied.get("filenames"):
|
| 921 |
+
filter_display["Filenames"] = filters_applied["filenames"]
|
| 922 |
+
|
| 923 |
+
if filter_display:
|
| 924 |
+
st.json(filter_display)
|
| 925 |
+
else:
|
| 926 |
+
st.info("No filters applied")
|
| 927 |
+
else:
|
| 928 |
+
st.info("No filters applied")
|
| 929 |
+
|
| 930 |
+
# Display conversation history up to retrieval point
|
| 931 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 932 |
+
if conversation_up_to:
|
| 933 |
+
st.markdown("**π¬ Conversation History (up to retrieval point):**")
|
| 934 |
+
with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
|
| 935 |
+
for msg_idx, msg in enumerate(conversation_up_to, 1):
|
| 936 |
+
role = msg.get("type", "unknown")
|
| 937 |
+
content = msg.get("content", "")
|
| 938 |
+
|
| 939 |
+
if role == "HumanMessage" or role == "human":
|
| 940 |
+
st.markdown(f"**π€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 941 |
+
elif role == "AIMessage" or role == "ai":
|
| 942 |
+
st.markdown(f"**π€ Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 943 |
+
else:
|
| 944 |
+
st.info("No conversation history available")
|
| 945 |
+
|
| 946 |
+
# Display documents retrieved
|
| 947 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 948 |
+
if docs_retrieved:
|
| 949 |
+
st.markdown(f"**π Documents Retrieved ({len(docs_retrieved)}):**")
|
| 950 |
+
with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
|
| 951 |
+
for doc_idx, doc in enumerate(docs_retrieved, 1):
|
| 952 |
+
st.markdown(f"**Document {doc_idx}:**")
|
| 953 |
+
|
| 954 |
+
# Display metadata
|
| 955 |
+
metadata = doc.get("metadata", {})
|
| 956 |
+
if metadata:
|
| 957 |
+
col1, col2, col3 = st.columns(3)
|
| 958 |
+
with col1:
|
| 959 |
+
st.write(f"π **File:** {metadata.get('filename', 'Unknown')}")
|
| 960 |
+
with col2:
|
| 961 |
+
st.write(f"ποΈ **Source:** {metadata.get('source', 'Unknown')}")
|
| 962 |
+
with col3:
|
| 963 |
+
st.write(f"π
**Year:** {metadata.get('year', 'Unknown')}")
|
| 964 |
+
|
| 965 |
+
# Additional metadata
|
| 966 |
+
if metadata.get('district'):
|
| 967 |
+
st.write(f"π **District:** {metadata.get('district')}")
|
| 968 |
+
if metadata.get('page'):
|
| 969 |
+
st.write(f"π **Page:** {metadata.get('page')}")
|
| 970 |
+
if metadata.get('score') is not None:
|
| 971 |
+
st.write(f"β **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"β **Score:** {metadata.get('score')}")
|
| 972 |
+
|
| 973 |
+
# Display content preview (first 200 chars)
|
| 974 |
+
content = doc.get("content", doc.get("page_content", ""))
|
| 975 |
+
if content:
|
| 976 |
+
st.markdown("**Content Preview:**")
|
| 977 |
+
st.text_area(
|
| 978 |
+
"Content Preview",
|
| 979 |
+
value=content[:200] + ("..." if len(content) > 200 else ""),
|
| 980 |
+
height=100,
|
| 981 |
+
disabled=True,
|
| 982 |
+
label_visibility="collapsed",
|
| 983 |
+
key=f"retrieval_{idx}_doc_{doc_idx}_preview"
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
if doc_idx < len(docs_retrieved):
|
| 987 |
+
st.markdown("---")
|
| 988 |
+
else:
|
| 989 |
+
st.info("No documents retrieved")
|
| 990 |
+
|
| 991 |
# Display summary stats
|
| 992 |
+
st.markdown("**π Summary:**")
|
| 993 |
st.json({
|
| 994 |
+
"conversation_length": len(conversation_up_to),
|
| 995 |
+
"documents_retrieved": len(docs_retrieved)
|
| 996 |
})
|
| 997 |
+
|
| 998 |
+
if idx < len(st.session_state.rag_retrieval_history):
|
| 999 |
+
st.markdown("---")
|
| 1000 |
|
| 1001 |
# Example Questions Section
|
| 1002 |
st.markdown("---")
|
|
|
|
| 1017 |
st.markdown(f"**Example:** `{example_q1}`")
|
| 1018 |
st.info("π‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
|
| 1019 |
with col2:
|
| 1020 |
+
if st.button("π Use This Question", key="use_example_1", width='stretch'):
|
| 1021 |
st.session_state.pending_question = example_q1
|
| 1022 |
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1023 |
st.rerun()
|
|
|
|
| 1038 |
)
|
| 1039 |
col1, col2 = st.columns([1, 4])
|
| 1040 |
with col1:
|
| 1041 |
+
if st.button("π Use Question 2", key="use_custom_1", width='stretch'):
|
| 1042 |
if custom_q1.strip():
|
| 1043 |
st.session_state.pending_question = custom_q1.strip()
|
| 1044 |
st.session_state.custom_question_1 = custom_q1.strip()
|
|
|
|
| 1064 |
)
|
| 1065 |
col1, col2 = st.columns([1, 4])
|
| 1066 |
with col1:
|
| 1067 |
+
if st.button("π Use Question 3", key="use_custom_2", width='stretch'):
|
| 1068 |
if custom_q2.strip():
|
| 1069 |
st.session_state.pending_question = custom_q2.strip()
|
| 1070 |
st.session_state.custom_question_2 = custom_q2.strip()
|
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent modules for chatbot implementations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .smart_chatbot import get_chatbot as get_smart_chatbot
|
| 6 |
+
from .multi_agent_chatbot import get_multi_agent_chatbot
|
| 7 |
+
from .gemini_chatbot import get_gemini_chatbot
|
| 8 |
+
|
| 9 |
+
__all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
|
| 10 |
+
|
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Chatbot (Beta Version)
|
| 3 |
+
|
| 4 |
+
This chatbot uses Google Gemini File Search API for RAG.
|
| 5 |
+
It provides a simpler architecture: Main Agent + Gemini Agent
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import traceback
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Any, Optional, TypedDict
|
| 15 |
+
|
| 16 |
+
from langgraph.graph import StateGraph, END
|
| 17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 19 |
+
|
| 20 |
+
from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 21 |
+
from src.config.paths import CONVERSATIONS_DIR
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GeminiState(TypedDict):
|
| 28 |
+
"""State for Gemini chatbot conversation flow"""
|
| 29 |
+
conversation_id: str
|
| 30 |
+
messages: List[Any]
|
| 31 |
+
current_query: str
|
| 32 |
+
query_context: Optional[Dict[str, Any]]
|
| 33 |
+
gemini_result: Optional[GeminiFileSearchResult]
|
| 34 |
+
final_response: Optional[str]
|
| 35 |
+
agent_logs: List[str]
|
| 36 |
+
conversation_context: Dict[str, Any]
|
| 37 |
+
session_start_time: float
|
| 38 |
+
last_ai_message_time: float
|
| 39 |
+
filters: Optional[Dict[str, Any]]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GeminiRAGChatbot:
|
| 43 |
+
"""Gemini File Search RAG chatbot (Beta version)"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
"""Initialize the Gemini chatbot"""
|
| 47 |
+
logger.info("π€ INITIALIZING: Gemini File Search Chatbot (Beta)")
|
| 48 |
+
|
| 49 |
+
# Initialize Gemini File Search client
|
| 50 |
+
try:
|
| 51 |
+
self.gemini_client = GeminiFileSearchClient()
|
| 52 |
+
logger.info("β
Gemini File Search client initialized")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"β Failed to initialize Gemini client: {e}")
|
| 55 |
+
raise RuntimeError(f"Gemini client initialization failed: {e}")
|
| 56 |
+
|
| 57 |
+
# Build the LangGraph with LangSmith tracing if enabled
|
| 58 |
+
self.graph = self._build_graph()
|
| 59 |
+
|
| 60 |
+
# Enable LangSmith tracing if configured
|
| 61 |
+
langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
| 62 |
+
if langsmith_enabled:
|
| 63 |
+
logger.info("π LangSmith tracing enabled")
|
| 64 |
+
langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
|
| 65 |
+
logger.info(f"π LangSmith project: {langsmith_project}")
|
| 66 |
+
|
| 67 |
+
# Conversations directory
|
| 68 |
+
self.conversations_dir = CONVERSATIONS_DIR
|
| 69 |
+
try:
|
| 70 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 71 |
+
except (PermissionError, OSError) as e:
|
| 72 |
+
logger.warning(f"Could not create conversations directory: {e}")
|
| 73 |
+
self.conversations_dir = Path("conversations")
|
| 74 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
logger.info("β
Gemini File Search Chatbot initialized")
|
| 77 |
+
|
| 78 |
+
def _build_graph(self) -> StateGraph:
|
| 79 |
+
"""Build the LangGraph for Gemini chatbot"""
|
| 80 |
+
graph = StateGraph(GeminiState)
|
| 81 |
+
|
| 82 |
+
# Add nodes
|
| 83 |
+
graph.add_node("main_agent", self._main_agent)
|
| 84 |
+
graph.add_node("gemini_agent", self._gemini_agent)
|
| 85 |
+
|
| 86 |
+
# Define the flow
|
| 87 |
+
graph.set_entry_point("main_agent")
|
| 88 |
+
graph.add_edge("main_agent", "gemini_agent")
|
| 89 |
+
graph.add_edge("gemini_agent", END)
|
| 90 |
+
|
| 91 |
+
return graph.compile()
|
| 92 |
+
|
| 93 |
+
def _main_agent(self, state: GeminiState) -> GeminiState:
|
| 94 |
+
"""Main Agent: Extracts filters and prepares query"""
|
| 95 |
+
logger.info("π― MAIN AGENT: Processing query")
|
| 96 |
+
|
| 97 |
+
query = state["current_query"]
|
| 98 |
+
messages = state["messages"]
|
| 99 |
+
|
| 100 |
+
# Extract UI filters if present in query
|
| 101 |
+
ui_filters = self._extract_ui_filters(query)
|
| 102 |
+
|
| 103 |
+
# Extract context from conversation
|
| 104 |
+
context = self._extract_context_from_conversation(messages, ui_filters)
|
| 105 |
+
|
| 106 |
+
# Store context and filters
|
| 107 |
+
state["query_context"] = context
|
| 108 |
+
state["filters"] = context.get("filters", {})
|
| 109 |
+
|
| 110 |
+
logger.info(f"π― MAIN AGENT: Filters extracted: {state['filters']}")
|
| 111 |
+
|
| 112 |
+
return state
|
| 113 |
+
|
| 114 |
+
def _gemini_agent(self, state: GeminiState) -> GeminiState:
|
| 115 |
+
"""Gemini Agent: Performs file search and generates response"""
|
| 116 |
+
logger.info("π GEMINI AGENT: Starting file search")
|
| 117 |
+
|
| 118 |
+
query = state["current_query"]
|
| 119 |
+
filters = state.get("filters", {})
|
| 120 |
+
|
| 121 |
+
# Perform Gemini file search
|
| 122 |
+
try:
|
| 123 |
+
result = self.gemini_client.search(query=query, filters=filters)
|
| 124 |
+
logger.info(f"β
GEMINI AGENT: Search completed, {len(result.sources)} sources found")
|
| 125 |
+
|
| 126 |
+
# Enhance response with document references
|
| 127 |
+
enhanced_response = self._enhance_response_with_references(
|
| 128 |
+
result.answer,
|
| 129 |
+
result.sources,
|
| 130 |
+
query
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
state["gemini_result"] = result
|
| 134 |
+
state["final_response"] = enhanced_response
|
| 135 |
+
state["last_ai_message_time"] = time.time()
|
| 136 |
+
|
| 137 |
+
state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"β GEMINI AGENT ERROR: {e}")
|
| 141 |
+
traceback.print_exc()
|
| 142 |
+
state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
|
| 143 |
+
state["last_ai_message_time"] = time.time()
|
| 144 |
+
|
| 145 |
+
return state
|
| 146 |
+
|
| 147 |
+
def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
|
| 148 |
+
"""Enhance Gemini response to include document references and format nicely"""
|
| 149 |
+
if not sources or not answer:
|
| 150 |
+
return answer
|
| 151 |
+
|
| 152 |
+
# Use LLM to intelligently add document references and format nicely
|
| 153 |
+
try:
|
| 154 |
+
from src.llm.adapters import get_llm_client
|
| 155 |
+
llm = get_llm_client()
|
| 156 |
+
|
| 157 |
+
# Prepare document summaries for the LLM
|
| 158 |
+
doc_summaries = []
|
| 159 |
+
for idx, doc in enumerate(sources, 1):
|
| 160 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 161 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 162 |
+
|
| 163 |
+
filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 164 |
+
year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 165 |
+
source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 166 |
+
district = metadata.get('district', '') if isinstance(metadata, dict) else ''
|
| 167 |
+
|
| 168 |
+
doc_info = f"{filename}"
|
| 169 |
+
if year and year != 'Unknown':
|
| 170 |
+
doc_info += f" ({year})"
|
| 171 |
+
if source and source != 'Unknown':
|
| 172 |
+
doc_info += f" - {source}"
|
| 173 |
+
if district:
|
| 174 |
+
doc_info += f" - {district}"
|
| 175 |
+
|
| 176 |
+
doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...")
|
| 177 |
+
|
| 178 |
+
prompt = f"""You are enhancing a response from a document search system. The original response is:
|
| 179 |
+
|
| 180 |
+
{answer}
|
| 181 |
+
|
| 182 |
+
The following documents were retrieved and used to generate this response:
|
| 183 |
+
|
| 184 |
+
{chr(10).join(doc_summaries)}
|
| 185 |
+
|
| 186 |
+
CRITICAL RULES:
|
| 187 |
+
1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate
|
| 188 |
+
2. The response should ONLY contain information from the retrieved documents listed above
|
| 189 |
+
3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
|
| 190 |
+
4. Add document references [Doc i] at the end of sentences that use information from specific documents
|
| 191 |
+
5. Only reference documents that are actually used in the response
|
| 192 |
+
6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
|
| 193 |
+
7. Keep the response natural, conversational, and well-formatted
|
| 194 |
+
8. Use proper formatting: paragraphs, line breaks, and structure for readability
|
| 195 |
+
9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting
|
| 196 |
+
10. If multiple documents support the same claim, use [Doc i, Doc j] format
|
| 197 |
+
11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
|
| 198 |
+
|
| 199 |
+
Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary."""
|
| 200 |
+
|
| 201 |
+
enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
|
| 202 |
+
|
| 203 |
+
# Fallback: if LLM fails, just return original with basic formatting
|
| 204 |
+
if not enhanced or len(enhanced) < len(answer) * 0.5:
|
| 205 |
+
logger.warning("LLM enhancement failed, using original response with basic formatting")
|
| 206 |
+
# Basic formatting: add line breaks after periods for readability
|
| 207 |
+
formatted = answer.replace('. ', '.\n\n')
|
| 208 |
+
if sources:
|
| 209 |
+
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
|
| 210 |
+
formatted += f"\n\n*Based on documents: {ref_list}*"
|
| 211 |
+
return formatted
|
| 212 |
+
|
| 213 |
+
return enhanced
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.warning(f"Failed to enhance response with references: {e}")
|
| 217 |
+
# Fallback: add basic formatting and references at the end
|
| 218 |
+
formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting
|
| 219 |
+
if sources:
|
| 220 |
+
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
|
| 221 |
+
formatted += f"\n\n*Based on documents: {ref_list}*"
|
| 222 |
+
return formatted
|
| 223 |
+
|
| 224 |
+
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
|
| 225 |
+
"""Extract UI filters from query if present"""
|
| 226 |
+
filters = {}
|
| 227 |
+
|
| 228 |
+
if "FILTER CONTEXT:" in query:
|
| 229 |
+
filter_section = query.split("FILTER CONTEXT:")[1]
|
| 230 |
+
if "USER QUERY:" in filter_section:
|
| 231 |
+
filter_section = filter_section.split("USER QUERY:")[0]
|
| 232 |
+
filter_section = filter_section.strip()
|
| 233 |
+
|
| 234 |
+
if "Sources:" in filter_section:
|
| 235 |
+
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
|
| 236 |
+
if sources_line:
|
| 237 |
+
sources_str = sources_line[0].split("Sources:")[1].strip()
|
| 238 |
+
if sources_str and sources_str != "None":
|
| 239 |
+
filters["sources"] = [s.strip() for s in sources_str.split(",")]
|
| 240 |
+
|
| 241 |
+
if "Years:" in filter_section:
|
| 242 |
+
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
|
| 243 |
+
if years_line:
|
| 244 |
+
years_str = years_line[0].split("Years:")[1].strip()
|
| 245 |
+
if years_str and years_str != "None":
|
| 246 |
+
filters["year"] = [y.strip() for y in years_str.split(",")]
|
| 247 |
+
|
| 248 |
+
if "Districts:" in filter_section:
|
| 249 |
+
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
|
| 250 |
+
if districts_line:
|
| 251 |
+
districts_str = districts_line[0].split("Districts:")[1].strip()
|
| 252 |
+
if districts_str and districts_str != "None":
|
| 253 |
+
filters["district"] = [d.strip() for d in districts_str.split(",")]
|
| 254 |
+
|
| 255 |
+
if "Filenames:" in filter_section:
|
| 256 |
+
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
|
| 257 |
+
if filenames_line:
|
| 258 |
+
filenames_str = filenames_line[0].split("Filenames:")[1].strip()
|
| 259 |
+
if filenames_str and filenames_str != "None":
|
| 260 |
+
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
|
| 261 |
+
|
| 262 |
+
return filters
|
| 263 |
+
|
| 264 |
+
def _extract_context_from_conversation(
|
| 265 |
+
self,
|
| 266 |
+
messages: List[Any],
|
| 267 |
+
ui_filters: Dict[str, List[str]]
|
| 268 |
+
) -> Dict[str, Any]:
|
| 269 |
+
"""Extract context from conversation history"""
|
| 270 |
+
# Use UI filters if available
|
| 271 |
+
filters = ui_filters.copy() if ui_filters else {}
|
| 272 |
+
|
| 273 |
+
# For Gemini, we pass filters directly to the search function
|
| 274 |
+
# The filters will be used to add context to the query
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"filters": filters,
|
| 278 |
+
"has_filters": bool(filters)
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
|
| 282 |
+
"""Main chat interface"""
|
| 283 |
+
logger.info(f"π¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
|
| 284 |
+
|
| 285 |
+
# Load conversation
|
| 286 |
+
conversation_file = self.conversations_dir / f"{conversation_id}.json"
|
| 287 |
+
conversation = self._load_conversation(conversation_file)
|
| 288 |
+
|
| 289 |
+
# Add user message
|
| 290 |
+
conversation["messages"].append(HumanMessage(content=user_input))
|
| 291 |
+
|
| 292 |
+
# Prepare state
|
| 293 |
+
state = GeminiState(
|
| 294 |
+
conversation_id=conversation_id,
|
| 295 |
+
messages=conversation["messages"],
|
| 296 |
+
current_query=user_input,
|
| 297 |
+
query_context=None,
|
| 298 |
+
gemini_result=None,
|
| 299 |
+
final_response=None,
|
| 300 |
+
agent_logs=[],
|
| 301 |
+
conversation_context=conversation.get("context", {}),
|
| 302 |
+
session_start_time=conversation["session_start_time"],
|
| 303 |
+
last_ai_message_time=conversation["last_ai_message_time"],
|
| 304 |
+
filters=None
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Run graph
|
| 308 |
+
final_state = self.graph.invoke(state)
|
| 309 |
+
|
| 310 |
+
# Add AI response to conversation
|
| 311 |
+
if final_state["final_response"]:
|
| 312 |
+
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
|
| 313 |
+
|
| 314 |
+
# Update conversation
|
| 315 |
+
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
|
| 316 |
+
conversation["context"] = final_state["conversation_context"]
|
| 317 |
+
|
| 318 |
+
# Save conversation
|
| 319 |
+
self._save_conversation(conversation_file, conversation)
|
| 320 |
+
|
| 321 |
+
# Format sources for display
|
| 322 |
+
sources = []
|
| 323 |
+
gemini_result = final_state.get("gemini_result")
|
| 324 |
+
if gemini_result:
|
| 325 |
+
sources = self.gemini_client.format_sources_for_display(gemini_result)
|
| 326 |
+
logger.info(f"π GEMINI CHAT: Formatted {len(sources)} sources for display")
|
| 327 |
+
|
| 328 |
+
return {
|
| 329 |
+
'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
|
| 330 |
+
'rag_result': {
|
| 331 |
+
'sources': sources,
|
| 332 |
+
'answer': final_state["final_response"]
|
| 333 |
+
},
|
| 334 |
+
'agent_logs': final_state["agent_logs"],
|
| 335 |
+
'actual_rag_query': final_state["current_query"],
|
| 336 |
+
'gemini_result': gemini_result # Include raw result for tracking
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
|
| 340 |
+
"""Load conversation from file"""
|
| 341 |
+
if conversation_file.exists():
|
| 342 |
+
try:
|
| 343 |
+
with open(conversation_file) as f:
|
| 344 |
+
data = json.load(f)
|
| 345 |
+
messages = []
|
| 346 |
+
for msg_data in data.get("messages", []):
|
| 347 |
+
if msg_data["type"] == "human":
|
| 348 |
+
messages.append(HumanMessage(content=msg_data["content"]))
|
| 349 |
+
elif msg_data["type"] == "ai":
|
| 350 |
+
messages.append(AIMessage(content=msg_data["content"]))
|
| 351 |
+
data["messages"] = messages
|
| 352 |
+
return data
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.warning(f"Could not load conversation: {e}")
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"messages": [],
|
| 358 |
+
"session_start_time": time.time(),
|
| 359 |
+
"last_ai_message_time": time.time(),
|
| 360 |
+
"context": {}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
|
| 364 |
+
"""Save conversation to file"""
|
| 365 |
+
try:
|
| 366 |
+
conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 367 |
+
|
| 368 |
+
messages_data = []
|
| 369 |
+
for msg in conversation["messages"]:
|
| 370 |
+
if isinstance(msg, HumanMessage):
|
| 371 |
+
messages_data.append({"type": "human", "content": msg.content})
|
| 372 |
+
elif isinstance(msg, AIMessage):
|
| 373 |
+
messages_data.append({"type": "ai", "content": msg.content})
|
| 374 |
+
|
| 375 |
+
conversation_data = {
|
| 376 |
+
"messages": messages_data,
|
| 377 |
+
"session_start_time": conversation["session_start_time"],
|
| 378 |
+
"last_ai_message_time": conversation["last_ai_message_time"],
|
| 379 |
+
"context": conversation.get("context", {})
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
with open(conversation_file, 'w') as f:
|
| 383 |
+
json.dump(conversation_data, f, indent=2)
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"Could not save conversation: {e}")
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def get_gemini_chatbot():
|
| 390 |
+
"""Get Gemini chatbot instance"""
|
| 391 |
+
return GeminiRAGChatbot()
|
| 392 |
+
|
|
@@ -208,6 +208,59 @@ class MultiAgentRAGChatbot:
|
|
| 208 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 209 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
def _build_graph(self) -> StateGraph:
|
| 212 |
"""Build the multi-agent LangGraph"""
|
| 213 |
graph = StateGraph(MultiAgentState)
|
|
@@ -512,6 +565,10 @@ class MultiAgentRAGChatbot:
|
|
| 512 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 513 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 514 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
- If no exact matches found, set extracted values to null
|
| 516 |
|
| 517 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 656 |
# Validate each district in the array
|
| 657 |
valid_districts = []
|
| 658 |
for district in extracted_district:
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
# Try removing "District" suffix
|
| 663 |
-
district_name = district.replace(" District", "").replace(" district", "")
|
| 664 |
-
if district_name in self.district_whitelist:
|
| 665 |
-
valid_districts.append(district_name)
|
| 666 |
|
| 667 |
if valid_districts:
|
| 668 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 671 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 672 |
extracted_district = None
|
| 673 |
else:
|
| 674 |
-
# Single district validation
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
extracted_district = None
|
| 684 |
|
| 685 |
# Validate source (handle both single values and arrays)
|
| 686 |
if extracted_source:
|
|
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
|
|
| 918 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 919 |
|
| 920 |
# Merge with extracted context for missing filters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
if not filters.get("year") and context.extracted_year:
|
| 922 |
# Handle both single values and arrays
|
| 923 |
if isinstance(context.extracted_year, list):
|
|
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
|
|
| 926 |
filters["year"] = [context.extracted_year]
|
| 927 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 928 |
|
| 929 |
-
if not filters.get("district") and context.extracted_district:
|
| 930 |
-
# Handle both single values and arrays
|
| 931 |
-
if isinstance(context.extracted_district, list):
|
| 932 |
-
# Normalize district names to title case (match Qdrant metadata format)
|
| 933 |
-
normalized = [d.title() for d in context.extracted_district]
|
| 934 |
-
filters["district"] = normalized
|
| 935 |
-
else:
|
| 936 |
-
filters["district"] = [context.extracted_district.title()]
|
| 937 |
-
logger.info(f"π§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
|
| 938 |
-
|
| 939 |
if not filters.get("sources") and context.extracted_source:
|
| 940 |
# Handle both single values and arrays
|
| 941 |
if isinstance(context.extracted_source, list):
|
|
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
|
|
| 963 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 964 |
|
| 965 |
if context.extracted_district:
|
| 966 |
-
#
|
| 967 |
if isinstance(context.extracted_district, list):
|
| 968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
else:
|
| 970 |
-
|
| 971 |
-
|
|
|
|
|
|
|
| 972 |
|
| 973 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 974 |
return filters
|
|
@@ -978,49 +1046,212 @@ Rewrite the best retrieval query:""")
|
|
| 978 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 979 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 980 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
# Create response prompt
|
| 983 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 984 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 985 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 986 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 987 |
RULES:
|
| 988 |
1. Answer the user's question directly and clearly
|
| 989 |
-
2. Use the retrieved documents as evidence
|
| 990 |
3. Be conversational, not technical
|
| 991 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 992 |
5. If relevant documents were found, reference them naturally
|
| 993 |
-
6. If no relevant documents,
|
| 994 |
-
7. If the passages have useful facts or numbers, use them in your answer
|
| 995 |
-
8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 996 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 997 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 998 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 999 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1000 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1001 |
-
14.
|
| 1002 |
-
|
|
|
|
| 1003 |
|
| 1004 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1005 |
-
HumanMessage(content=f"""
|
|
|
|
|
|
|
|
|
|
| 1006 |
|
| 1007 |
Retrieved Documents: {len(documents)} documents found
|
| 1008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
RAG Answer: {rag_answer}
|
| 1010 |
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
])
|
| 1013 |
|
| 1014 |
try:
|
| 1015 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1016 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1017 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
except Exception as e:
|
| 1020 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1021 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1022 |
return rag_answer # Fallback to RAG answer
|
| 1023 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1025 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1026 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
|
|
|
| 208 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 209 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 210 |
|
| 211 |
+
def _normalize_district_name(self, district: str) -> Optional[str]:
|
| 212 |
+
"""Normalize district name with fuzzy matching for common misspellings."""
|
| 213 |
+
if not district:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
district = district.strip()
|
| 217 |
+
|
| 218 |
+
# Direct match
|
| 219 |
+
if district in self.district_whitelist:
|
| 220 |
+
return district
|
| 221 |
+
|
| 222 |
+
# Remove "District" suffix
|
| 223 |
+
district_name = district.replace(" District", "").replace(" district", "").strip()
|
| 224 |
+
if district_name in self.district_whitelist:
|
| 225 |
+
return district_name
|
| 226 |
+
|
| 227 |
+
# Common misspellings mapping
|
| 228 |
+
misspelling_map = {
|
| 229 |
+
"kalagala": "Kalangala",
|
| 230 |
+
"Kalagala": "Kalangala",
|
| 231 |
+
"KALAGALA": "Kalangala",
|
| 232 |
+
"kalangala": "Kalangala",
|
| 233 |
+
"gulu": "Gulu",
|
| 234 |
+
"GULU": "Gulu",
|
| 235 |
+
"kampala": "Kampala",
|
| 236 |
+
"KAMPALA": "Kampala",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# Check misspelling map (case-insensitive)
|
| 240 |
+
district_lower = district_name.lower()
|
| 241 |
+
if district_lower in misspelling_map:
|
| 242 |
+
corrected = misspelling_map[district_lower]
|
| 243 |
+
if corrected in self.district_whitelist:
|
| 244 |
+
return corrected
|
| 245 |
+
|
| 246 |
+
# Fuzzy matching for similar names (simple Levenshtein-like check)
|
| 247 |
+
# Check if the district name is very similar to any whitelist entry
|
| 248 |
+
for whitelist_district in self.district_whitelist:
|
| 249 |
+
# Case-insensitive comparison
|
| 250 |
+
if district_name.lower() == whitelist_district.lower():
|
| 251 |
+
return whitelist_district
|
| 252 |
+
|
| 253 |
+
# Check if one is a substring of the other (for partial matches)
|
| 254 |
+
if len(district_name) >= 4 and len(whitelist_district) >= 4:
|
| 255 |
+
if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
|
| 256 |
+
# Only return if it's a strong match (at least 80% of characters match)
|
| 257 |
+
min_len = min(len(district_name), len(whitelist_district))
|
| 258 |
+
max_len = max(len(district_name), len(whitelist_district))
|
| 259 |
+
if min_len / max_len >= 0.8:
|
| 260 |
+
return whitelist_district
|
| 261 |
+
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
def _build_graph(self) -> StateGraph:
|
| 265 |
"""Build the multi-agent LangGraph"""
|
| 266 |
graph = StateGraph(MultiAgentState)
|
|
|
|
| 565 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 566 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 567 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
| 568 |
+
- **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
|
| 569 |
+
* "Kalagala" (missing 'n') should be extracted as "Kalangala"
|
| 570 |
+
* "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
|
| 571 |
+
* Similar case-insensitive variations should be normalized to the correct district name
|
| 572 |
- If no exact matches found, set extracted values to null
|
| 573 |
|
| 574 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
|
|
| 713 |
# Validate each district in the array
|
| 714 |
valid_districts = []
|
| 715 |
for district in extracted_district:
|
| 716 |
+
normalized = self._normalize_district_name(district)
|
| 717 |
+
if normalized:
|
| 718 |
+
valid_districts.append(normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
if valid_districts:
|
| 721 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
|
|
| 724 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 725 |
extracted_district = None
|
| 726 |
else:
|
| 727 |
+
# Single district validation with fuzzy matching
|
| 728 |
+
normalized = self._normalize_district_name(extracted_district)
|
| 729 |
+
if normalized:
|
| 730 |
+
if normalized != extracted_district:
|
| 731 |
+
logger.info(f"π QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
|
| 732 |
+
extracted_district = normalized
|
| 733 |
+
else:
|
| 734 |
+
logger.warning(f"β οΈ Invalid district extracted: '{extracted_district}' not in whitelist")
|
| 735 |
+
extracted_district = None
|
|
|
|
| 736 |
|
| 737 |
# Validate source (handle both single values and arrays)
|
| 738 |
if extracted_source:
|
|
|
|
| 970 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 971 |
|
| 972 |
# Merge with extracted context for missing filters
|
| 973 |
+
if not filters.get("district") and context.extracted_district:
|
| 974 |
+
# Normalize district names using the normalization function
|
| 975 |
+
if isinstance(context.extracted_district, list):
|
| 976 |
+
normalized_districts = []
|
| 977 |
+
for d in context.extracted_district:
|
| 978 |
+
normalized = self._normalize_district_name(d)
|
| 979 |
+
if normalized:
|
| 980 |
+
normalized_districts.append(normalized)
|
| 981 |
+
if normalized_districts:
|
| 982 |
+
filters["district"] = normalized_districts
|
| 983 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 984 |
+
else:
|
| 985 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 986 |
+
if normalized:
|
| 987 |
+
filters["district"] = [normalized]
|
| 988 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 989 |
+
|
| 990 |
if not filters.get("year") and context.extracted_year:
|
| 991 |
# Handle both single values and arrays
|
| 992 |
if isinstance(context.extracted_year, list):
|
|
|
|
| 995 |
filters["year"] = [context.extracted_year]
|
| 996 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
if not filters.get("sources") and context.extracted_source:
|
| 999 |
# Handle both single values and arrays
|
| 1000 |
if isinstance(context.extracted_source, list):
|
|
|
|
| 1022 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 1023 |
|
| 1024 |
if context.extracted_district:
|
| 1025 |
+
# Normalize district names using the normalization function
|
| 1026 |
if isinstance(context.extracted_district, list):
|
| 1027 |
+
normalized_districts = []
|
| 1028 |
+
for d in context.extracted_district:
|
| 1029 |
+
normalized = self._normalize_district_name(d)
|
| 1030 |
+
if normalized:
|
| 1031 |
+
normalized_districts.append(normalized)
|
| 1032 |
+
if normalized_districts:
|
| 1033 |
+
filters["district"] = normalized_districts
|
| 1034 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 1035 |
else:
|
| 1036 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 1037 |
+
if normalized:
|
| 1038 |
+
filters["district"] = [normalized]
|
| 1039 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 1040 |
|
| 1041 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 1042 |
return filters
|
|
|
|
| 1046 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 1047 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 1048 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
| 1049 |
+
logger.info(f"π¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
|
| 1050 |
+
|
| 1051 |
+
# Build conversation history context
|
| 1052 |
+
conversation_context = self._build_conversation_context(messages)
|
| 1053 |
+
|
| 1054 |
+
# Build detailed document information
|
| 1055 |
+
document_details = self._build_document_details(documents)
|
| 1056 |
+
|
| 1057 |
+
# Extract correct district/source/year names from documents (to correct misspellings)
|
| 1058 |
+
correct_names = self._extract_correct_names_from_documents(documents)
|
| 1059 |
|
| 1060 |
# Create response prompt
|
| 1061 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 1062 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 1063 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 1064 |
|
| 1065 |
+
CRITICAL RULES - NO HALLUCINATION:
|
| 1066 |
+
1. **ONLY use information from the retrieved documents provided below**
|
| 1067 |
+
2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
|
| 1068 |
+
3. **If a document doesn't contain the information, DO NOT make it up**
|
| 1069 |
+
4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
|
| 1070 |
+
5. **Check the document years/districts before making any claims about them**
|
| 1071 |
+
6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
|
| 1072 |
+
|
| 1073 |
RULES:
|
| 1074 |
1. Answer the user's question directly and clearly
|
| 1075 |
+
2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
|
| 1076 |
3. Be conversational, not technical
|
| 1077 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 1078 |
5. If relevant documents were found, reference them naturally
|
| 1079 |
+
6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
|
| 1080 |
+
7. If the passages have useful facts or numbers, use them in your answer WITH references
|
| 1081 |
+
8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 1082 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 1083 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 1084 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 1085 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1086 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1087 |
+
14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
|
| 1088 |
+
15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
|
| 1089 |
+
16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
|
| 1090 |
|
| 1091 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1092 |
+
HumanMessage(content=f"""Conversation History:
|
| 1093 |
+
{conversation_context}
|
| 1094 |
+
|
| 1095 |
+
Current User Question: {query}
|
| 1096 |
|
| 1097 |
Retrieved Documents: {len(documents)} documents found
|
| 1098 |
|
| 1099 |
+
CORRECT NAMES TO USE (from document metadata - use these exact spellings):
|
| 1100 |
+
{correct_names}
|
| 1101 |
+
|
| 1102 |
+
Full Document Details:
|
| 1103 |
+
{document_details}
|
| 1104 |
+
|
| 1105 |
RAG Answer: {rag_answer}
|
| 1106 |
|
| 1107 |
+
CRITICAL:
|
| 1108 |
+
- Responses should be grounded to what is available in the retrieved documents
|
| 1109 |
+
- If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
|
| 1110 |
+
- Every factual claim MUST have [Doc i] reference
|
| 1111 |
+
- If information is not in documents, explicitly state it's not available
|
| 1112 |
+
- **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
|
| 1113 |
+
|
| 1114 |
+
Generate a conversational response with proper document references:""")
|
| 1115 |
])
|
| 1116 |
|
| 1117 |
try:
|
| 1118 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1119 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1120 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1121 |
+
|
| 1122 |
+
# Post-process response to ensure no hallucination
|
| 1123 |
+
final_response = self._validate_and_enhance_response(
|
| 1124 |
+
response.content.strip(),
|
| 1125 |
+
documents,
|
| 1126 |
+
query
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
return final_response
|
| 1130 |
except Exception as e:
|
| 1131 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1132 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1133 |
return rag_answer # Fallback to RAG answer
|
| 1134 |
|
| 1135 |
+
def _build_conversation_context(self, messages: List[Any]) -> str:
|
| 1136 |
+
"""Build conversation history context for response generation."""
|
| 1137 |
+
if not messages:
|
| 1138 |
+
return "No previous conversation."
|
| 1139 |
+
|
| 1140 |
+
context_lines = []
|
| 1141 |
+
# Show last 6 messages for context (to capture the current exchange)
|
| 1142 |
+
for msg in messages[-6:]:
|
| 1143 |
+
if isinstance(msg, HumanMessage):
|
| 1144 |
+
context_lines.append(f"User: {msg.content}")
|
| 1145 |
+
elif isinstance(msg, AIMessage):
|
| 1146 |
+
context_lines.append(f"Assistant: {msg.content}")
|
| 1147 |
+
|
| 1148 |
+
return "\n".join(context_lines) if context_lines else "No previous conversation."
|
| 1149 |
+
|
| 1150 |
+
def _build_document_details(self, documents: List[Any]) -> str:
|
| 1151 |
+
"""Build detailed document information for response generation."""
|
| 1152 |
+
if not documents:
|
| 1153 |
+
return "No documents retrieved."
|
| 1154 |
+
|
| 1155 |
+
details = []
|
| 1156 |
+
for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
|
| 1157 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1158 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 1159 |
+
|
| 1160 |
+
if isinstance(metadata, dict):
|
| 1161 |
+
filename = metadata.get('filename', 'Unknown')
|
| 1162 |
+
year = metadata.get('year', 'Unknown')
|
| 1163 |
+
district = metadata.get('district', 'Unknown')
|
| 1164 |
+
source = metadata.get('source', 'Unknown')
|
| 1165 |
+
page = metadata.get('page', metadata.get('page_label', 'Unknown'))
|
| 1166 |
+
|
| 1167 |
+
doc_info = f"[Doc {i}]"
|
| 1168 |
+
doc_info += f"\n Filename: {filename}"
|
| 1169 |
+
doc_info += f"\n Year: {year}"
|
| 1170 |
+
doc_info += f"\n District: {district}"
|
| 1171 |
+
doc_info += f"\n Source: {source}"
|
| 1172 |
+
if page != 'Unknown':
|
| 1173 |
+
doc_info += f"\n Page: {page}"
|
| 1174 |
+
doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
|
| 1175 |
+
details.append(doc_info)
|
| 1176 |
+
|
| 1177 |
+
return "\n\n".join(details) if details else "No document details available."
|
| 1178 |
+
|
| 1179 |
+
def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
|
| 1180 |
+
"""Extract correct district/source names from documents to correct misspellings."""
|
| 1181 |
+
districts = set()
|
| 1182 |
+
sources = set()
|
| 1183 |
+
years = set()
|
| 1184 |
+
|
| 1185 |
+
for doc in documents:
|
| 1186 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1187 |
+
if isinstance(metadata, dict):
|
| 1188 |
+
if metadata.get('district'):
|
| 1189 |
+
districts.add(str(metadata['district']))
|
| 1190 |
+
if metadata.get('source'):
|
| 1191 |
+
sources.add(str(metadata['source']))
|
| 1192 |
+
if metadata.get('year'):
|
| 1193 |
+
years.add(str(metadata['year']))
|
| 1194 |
+
|
| 1195 |
+
result = []
|
| 1196 |
+
if districts:
|
| 1197 |
+
result.append(f"Districts: {', '.join(sorted(districts))}")
|
| 1198 |
+
if sources:
|
| 1199 |
+
result.append(f"Sources: {', '.join(sorted(sources))}")
|
| 1200 |
+
if years:
|
| 1201 |
+
result.append(f"Years: {', '.join(sorted(years))}")
|
| 1202 |
+
|
| 1203 |
+
if result:
|
| 1204 |
+
return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
|
| 1205 |
+
return "No metadata available."
|
| 1206 |
+
|
| 1207 |
+
def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
|
| 1208 |
+
"""Validate response and ensure all claims are referenced."""
|
| 1209 |
+
# Extract years and districts from documents
|
| 1210 |
+
doc_years = set()
|
| 1211 |
+
doc_districts = set()
|
| 1212 |
+
doc_sources = set()
|
| 1213 |
+
|
| 1214 |
+
for doc in documents:
|
| 1215 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1216 |
+
if isinstance(metadata, dict):
|
| 1217 |
+
if metadata.get('year'):
|
| 1218 |
+
doc_years.add(str(metadata['year']))
|
| 1219 |
+
if metadata.get('district'):
|
| 1220 |
+
doc_districts.add(str(metadata['district']))
|
| 1221 |
+
if metadata.get('source'):
|
| 1222 |
+
doc_sources.add(str(metadata['source']))
|
| 1223 |
+
|
| 1224 |
+
# Correct misspellings in response using correct names from documents
|
| 1225 |
+
# response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
|
| 1226 |
+
|
| 1227 |
+
# Check if response mentions years not in documents
|
| 1228 |
+
year_pattern = r'\b(20\d{2})\b'
|
| 1229 |
+
mentioned_years = set(re.findall(year_pattern, response))
|
| 1230 |
+
|
| 1231 |
+
# Check if user query mentions a year
|
| 1232 |
+
query_years = set(re.findall(year_pattern, query))
|
| 1233 |
+
|
| 1234 |
+
# If user asks about a year not in documents, add a warning
|
| 1235 |
+
missing_years = query_years - doc_years
|
| 1236 |
+
if missing_years and doc_years:
|
| 1237 |
+
warning = f"\n\nβ οΈ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
|
| 1238 |
+
if warning not in response:
|
| 1239 |
+
response = response + warning
|
| 1240 |
+
|
| 1241 |
+
# Check if response has document references
|
| 1242 |
+
doc_ref_pattern = r'\[Doc\s+\d+\]'
|
| 1243 |
+
has_refs = bool(re.search(doc_ref_pattern, response))
|
| 1244 |
+
|
| 1245 |
+
# If response has factual claims but no references, add a note
|
| 1246 |
+
if not has_refs and len(documents) > 0:
|
| 1247 |
+
# Check if response has numbers or specific claims (simple heuristic)
|
| 1248 |
+
has_numbers = bool(re.search(r'\d+', response))
|
| 1249 |
+
if has_numbers and len(response) > 50:
|
| 1250 |
+
logger.warning("β οΈ Response contains factual claims but no document references")
|
| 1251 |
+
# Don't modify response, but log the issue
|
| 1252 |
+
|
| 1253 |
+
return response
|
| 1254 |
+
|
| 1255 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1256 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1257 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
|
File without changes
|
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Management Module
|
| 3 |
+
|
| 4 |
+
This module provides a unified interface for handling user feedback,
|
| 5 |
+
including data preparation, validation, and Snowflake storage.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Any, List, Optional
|
| 9 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 10 |
+
|
| 11 |
+
from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
|
| 12 |
+
from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FeedbackManager:
|
| 16 |
+
"""
|
| 17 |
+
Unified manager for feedback operations.
|
| 18 |
+
|
| 19 |
+
This class provides a single interface for all feedback-related functionality,
|
| 20 |
+
including data preparation, validation, and storage.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize the FeedbackManager"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
|
| 29 |
+
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
|
| 30 |
+
transcript = []
|
| 31 |
+
for msg in messages:
|
| 32 |
+
if isinstance(msg, HumanMessage):
|
| 33 |
+
transcript.append({
|
| 34 |
+
"role": "user",
|
| 35 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 36 |
+
})
|
| 37 |
+
elif isinstance(msg, AIMessage):
|
| 38 |
+
transcript.append({
|
| 39 |
+
"role": "assistant",
|
| 40 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 41 |
+
})
|
| 42 |
+
return transcript
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
|
| 46 |
+
"""Build retrievals structure from retrieval history"""
|
| 47 |
+
retrievals = []
|
| 48 |
+
|
| 49 |
+
for entry in rag_retrieval_history:
|
| 50 |
+
# Get the user message that triggered this retrieval
|
| 51 |
+
# The entry has conversation_up_to which includes messages up to that point
|
| 52 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 53 |
+
|
| 54 |
+
# Find the last user message in conversation_up_to (this is the trigger)
|
| 55 |
+
user_message_trigger = ""
|
| 56 |
+
for msg_dict in reversed(conversation_up_to):
|
| 57 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 58 |
+
user_message_trigger = msg_dict.get("content", "")
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
# Fallback: if not found in conversation_up_to, get from actual messages
|
| 62 |
+
# This handles edge cases where conversation_up_to might be incomplete
|
| 63 |
+
if not user_message_trigger:
|
| 64 |
+
# Find which retrieval this is (0-indexed)
|
| 65 |
+
retrieval_idx = rag_retrieval_history.index(entry)
|
| 66 |
+
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
|
| 67 |
+
# because each retrieval is preceded by: user message, bot response, user message, ...
|
| 68 |
+
# But we need to account for the fact that the first retrieval happens after the first user message
|
| 69 |
+
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
|
| 70 |
+
if retrieval_idx < len(user_msgs):
|
| 71 |
+
user_message_trigger = str(user_msgs[retrieval_idx].content)
|
| 72 |
+
elif user_msgs:
|
| 73 |
+
# Fallback to last user message
|
| 74 |
+
user_message_trigger = str(user_msgs[-1].content)
|
| 75 |
+
|
| 76 |
+
# Get retrieved documents and truncate content to 100 chars
|
| 77 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 78 |
+
retrieved_docs = []
|
| 79 |
+
for doc in docs_retrieved:
|
| 80 |
+
doc_copy = doc.copy()
|
| 81 |
+
# Truncate content to 100 characters (keep all other fields)
|
| 82 |
+
if "content" in doc_copy:
|
| 83 |
+
doc_copy["content"] = doc_copy["content"][:100]
|
| 84 |
+
retrieved_docs.append(doc_copy)
|
| 85 |
+
|
| 86 |
+
retrievals.append({
|
| 87 |
+
"retrieved_docs": retrieved_docs,
|
| 88 |
+
"user_message_trigger": user_message_trigger
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
return retrievals
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def build_feedback_score_related_retrieval_docs(
|
| 95 |
+
is_feedback_about_last_retrieval: bool,
|
| 96 |
+
messages: List[Any],
|
| 97 |
+
rag_retrieval_history: List[Dict[str, Any]]
|
| 98 |
+
) -> Optional[Dict[str, Any]]:
|
| 99 |
+
"""Build feedback_score_related_retrieval_docs structure"""
|
| 100 |
+
if not rag_retrieval_history:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Get the relevant retrieval entry
|
| 104 |
+
if is_feedback_about_last_retrieval:
|
| 105 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 106 |
+
else:
|
| 107 |
+
# If feedback is about all retrievals, use the last one as default
|
| 108 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 109 |
+
|
| 110 |
+
# Get conversation up to that point
|
| 111 |
+
conversation_up_to = relevant_entry.get("conversation_up_to", [])
|
| 112 |
+
|
| 113 |
+
# Convert to transcript format (role/content)
|
| 114 |
+
conversation_up_to_point = []
|
| 115 |
+
for msg_dict in conversation_up_to:
|
| 116 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 117 |
+
conversation_up_to_point.append({
|
| 118 |
+
"role": "user",
|
| 119 |
+
"content": msg_dict.get("content", "")
|
| 120 |
+
})
|
| 121 |
+
elif msg_dict.get("type") == "AIMessage":
|
| 122 |
+
conversation_up_to_point.append({
|
| 123 |
+
"role": "assistant",
|
| 124 |
+
"content": msg_dict.get("content", "")
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# Get retrieved docs with full content (not truncated)
|
| 128 |
+
retrieved_docs = relevant_entry.get("docs_retrieved", [])
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"conversation_up_to_point": conversation_up_to_point,
|
| 132 |
+
"retrieved_docs": retrieved_docs
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 137 |
+
"""Create UserFeedback instance from dictionary"""
|
| 138 |
+
return create_feedback_from_dict(data)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 142 |
+
"""Save feedback to Snowflake"""
|
| 143 |
+
return save_to_snowflake(feedback, table_name)
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 147 |
+
"""Generate Snowflake schema SQL"""
|
| 148 |
+
return generate_snowflake_schema_sql(table_name)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
|
| 152 |
+
|
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Schema for RAG Chatbot
|
| 3 |
+
|
| 4 |
+
This module defines dataclasses for feedback data structures
|
| 5 |
+
and provides Snowflake schema generation.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from dataclasses import dataclass, asdict, field
|
| 10 |
+
from typing import List, Optional, Dict, Any, Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class RetrievedDocument:
|
| 16 |
+
"""Single retrieved document metadata"""
|
| 17 |
+
doc_id: str
|
| 18 |
+
filename: str
|
| 19 |
+
page: int
|
| 20 |
+
score: float
|
| 21 |
+
content: str
|
| 22 |
+
metadata: Dict[str, Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class RetrievalEntry:
|
| 27 |
+
"""Single retrieval operation metadata"""
|
| 28 |
+
rag_query: str
|
| 29 |
+
documents_retrieved: List[RetrievedDocument]
|
| 30 |
+
conversation_length: int
|
| 31 |
+
filters_applied: Optional[Dict[str, Any]] = None
|
| 32 |
+
timestamp: Optional[float] = None
|
| 33 |
+
_raw_data: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class UserFeedback:
|
| 38 |
+
"""User feedback submission data"""
|
| 39 |
+
feedback_id: str
|
| 40 |
+
open_ended_feedback: Optional[str]
|
| 41 |
+
score: int
|
| 42 |
+
is_feedback_about_last_retrieval: bool
|
| 43 |
+
conversation_id: str
|
| 44 |
+
timestamp: float
|
| 45 |
+
message_count: int
|
| 46 |
+
has_retrievals: bool
|
| 47 |
+
retrieval_count: int
|
| 48 |
+
transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
|
| 49 |
+
retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
|
| 50 |
+
feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
|
| 51 |
+
retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
|
| 52 |
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 53 |
+
|
| 54 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 55 |
+
"""Convert to dictionary with nested data structures"""
|
| 56 |
+
result = asdict(self)
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
def to_snowflake_schema(self) -> Dict[str, Any]:
|
| 60 |
+
"""Generate Snowflake schema for this dataclass"""
|
| 61 |
+
schema = {
|
| 62 |
+
"feedback_id": "VARCHAR(255)",
|
| 63 |
+
"open_ended_feedback": "VARCHAR(16777216)", # Large text
|
| 64 |
+
"score": "INTEGER",
|
| 65 |
+
"is_feedback_about_last_retrieval": "BOOLEAN",
|
| 66 |
+
"conversation_id": "VARCHAR(255)",
|
| 67 |
+
"timestamp": "NUMBER(20, 0)",
|
| 68 |
+
"message_count": "INTEGER",
|
| 69 |
+
"has_retrievals": "BOOLEAN",
|
| 70 |
+
"retrieval_count": "INTEGER",
|
| 71 |
+
"transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
|
| 72 |
+
"retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
|
| 73 |
+
"feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
|
| 74 |
+
"retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
|
| 75 |
+
"created_at": "TIMESTAMP_NTZ",
|
| 76 |
+
# transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
| 77 |
+
# retrievals structure: [
|
| 78 |
+
# {
|
| 79 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
|
| 80 |
+
# "user_message_trigger": "final user message that triggered this retrieval"
|
| 81 |
+
# },
|
| 82 |
+
# ...
|
| 83 |
+
# ]
|
| 84 |
+
# feedback_score_related_retrieval_docs structure: {
|
| 85 |
+
# "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
|
| 86 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
|
| 87 |
+
# }
|
| 88 |
+
}
|
| 89 |
+
return schema
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
|
| 93 |
+
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 94 |
+
schema = cls.to_snowflake_schema(None)
|
| 95 |
+
|
| 96 |
+
columns = []
|
| 97 |
+
for col_name, col_type in schema.items():
|
| 98 |
+
nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
|
| 99 |
+
columns.append(f" {col_name} {col_type} {nullable}")
|
| 100 |
+
|
| 101 |
+
# Build SQL string properly
|
| 102 |
+
columns_str = ",\n".join(columns)
|
| 103 |
+
|
| 104 |
+
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 105 |
+
{columns_str},
|
| 106 |
+
PRIMARY KEY (feedback_id)
|
| 107 |
+
)
|
| 108 |
+
CLUSTER BY (timestamp, conversation_id, score);
|
| 109 |
+
-- Note: Snowflake doesn't support traditional indexes on regular tables.
|
| 110 |
+
-- Instead, we use CLUSTER BY to optimize queries on these columns.
|
| 111 |
+
-- Snowflake automatically maintains clustering for efficient querying.
|
| 112 |
+
-- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
|
| 113 |
+
-- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
|
| 114 |
+
"""
|
| 115 |
+
return sql
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Snowflake variant schema for retrieved_data array
|
| 119 |
+
RETRIEVAL_ENTRY_SCHEMA = {
|
| 120 |
+
"rag_query": "VARCHAR",
|
| 121 |
+
"documents_retrieved": "ARRAY", # Array of document objects
|
| 122 |
+
"conversation_length": "INTEGER",
|
| 123 |
+
"filters_applied": "OBJECT",
|
| 124 |
+
"timestamp": "NUMBER"
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
DOCUMENT_SCHEMA = {
|
| 128 |
+
"doc_id": "VARCHAR",
|
| 129 |
+
"filename": "VARCHAR",
|
| 130 |
+
"page": "INTEGER",
|
| 131 |
+
"score": "DOUBLE",
|
| 132 |
+
"content": "VARCHAR(16777216)",
|
| 133 |
+
"metadata": "OBJECT"
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 138 |
+
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 139 |
+
if table_name is None:
|
| 140 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 141 |
+
return UserFeedback.get_snowflake_create_table_sql(table_name)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 145 |
+
"""Create UserFeedback instance from dictionary"""
|
| 146 |
+
return UserFeedback(
|
| 147 |
+
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 148 |
+
open_ended_feedback=data.get("open_ended_feedback"),
|
| 149 |
+
score=data["score"],
|
| 150 |
+
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 151 |
+
conversation_id=data["conversation_id"],
|
| 152 |
+
timestamp=data["timestamp"],
|
| 153 |
+
message_count=data["message_count"],
|
| 154 |
+
has_retrievals=data["has_retrievals"],
|
| 155 |
+
retrieval_count=data["retrieval_count"],
|
| 156 |
+
transcript=data.get("transcript", []),
|
| 157 |
+
retrievals=data.get("retrievals", []),
|
| 158 |
+
feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
|
| 159 |
+
retrieved_data=data.get("retrieved_data")
|
| 160 |
+
)
|
| 161 |
+
|
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Snowflake Connector for Feedback System
|
| 3 |
+
|
| 4 |
+
This module handles inserting user feedback into Snowflake.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from .feedback_schema import UserFeedback
|
| 12 |
+
|
| 13 |
+
# Try to import snowflake connector
|
| 14 |
+
try:
|
| 15 |
+
import snowflake.connector
|
| 16 |
+
SNOWFLAKE_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SNOWFLAKE_AVAILABLE = False
|
| 19 |
+
logging.warning("β οΈ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SnowflakeFeedbackConnector:
|
| 27 |
+
"""Connector for inserting feedback into Snowflake"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
user: str,
|
| 32 |
+
password: str,
|
| 33 |
+
account: str,
|
| 34 |
+
warehouse: str,
|
| 35 |
+
database: str = "SNOWFLAKE_LEARNING",
|
| 36 |
+
schema: str = "PUBLIC"
|
| 37 |
+
):
|
| 38 |
+
self.user = user
|
| 39 |
+
self.password = password
|
| 40 |
+
self.account = account
|
| 41 |
+
self.warehouse = warehouse
|
| 42 |
+
self.database = database
|
| 43 |
+
self.schema = schema
|
| 44 |
+
self._connection = None
|
| 45 |
+
|
| 46 |
+
def connect(self):
|
| 47 |
+
"""Establish Snowflake connection"""
|
| 48 |
+
if not SNOWFLAKE_AVAILABLE:
|
| 49 |
+
raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
|
| 50 |
+
|
| 51 |
+
logger.info("=" * 80)
|
| 52 |
+
logger.info("π SNOWFLAKE CONNECTION: Attempting to connect...")
|
| 53 |
+
logger.info(f" - Account: {self.account}")
|
| 54 |
+
logger.info(f" - Warehouse: {self.warehouse}")
|
| 55 |
+
logger.info(f" - Database: {self.database}")
|
| 56 |
+
logger.info(f" - Schema: {self.schema}")
|
| 57 |
+
logger.info(f" - User: {self.user}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
self._connection = snowflake.connector.connect(
|
| 61 |
+
user=self.user,
|
| 62 |
+
password=self.password,
|
| 63 |
+
account=self.account,
|
| 64 |
+
warehouse=self.warehouse
|
| 65 |
+
# Don't set database/schema in connection - we'll do it per query
|
| 66 |
+
)
|
| 67 |
+
logger.info("β
SNOWFLAKE CONNECTION: Successfully connected")
|
| 68 |
+
logger.info("=" * 80)
|
| 69 |
+
print(f"β
Connected to Snowflake: {self.database}.{self.schema}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"β SNOWFLAKE CONNECTION FAILED: {e}")
|
| 72 |
+
logger.error("=" * 80)
|
| 73 |
+
print(f"β Failed to connect to Snowflake: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def disconnect(self):
|
| 77 |
+
"""Close Snowflake connection"""
|
| 78 |
+
if self._connection:
|
| 79 |
+
self._connection.close()
|
| 80 |
+
print("β
Disconnected from Snowflake")
|
| 81 |
+
|
| 82 |
+
def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 83 |
+
"""Insert a single feedback record into Snowflake"""
|
| 84 |
+
logger.info("=" * 80)
|
| 85 |
+
logger.info("π SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 87 |
+
|
| 88 |
+
# Get table name from parameter, env var, or default
|
| 89 |
+
if table_name is None:
|
| 90 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 91 |
+
|
| 92 |
+
if not self._connection:
|
| 93 |
+
logger.error("β Not connected to Snowflake. Call connect() first.")
|
| 94 |
+
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
logger.info("π VALIDATION: Validating feedback data structure...")
|
| 98 |
+
|
| 99 |
+
# Validate feedback object
|
| 100 |
+
validation_errors = []
|
| 101 |
+
if not feedback.feedback_id:
|
| 102 |
+
validation_errors.append("Missing feedback_id")
|
| 103 |
+
if feedback.score is None:
|
| 104 |
+
validation_errors.append("Missing score")
|
| 105 |
+
if feedback.timestamp is None:
|
| 106 |
+
validation_errors.append("Missing timestamp")
|
| 107 |
+
|
| 108 |
+
if validation_errors:
|
| 109 |
+
logger.error(f"β VALIDATION FAILED: {validation_errors}")
|
| 110 |
+
return False
|
| 111 |
+
else:
|
| 112 |
+
logger.info("β
VALIDATION PASSED: All required fields present")
|
| 113 |
+
|
| 114 |
+
logger.info("π Data Summary:")
|
| 115 |
+
logger.info(f" - Feedback ID: {feedback.feedback_id}")
|
| 116 |
+
logger.info(f" - Score: {feedback.score}")
|
| 117 |
+
logger.info(f" - Conversation ID: {feedback.conversation_id}")
|
| 118 |
+
logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
|
| 119 |
+
logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
|
| 120 |
+
logger.info(f" - Message Count: {feedback.message_count}")
|
| 121 |
+
logger.info(f" - Timestamp: {feedback.timestamp}")
|
| 122 |
+
|
| 123 |
+
cursor = self._connection.cursor()
|
| 124 |
+
logger.info("β
SNOWFLAKE CONNECTION: Cursor created")
|
| 125 |
+
|
| 126 |
+
# Set database and schema context
|
| 127 |
+
logger.info(f"π§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
|
| 128 |
+
try:
|
| 129 |
+
cursor.execute(f'USE DATABASE "{self.database}"')
|
| 130 |
+
cursor.execute(f'USE SCHEMA "{self.schema}"')
|
| 131 |
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
| 132 |
+
current_db, current_schema = cursor.fetchone()
|
| 133 |
+
logger.info(f"β
Current context verified: Database={current_db}, Schema={current_schema}")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"β Could not set context: {e}")
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
# Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 139 |
+
logger.info("π§ DATA PREPARATION: Preparing VARIANT columns...")
|
| 140 |
+
feedback_dict = feedback.to_dict()
|
| 141 |
+
|
| 142 |
+
# Prepare transcript (ARRAY) - convert to JSON string
|
| 143 |
+
transcript_raw = feedback_dict.get('transcript', [])
|
| 144 |
+
if transcript_raw:
|
| 145 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 146 |
+
transcript_for_db = json.dumps(transcript_raw)
|
| 147 |
+
logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
|
| 148 |
+
else:
|
| 149 |
+
transcript_for_db = None
|
| 150 |
+
logger.info(" - Transcript: None")
|
| 151 |
+
|
| 152 |
+
# Prepare retrievals (ARRAY) - convert to JSON string
|
| 153 |
+
retrievals_raw = feedback_dict.get('retrievals', [])
|
| 154 |
+
if retrievals_raw:
|
| 155 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 156 |
+
retrievals_for_db = json.dumps(retrievals_raw)
|
| 157 |
+
logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
|
| 158 |
+
else:
|
| 159 |
+
retrievals_for_db = None
|
| 160 |
+
logger.info(" - Retrievals: None")
|
| 161 |
+
|
| 162 |
+
# Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
|
| 163 |
+
feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
|
| 164 |
+
if feedback_score_related_raw:
|
| 165 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 166 |
+
feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
|
| 167 |
+
logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
|
| 168 |
+
else:
|
| 169 |
+
feedback_score_related_for_db = None
|
| 170 |
+
logger.info(" - Feedback score related docs: None")
|
| 171 |
+
|
| 172 |
+
# Prepare retrieved_data (preserved old column) - convert to JSON string
|
| 173 |
+
retrieved_data_raw = feedback_dict.get('retrieved_data')
|
| 174 |
+
if retrieved_data_raw:
|
| 175 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 176 |
+
retrieved_data_for_db = json.dumps(retrieved_data_raw)
|
| 177 |
+
logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
|
| 178 |
+
else:
|
| 179 |
+
retrieved_data_for_db = None
|
| 180 |
+
logger.info(" - Retrieved data (preserved): None")
|
| 181 |
+
|
| 182 |
+
# Build SQL with new column structure
|
| 183 |
+
# Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
|
| 184 |
+
sql = f"""INSERT INTO {table_name} (
|
| 185 |
+
feedback_id,
|
| 186 |
+
open_ended_feedback,
|
| 187 |
+
score,
|
| 188 |
+
is_feedback_about_last_retrieval,
|
| 189 |
+
conversation_id,
|
| 190 |
+
timestamp,
|
| 191 |
+
message_count,
|
| 192 |
+
has_retrievals,
|
| 193 |
+
retrieval_count,
|
| 194 |
+
transcript,
|
| 195 |
+
retrievals,
|
| 196 |
+
feedback_score_related_retrieval_docs,
|
| 197 |
+
retrieved_data,
|
| 198 |
+
created_at
|
| 199 |
+
) VALUES (
|
| 200 |
+
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 201 |
+
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 202 |
+
%(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
|
| 203 |
+
%(retrieved_data)s, %(created_at)s
|
| 204 |
+
)"""
|
| 205 |
+
|
| 206 |
+
logger.info("π SQL PREPARATION: Building INSERT statement...")
|
| 207 |
+
logger.info(f" - Target table: {table_name}")
|
| 208 |
+
logger.info(f" - Database: {self.database}")
|
| 209 |
+
logger.info(f" - Schema: {self.schema}")
|
| 210 |
+
|
| 211 |
+
# Prepare parameters
|
| 212 |
+
# Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 213 |
+
params = {
|
| 214 |
+
'feedback_id': feedback.feedback_id,
|
| 215 |
+
'open_ended_feedback': feedback.open_ended_feedback,
|
| 216 |
+
'score': feedback.score,
|
| 217 |
+
'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
|
| 218 |
+
'conversation_id': feedback.conversation_id,
|
| 219 |
+
'timestamp': int(feedback.timestamp),
|
| 220 |
+
'message_count': feedback.message_count,
|
| 221 |
+
'has_retrievals': feedback.has_retrievals,
|
| 222 |
+
'retrieval_count': feedback.retrieval_count,
|
| 223 |
+
'transcript': transcript_for_db, # JSON string
|
| 224 |
+
'retrievals': retrievals_for_db, # JSON string
|
| 225 |
+
'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
|
| 226 |
+
'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
|
| 227 |
+
'created_at': feedback.created_at
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# Execute insert
|
| 231 |
+
logger.info("π SQL EXECUTION: Executing INSERT query...")
|
| 232 |
+
cursor.execute(sql, params)
|
| 233 |
+
|
| 234 |
+
logger.info("β
SQL EXECUTION: Query executed successfully")
|
| 235 |
+
logger.info(f" - Rows affected: 1")
|
| 236 |
+
logger.info(f" - Status: SUCCESS")
|
| 237 |
+
|
| 238 |
+
cursor.close()
|
| 239 |
+
logger.info("β
SNOWFLAKE INSERT: Feedback inserted successfully")
|
| 240 |
+
logger.info(f"π Inserted feedback: {feedback.feedback_id}")
|
| 241 |
+
logger.info("=" * 80)
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
# Check if it's a Snowflake error
|
| 246 |
+
if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
|
| 247 |
+
logger.error(f"β SQL EXECUTION ERROR: {e}")
|
| 248 |
+
logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
|
| 249 |
+
logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
|
| 250 |
+
else:
|
| 251 |
+
logger.error(f"β SNOWFLAKE INSERT FAILED: {type(e).__name__}")
|
| 252 |
+
logger.error(f" - Error: {e}")
|
| 253 |
+
logger.error("=" * 80)
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def __enter__(self):
|
| 257 |
+
"""Context manager entry"""
|
| 258 |
+
self.connect()
|
| 259 |
+
return self
|
| 260 |
+
|
| 261 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 262 |
+
"""Context manager exit"""
|
| 263 |
+
self.disconnect()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
| 267 |
+
"""Create Snowflake connector from environment variables"""
|
| 268 |
+
user = os.getenv("SNOWFLAKE_USER")
|
| 269 |
+
password = os.getenv("SNOWFLAKE_PASSWORD")
|
| 270 |
+
account = os.getenv("SNOWFLAKE_ACCOUNT")
|
| 271 |
+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
|
| 272 |
+
database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
|
| 273 |
+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
| 274 |
+
|
| 275 |
+
if not all([user, password, account, warehouse]):
|
| 276 |
+
print("β οΈ Snowflake credentials not found in environment variables")
|
| 277 |
+
print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
return SnowflakeFeedbackConnector(
|
| 281 |
+
user=user,
|
| 282 |
+
password=password,
|
| 283 |
+
account=account,
|
| 284 |
+
warehouse=warehouse,
|
| 285 |
+
database=database,
|
| 286 |
+
schema=schema
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 291 |
+
"""Helper function to save feedback to Snowflake"""
|
| 292 |
+
logger.info("=" * 80)
|
| 293 |
+
logger.info("π΅ SNOWFLAKE SAVE: Starting save process")
|
| 294 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 295 |
+
|
| 296 |
+
# Get table name from parameter or env var
|
| 297 |
+
if table_name is None:
|
| 298 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 299 |
+
|
| 300 |
+
connector = get_snowflake_connector_from_env()
|
| 301 |
+
|
| 302 |
+
if not connector:
|
| 303 |
+
logger.warning("β οΈ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
|
| 304 |
+
logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 305 |
+
logger.info("=" * 80)
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
logger.info("π‘ SNOWFLAKE SAVE: Establishing connection...")
|
| 310 |
+
connector.connect()
|
| 311 |
+
logger.info("β
SNOWFLAKE SAVE: Connection established")
|
| 312 |
+
|
| 313 |
+
logger.info("π₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 314 |
+
success = connector.insert_feedback(feedback, table_name=table_name)
|
| 315 |
+
|
| 316 |
+
logger.info("π SNOWFLAKE SAVE: Disconnecting...")
|
| 317 |
+
connector.disconnect()
|
| 318 |
+
|
| 319 |
+
if success:
|
| 320 |
+
logger.info("β
SNOWFLAKE SAVE: Successfully saved feedback")
|
| 321 |
+
else:
|
| 322 |
+
logger.error("β SNOWFLAKE SAVE: Failed to save feedback")
|
| 323 |
+
|
| 324 |
+
logger.info("=" * 80)
|
| 325 |
+
return success
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"β SNOWFLAKE SAVE ERROR: {type(e).__name__}")
|
| 328 |
+
logger.error(f" - Error: {e}")
|
| 329 |
+
logger.info("=" * 80)
|
| 330 |
+
return False
|
| 331 |
+
|
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Integration Module
|
| 3 |
+
|
| 4 |
+
This module provides integration with Google Gemini File Search API
|
| 5 |
+
for RAG functionality using Gemini's built-in file search capabilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 9 |
+
|
| 10 |
+
__all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
|
| 11 |
+
|
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Client
|
| 3 |
+
|
| 4 |
+
Handles interaction with Google Gemini File Search API for RAG.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai import types
|
| 17 |
+
GEMINI_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
GEMINI_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class GeminiFileSearchResult:
|
| 24 |
+
"""Result from Gemini File Search query"""
|
| 25 |
+
answer: str
|
| 26 |
+
sources: List[Dict[str, Any]] # List of document references
|
| 27 |
+
grounding_metadata: Optional[Dict[str, Any]] = None
|
| 28 |
+
query: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GeminiFileSearchClient:
|
| 32 |
+
"""Client for interacting with Gemini File Search API"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
|
| 35 |
+
"""
|
| 36 |
+
Initialize Gemini File Search client.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
|
| 40 |
+
store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
|
| 41 |
+
"""
|
| 42 |
+
if not GEMINI_AVAILABLE:
|
| 43 |
+
raise ImportError("google-genai package not installed. Install with: pip install google-genai")
|
| 44 |
+
|
| 45 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 46 |
+
if not self.api_key:
|
| 47 |
+
raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
|
| 48 |
+
|
| 49 |
+
store_name_raw = store_name or os.getenv("GEMINI_FILESTORE_NAME")
|
| 50 |
+
if not store_name_raw:
|
| 51 |
+
raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
|
| 52 |
+
|
| 53 |
+
# Normalize store name: API expects the FULL path format (fileSearchStores/xxx)
|
| 54 |
+
# If just the ID is provided, construct the full path
|
| 55 |
+
if store_name_raw.startswith("fileSearchStores/"):
|
| 56 |
+
self.store_name = store_name_raw # Already full path
|
| 57 |
+
else:
|
| 58 |
+
# Just the ID provided, construct full path
|
| 59 |
+
self.store_name = f"fileSearchStores/{store_name_raw}"
|
| 60 |
+
|
| 61 |
+
logger.info(f"π¦ Using file search store: {self.store_name}")
|
| 62 |
+
|
| 63 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 64 |
+
self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
|
| 65 |
+
|
| 66 |
+
def search(
|
| 67 |
+
self,
|
| 68 |
+
query: str,
|
| 69 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 70 |
+
model: Optional[str] = None
|
| 71 |
+
) -> GeminiFileSearchResult:
|
| 72 |
+
"""
|
| 73 |
+
Search using Gemini File Search.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
query: User query
|
| 77 |
+
filters: Optional filters (year, source, district, etc.)
|
| 78 |
+
model: Model to use (defaults to gemini-2.5-flash)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
GeminiFileSearchResult with answer and sources
|
| 82 |
+
"""
|
| 83 |
+
model = model or self.model
|
| 84 |
+
|
| 85 |
+
# Build filter context for the query if filters are provided
|
| 86 |
+
# Gemini File Search doesn't support explicit filters in the API,
|
| 87 |
+
# so we add them as context in the query
|
| 88 |
+
filter_context = ""
|
| 89 |
+
if filters:
|
| 90 |
+
filter_parts = []
|
| 91 |
+
if filters.get("year"):
|
| 92 |
+
years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
|
| 93 |
+
filter_parts.append(f"Year: {', '.join(years)}")
|
| 94 |
+
if filters.get("sources"):
|
| 95 |
+
sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
|
| 96 |
+
filter_parts.append(f"Source: {', '.join(sources)}")
|
| 97 |
+
if filters.get("district"):
|
| 98 |
+
districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
|
| 99 |
+
filter_parts.append(f"District: {', '.join(districts)}")
|
| 100 |
+
if filters.get("filenames"):
|
| 101 |
+
filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
|
| 102 |
+
filter_parts.append(f"Filename: {', '.join(filenames)}")
|
| 103 |
+
|
| 104 |
+
if filter_parts:
|
| 105 |
+
filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
|
| 106 |
+
|
| 107 |
+
# Combine query with filter context
|
| 108 |
+
# Add comprehensive system instructions similar to multi-agent system
|
| 109 |
+
system_instructions = """You are a helpful audit report assistant specialized in analyzing government audit reports from Uganda's Office of the Auditor General.
|
| 110 |
+
|
| 111 |
+
CRITICAL RULES:
|
| 112 |
+
1. **NO HALLUCINATION**: Only use information that is explicitly stated in the retrieved documents. Do not make up facts, numbers, or details.
|
| 113 |
+
2. **Document References**: Always cite which documents you're using with [Doc i] references at the end of sentences that use specific information.
|
| 114 |
+
3. **Formatting**: Structure your response with clear paragraphs, bullet points, or sections for readability.
|
| 115 |
+
4. **Accuracy**: If the retrieved documents don't contain the requested information, explicitly state "The retrieved documents do not contain information about [topic]."
|
| 116 |
+
5. **Years and Data**: Pay careful attention to years mentioned in documents. If a user asks about a specific year but documents show different years, explicitly state this.
|
| 117 |
+
6. **District/Source Names**: Use the exact district and source names as they appear in the document metadata (e.g., "Kalangala" not "Kalagala").
|
| 118 |
+
7. **Financial Data**: When providing financial figures, include the currency (UGX) and be precise about amounts.
|
| 119 |
+
8. **Conversational Tone**: Be helpful, clear, and conversational while maintaining accuracy.
|
| 120 |
+
|
| 121 |
+
IMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents."""
|
| 122 |
+
|
| 123 |
+
# Combine system instructions with query
|
| 124 |
+
full_query = f"{system_instructions}\n\nUser Question: {query}{filter_context}\n\nPlease provide a detailed, well-formatted response with proper document references."
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Generate content with file search
|
| 128 |
+
# Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
|
| 129 |
+
# Try with full path format first, then fallback to just ID if needed
|
| 130 |
+
store_name_to_try = self.store_name
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Try the documented format first with full path
|
| 134 |
+
response = self.client.models.generate_content(
|
| 135 |
+
model=model,
|
| 136 |
+
contents=full_query,
|
| 137 |
+
config=types.GenerateContentConfig(
|
| 138 |
+
tools=[
|
| 139 |
+
types.Tool(
|
| 140 |
+
file_search=types.FileSearch(
|
| 141 |
+
file_search_store_names=[store_name_to_try]
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
except Exception as api_error:
|
| 148 |
+
error_str = str(api_error).lower()
|
| 149 |
+
# If format error, try with just the ID (without fileSearchStores/ prefix)
|
| 150 |
+
if 'format' in error_str or 'invalid' in error_str or 'too long' in error_str:
|
| 151 |
+
logger.warning(f"Full path format failed, trying with just store ID: {api_error}")
|
| 152 |
+
# Extract just the ID part
|
| 153 |
+
if store_name_to_try.startswith("fileSearchStores/"):
|
| 154 |
+
store_id = store_name_to_try.split("/", 1)[1]
|
| 155 |
+
store_name_to_try = store_id
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
response = self.client.models.generate_content(
|
| 159 |
+
model=model,
|
| 160 |
+
contents=full_query,
|
| 161 |
+
config=types.GenerateContentConfig(
|
| 162 |
+
tools=[
|
| 163 |
+
types.Tool(
|
| 164 |
+
file_search=types.FileSearch(
|
| 165 |
+
file_search_store_names=[store_name_to_try]
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
except Exception as e2:
|
| 172 |
+
raise Exception(f"Failed to call Gemini API with both formats. Full path error: {api_error}, ID-only error: {e2}")
|
| 173 |
+
else:
|
| 174 |
+
# Try alternative dict format
|
| 175 |
+
logger.warning(f"Primary API format failed, trying alternative: {api_error}")
|
| 176 |
+
try:
|
| 177 |
+
response = self.client.models.generate_content(
|
| 178 |
+
model=model,
|
| 179 |
+
contents=full_query,
|
| 180 |
+
tools=[{
|
| 181 |
+
"file_search": {
|
| 182 |
+
"file_search_store_names": [store_name_to_try]
|
| 183 |
+
}
|
| 184 |
+
}]
|
| 185 |
+
)
|
| 186 |
+
except Exception as e2:
|
| 187 |
+
raise Exception(f"Failed to call Gemini API: {e2}")
|
| 188 |
+
|
| 189 |
+
# Extract answer
|
| 190 |
+
answer = ""
|
| 191 |
+
if hasattr(response, 'text'):
|
| 192 |
+
answer = response.text
|
| 193 |
+
elif hasattr(response, 'candidates') and response.candidates:
|
| 194 |
+
# Try to get text from first candidate
|
| 195 |
+
candidate = response.candidates[0]
|
| 196 |
+
if hasattr(candidate, 'content') and candidate.content:
|
| 197 |
+
if hasattr(candidate.content, 'parts'):
|
| 198 |
+
text_parts = []
|
| 199 |
+
for part in candidate.content.parts:
|
| 200 |
+
if hasattr(part, 'text'):
|
| 201 |
+
text_parts.append(part.text)
|
| 202 |
+
answer = " ".join(text_parts)
|
| 203 |
+
elif isinstance(candidate.content, str):
|
| 204 |
+
answer = candidate.content
|
| 205 |
+
else:
|
| 206 |
+
answer = str(response)
|
| 207 |
+
|
| 208 |
+
# Extract grounding metadata (document references)
|
| 209 |
+
sources = []
|
| 210 |
+
grounding_metadata = None
|
| 211 |
+
|
| 212 |
+
logger.info(f"π Extracting sources from Gemini response...")
|
| 213 |
+
|
| 214 |
+
if hasattr(response, 'candidates') and response.candidates:
|
| 215 |
+
candidate = response.candidates[0]
|
| 216 |
+
logger.info(f" Found candidate, checking for grounding_metadata...")
|
| 217 |
+
|
| 218 |
+
# Get grounding metadata
|
| 219 |
+
if hasattr(candidate, 'grounding_metadata'):
|
| 220 |
+
grounding_metadata = candidate.grounding_metadata
|
| 221 |
+
logger.info(f" Found grounding_metadata: {type(grounding_metadata)}")
|
| 222 |
+
|
| 223 |
+
# Extract source documents from grounding metadata
|
| 224 |
+
# Handle different response formats
|
| 225 |
+
grounding_chunks = None
|
| 226 |
+
if hasattr(grounding_metadata, 'grounding_chunks'):
|
| 227 |
+
grounding_chunks = grounding_metadata.grounding_chunks
|
| 228 |
+
logger.info(f" Found grounding_chunks (attr): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 229 |
+
elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
|
| 230 |
+
grounding_chunks = grounding_metadata['grounding_chunks']
|
| 231 |
+
logger.info(f" Found grounding_chunks (dict): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 232 |
+
elif hasattr(grounding_metadata, '__dict__'):
|
| 233 |
+
# Try to access as object attributes
|
| 234 |
+
metadata_dict = grounding_metadata.__dict__
|
| 235 |
+
if 'grounding_chunks' in metadata_dict:
|
| 236 |
+
grounding_chunks = metadata_dict['grounding_chunks']
|
| 237 |
+
logger.info(f" Found grounding_chunks (__dict__): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 238 |
+
|
| 239 |
+
if grounding_chunks:
|
| 240 |
+
logger.info(f" Processing {len(grounding_chunks)} grounding chunks...")
|
| 241 |
+
for idx, chunk in enumerate(grounding_chunks):
|
| 242 |
+
# Handle both object and dict formats
|
| 243 |
+
try:
|
| 244 |
+
if isinstance(chunk, dict):
|
| 245 |
+
chunk_data = chunk
|
| 246 |
+
else:
|
| 247 |
+
# Object format - convert to dict-like access
|
| 248 |
+
chunk_data = {}
|
| 249 |
+
if hasattr(chunk, 'chunk'):
|
| 250 |
+
chunk_obj = chunk.chunk
|
| 251 |
+
chunk_data['chunk'] = {
|
| 252 |
+
'text': getattr(chunk_obj, 'text', ''),
|
| 253 |
+
'file_name': getattr(chunk_obj, 'file_name', '')
|
| 254 |
+
}
|
| 255 |
+
if hasattr(chunk, 'relevance_score'):
|
| 256 |
+
score_obj = chunk.relevance_score
|
| 257 |
+
chunk_data['relevance_score'] = {
|
| 258 |
+
'score': getattr(score_obj, 'score', 0.0)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
chunk_info = chunk_data.get('chunk', {})
|
| 262 |
+
text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
|
| 263 |
+
file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
|
| 264 |
+
|
| 265 |
+
# Try to extract file URI and parse metadata from it
|
| 266 |
+
file_uri = chunk_info.get('file_uri', '') if isinstance(chunk_info, dict) else ''
|
| 267 |
+
|
| 268 |
+
# Also check for 'web' attribute (GroundingChunkData structure)
|
| 269 |
+
if hasattr(chunk, 'web') and chunk.web:
|
| 270 |
+
web_data = chunk.web
|
| 271 |
+
file_uri = getattr(web_data, 'file_uri', '') or file_uri
|
| 272 |
+
file_name = getattr(web_data, 'title', '') or getattr(web_data, 'filename', '') or file_name
|
| 273 |
+
text = getattr(web_data, 'text', '') or getattr(web_data, 'content', '') or text
|
| 274 |
+
|
| 275 |
+
# Check retrieved_context - this is where the actual data seems to be!
|
| 276 |
+
if hasattr(chunk, 'retrieved_context') and chunk.retrieved_context:
|
| 277 |
+
rc = chunk.retrieved_context
|
| 278 |
+
# Get text content
|
| 279 |
+
if hasattr(rc, 'text'):
|
| 280 |
+
text = getattr(rc, 'text', '') or text
|
| 281 |
+
# Get document name
|
| 282 |
+
if hasattr(rc, 'document_name'):
|
| 283 |
+
doc_name = getattr(rc, 'document_name', '')
|
| 284 |
+
if doc_name:
|
| 285 |
+
file_name = doc_name or file_name
|
| 286 |
+
|
| 287 |
+
# Fallback: Parse from string representation if we still don't have filename
|
| 288 |
+
if not file_name:
|
| 289 |
+
chunk_str = str(chunk)
|
| 290 |
+
import re
|
| 291 |
+
# Look for PDF filenames
|
| 292 |
+
pdf_match = re.search(r"([A-Za-z0-9\s_-]+\.pdf)", chunk_str)
|
| 293 |
+
if pdf_match:
|
| 294 |
+
file_name = pdf_match.group(1)
|
| 295 |
+
# Or look for title= pattern
|
| 296 |
+
if not file_name and 'title=' in chunk_str:
|
| 297 |
+
title_match = re.search(r"title=['\"]([^'\"]+)['\"]", chunk_str)
|
| 298 |
+
if title_match:
|
| 299 |
+
file_name = title_match.group(1)
|
| 300 |
+
|
| 301 |
+
if not file_name and file_uri:
|
| 302 |
+
# Extract filename from URI if available
|
| 303 |
+
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
|
| 304 |
+
|
| 305 |
+
score_data = chunk_data.get('relevance_score', {})
|
| 306 |
+
score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
|
| 307 |
+
|
| 308 |
+
if text or file_name: # Only add if we have content
|
| 309 |
+
source_info = {
|
| 310 |
+
"content": text,
|
| 311 |
+
"filename": file_name,
|
| 312 |
+
"score": score,
|
| 313 |
+
"file_uri": file_uri,
|
| 314 |
+
}
|
| 315 |
+
sources.append(source_info)
|
| 316 |
+
logger.info(f"π Extracted source {idx+1}: {file_name} (score: {score:.3f}, content length: {len(text)})")
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning(f"Error extracting chunk {idx+1} info: {e}")
|
| 319 |
+
import traceback
|
| 320 |
+
logger.debug(traceback.format_exc())
|
| 321 |
+
continue
|
| 322 |
+
else:
|
| 323 |
+
logger.warning(f" No grounding_chunks found in grounding_metadata")
|
| 324 |
+
else:
|
| 325 |
+
logger.warning(f" Candidate does not have grounding_metadata attribute")
|
| 326 |
+
|
| 327 |
+
# Also try to get file references from other parts of the response
|
| 328 |
+
# Sometimes Gemini includes file references in the response itself
|
| 329 |
+
if not sources or len(sources) == 0:
|
| 330 |
+
logger.info(f" No sources from grounding_metadata, trying alternative extraction...")
|
| 331 |
+
# Check if response has file references in other attributes
|
| 332 |
+
if hasattr(candidate, 'content') and candidate.content:
|
| 333 |
+
if hasattr(candidate.content, 'parts'):
|
| 334 |
+
for part in candidate.content.parts:
|
| 335 |
+
if hasattr(part, 'file_data'):
|
| 336 |
+
file_data = part.file_data
|
| 337 |
+
if hasattr(file_data, 'file_uri') or (isinstance(file_data, dict) and 'file_uri' in file_data):
|
| 338 |
+
file_uri = getattr(file_data, 'file_uri', None) or (file_data.get('file_uri') if isinstance(file_data, dict) else None)
|
| 339 |
+
if file_uri:
|
| 340 |
+
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
|
| 341 |
+
sources.append({
|
| 342 |
+
"content": "",
|
| 343 |
+
"filename": file_name,
|
| 344 |
+
"score": 0.0,
|
| 345 |
+
"file_uri": file_uri,
|
| 346 |
+
})
|
| 347 |
+
logger.info(f"π Extracted source from file_data: {file_name}")
|
| 348 |
+
|
| 349 |
+
logger.info(f"β
Total sources extracted: {len(sources)}")
|
| 350 |
+
|
| 351 |
+
return GeminiFileSearchResult(
|
| 352 |
+
answer=answer,
|
| 353 |
+
sources=sources,
|
| 354 |
+
grounding_metadata=grounding_metadata,
|
| 355 |
+
query=query
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
# Return error result
|
| 360 |
+
return GeminiFileSearchResult(
|
| 361 |
+
answer=f"I apologize, but I encountered an error: {str(e)}",
|
| 362 |
+
sources=[],
|
| 363 |
+
query=query
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
|
| 367 |
+
"""
|
| 368 |
+
Format Gemini sources to match the format expected by the UI.
|
| 369 |
+
|
| 370 |
+
Returns list of document-like objects compatible with existing display code.
|
| 371 |
+
"""
|
| 372 |
+
from langchain.docstore.document import Document
|
| 373 |
+
|
| 374 |
+
formatted_sources = []
|
| 375 |
+
|
| 376 |
+
for i, source in enumerate(result.sources):
|
| 377 |
+
filename = source.get("filename", "Unknown")
|
| 378 |
+
|
| 379 |
+
# Try to extract metadata from filename (e.g., "Kalangala DLG Report of Auditor General 2021.pdf")
|
| 380 |
+
year = None
|
| 381 |
+
district = None
|
| 382 |
+
source_name = "Gemini File Search"
|
| 383 |
+
|
| 384 |
+
# Parse filename for year
|
| 385 |
+
import re
|
| 386 |
+
year_match = re.search(r'\b(20\d{2})\b', filename)
|
| 387 |
+
if year_match:
|
| 388 |
+
year = int(year_match.group(1))
|
| 389 |
+
|
| 390 |
+
# Parse filename for district/source
|
| 391 |
+
if "Kalangala" in filename:
|
| 392 |
+
district = "Kalangala"
|
| 393 |
+
source_name = "Kalangala DLG"
|
| 394 |
+
elif "Gulu" in filename:
|
| 395 |
+
district = "Gulu"
|
| 396 |
+
source_name = "Gulu DLG"
|
| 397 |
+
elif "KCCA" in filename:
|
| 398 |
+
district = "Kampala"
|
| 399 |
+
source_name = "KCCA"
|
| 400 |
+
elif "MAAIF" in filename:
|
| 401 |
+
source_name = "MAAIF"
|
| 402 |
+
elif "MWTS" in filename:
|
| 403 |
+
source_name = "MWTS"
|
| 404 |
+
elif "Consolidated" in filename:
|
| 405 |
+
source_name = "Consolidated"
|
| 406 |
+
|
| 407 |
+
# Create a Document object compatible with existing code
|
| 408 |
+
doc = Document(
|
| 409 |
+
page_content=source.get("content", ""),
|
| 410 |
+
metadata={
|
| 411 |
+
"filename": filename,
|
| 412 |
+
"source": source_name,
|
| 413 |
+
"score": source.get("score"),
|
| 414 |
+
"chunk_index": i,
|
| 415 |
+
"page": None, # Gemini doesn't provide page numbers
|
| 416 |
+
"year": year,
|
| 417 |
+
"district": district,
|
| 418 |
+
"chunk_id": f"gemini_{i}",
|
| 419 |
+
"_id": f"gemini_{i}",
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
formatted_sources.append(doc)
|
| 423 |
+
logger.info(f"π Formatted source {i+1}: {filename} ({year}, {source_name})")
|
| 424 |
+
|
| 425 |
+
logger.info(f"β
Formatted {len(formatted_sources)} sources for display")
|
| 426 |
+
return formatted_sources
|
| 427 |
+
|
|
File without changes
|
|
@@ -14,7 +14,7 @@ except ModuleNotFoundError as me:
|
|
| 14 |
|
| 15 |
from .logging import log_error
|
| 16 |
|
| 17 |
-
from .loader import chunks_to_documents
|
| 18 |
from .vectorstore import VectorStoreManager
|
| 19 |
from .reporting.service import ReportService
|
| 20 |
from .retrieval.context import ContextRetriever
|
|
|
|
| 14 |
|
| 15 |
from .logging import log_error
|
| 16 |
|
| 17 |
+
from .llm.loader import chunks_to_documents
|
| 18 |
from .vectorstore import VectorStoreManager
|
| 19 |
from .reporting.service import ReportService
|
| 20 |
from .retrieval.context import ContextRetriever
|
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
"""Report metadata and utilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from .metadata import get_report_metadata, get_available_sources
|
| 4 |
from .service import ReportService
|
|
|
|
| 1 |
+
"""Report metadata and utilities.
|
| 2 |
+
|
| 3 |
+
This module is kept for backward compatibility with pipeline.py.
|
| 4 |
+
For feedback-related functionality, use src.feedback instead.
|
| 5 |
+
"""
|
| 6 |
|
| 7 |
from .metadata import get_report_metadata, get_available_sources
|
| 8 |
from .service import ReportService
|
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI Components Module
|
| 3 |
+
|
| 4 |
+
This module contains UI-related components including styles, visualizations,
|
| 5 |
+
and utility functions for the Streamlit application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .styles import get_custom_css
|
| 9 |
+
from .components import (
|
| 10 |
+
display_chunk_statistics_charts,
|
| 11 |
+
display_chunk_statistics_table
|
| 12 |
+
)
|
| 13 |
+
from .utils import extract_chunk_statistics
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"get_custom_css",
|
| 17 |
+
"display_chunk_statistics_charts",
|
| 18 |
+
"display_chunk_statistics_table",
|
| 19 |
+
"extract_chunk_statistics"
|
| 20 |
+
]
|
| 21 |
+
|
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI components for displaying statistics and visualizations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
|
| 12 |
+
"""Display statistics as interactive charts for 10+ results."""
|
| 13 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# Wrap everything in one styled container - open it
|
| 17 |
+
st.markdown(f"""
|
| 18 |
+
<div class="retrieval-distribution-container">
|
| 19 |
+
<h3 style="margin-top: 0;">π {title}</h3>
|
| 20 |
+
<div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
|
| 21 |
+
<div class="metric-container">
|
| 22 |
+
<div class="metric-label">Total Chunks</div>
|
| 23 |
+
<div class="metric-value">{stats['total_chunks']}</div>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="metric-container">
|
| 26 |
+
<div class="metric-label">Unique Sources</div>
|
| 27 |
+
<div class="metric-value">{stats['unique_sources']}</div>
|
| 28 |
+
</div>
|
| 29 |
+
<div class="metric-container">
|
| 30 |
+
<div class="metric-label">Unique Years</div>
|
| 31 |
+
<div class="metric-value">{stats['unique_years']}</div>
|
| 32 |
+
</div>
|
| 33 |
+
<div class="metric-container">
|
| 34 |
+
<div class="metric-label">Unique Files</div>
|
| 35 |
+
<div class="metric-value">{stats['unique_filenames']}</div>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
""", unsafe_allow_html=True)
|
| 39 |
+
|
| 40 |
+
# Charts - three columns to include Districts
|
| 41 |
+
col1, col2, col3 = st.columns(3)
|
| 42 |
+
|
| 43 |
+
with col1:
|
| 44 |
+
# Source distribution chart
|
| 45 |
+
if stats['source_distribution']:
|
| 46 |
+
source_df = pd.DataFrame(
|
| 47 |
+
list(stats['source_distribution'].items()),
|
| 48 |
+
columns=['Source', 'Count']
|
| 49 |
+
)
|
| 50 |
+
fig_source = px.bar(
|
| 51 |
+
source_df,
|
| 52 |
+
x='Count',
|
| 53 |
+
y='Source',
|
| 54 |
+
orientation='h',
|
| 55 |
+
title='Distribution by Source',
|
| 56 |
+
color='Count',
|
| 57 |
+
color_continuous_scale='viridis'
|
| 58 |
+
)
|
| 59 |
+
fig_source.update_layout(height=400, showlegend=False)
|
| 60 |
+
st.plotly_chart(fig_source, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 61 |
+
|
| 62 |
+
with col2:
|
| 63 |
+
# Year distribution chart
|
| 64 |
+
if stats['year_distribution']:
|
| 65 |
+
# Filter out 'Unknown' years for the chart
|
| 66 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 67 |
+
if year_dist_filtered:
|
| 68 |
+
year_df = pd.DataFrame(
|
| 69 |
+
list(year_dist_filtered.items()),
|
| 70 |
+
columns=['Year', 'Count']
|
| 71 |
+
)
|
| 72 |
+
# Sort by year as integer but keep as string for categorical display
|
| 73 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 74 |
+
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
|
| 75 |
+
|
| 76 |
+
fig_year = px.bar(
|
| 77 |
+
year_df,
|
| 78 |
+
x='Year',
|
| 79 |
+
y='Count',
|
| 80 |
+
title='Distribution by Year',
|
| 81 |
+
color='Count',
|
| 82 |
+
color_continuous_scale='plasma'
|
| 83 |
+
)
|
| 84 |
+
# Ensure years are treated as categorical (discrete) not continuous
|
| 85 |
+
fig_year.update_xaxes(type='category')
|
| 86 |
+
fig_year.update_layout(height=400, showlegend=False)
|
| 87 |
+
st.plotly_chart(fig_year, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 88 |
+
else:
|
| 89 |
+
st.info("No valid years found in the results")
|
| 90 |
+
|
| 91 |
+
with col3:
|
| 92 |
+
# District distribution chart
|
| 93 |
+
if stats.get('district_distribution'):
|
| 94 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 95 |
+
if district_dist_filtered:
|
| 96 |
+
district_df = pd.DataFrame(
|
| 97 |
+
list(district_dist_filtered.items()),
|
| 98 |
+
columns=['District', 'Count']
|
| 99 |
+
)
|
| 100 |
+
district_df = district_df.sort_values('Count', ascending=False)
|
| 101 |
+
|
| 102 |
+
fig_district = px.bar(
|
| 103 |
+
district_df,
|
| 104 |
+
x='Count',
|
| 105 |
+
y='District',
|
| 106 |
+
orientation='h',
|
| 107 |
+
title='Distribution by District',
|
| 108 |
+
color='Count',
|
| 109 |
+
color_continuous_scale='blues'
|
| 110 |
+
)
|
| 111 |
+
fig_district.update_layout(height=400, showlegend=False)
|
| 112 |
+
st.plotly_chart(fig_district, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 113 |
+
else:
|
| 114 |
+
st.info("No valid districts found in the results")
|
| 115 |
+
|
| 116 |
+
# Close the container
|
| 117 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
|
| 121 |
+
"""Display statistics as tables for smaller results with fixed alignment."""
|
| 122 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# Wrap in styled container
|
| 126 |
+
st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 127 |
+
|
| 128 |
+
st.subheader(f"π {title}")
|
| 129 |
+
|
| 130 |
+
# Create a container with fixed height for alignment
|
| 131 |
+
stats_container = st.container()
|
| 132 |
+
|
| 133 |
+
with stats_container:
|
| 134 |
+
# Create 4 equal columns for consistent alignment
|
| 135 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 136 |
+
|
| 137 |
+
with col1:
|
| 138 |
+
st.markdown("**ποΈ Districts**")
|
| 139 |
+
if stats.get('district_distribution'):
|
| 140 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 141 |
+
if district_dist_filtered:
|
| 142 |
+
district_data = {
|
| 143 |
+
"District": list(district_dist_filtered.keys()),
|
| 144 |
+
"Count": list(district_dist_filtered.values())
|
| 145 |
+
}
|
| 146 |
+
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
|
| 147 |
+
st.dataframe(district_df, hide_index=True, width='stretch')
|
| 148 |
+
else:
|
| 149 |
+
st.write("No district data")
|
| 150 |
+
else:
|
| 151 |
+
st.write("No district data")
|
| 152 |
+
|
| 153 |
+
with col2:
|
| 154 |
+
st.markdown("**π Sources**")
|
| 155 |
+
if stats['source_distribution']:
|
| 156 |
+
source_data = {
|
| 157 |
+
"Source": list(stats['source_distribution'].keys()),
|
| 158 |
+
"Count": list(stats['source_distribution'].values())
|
| 159 |
+
}
|
| 160 |
+
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
|
| 161 |
+
st.dataframe(source_df, hide_index=True, width='stretch')
|
| 162 |
+
else:
|
| 163 |
+
st.write("No source data")
|
| 164 |
+
|
| 165 |
+
with col3:
|
| 166 |
+
st.markdown("**π
Years**")
|
| 167 |
+
if stats['year_distribution']:
|
| 168 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 169 |
+
if year_dist_filtered:
|
| 170 |
+
year_data = {
|
| 171 |
+
"Year": list(year_dist_filtered.keys()),
|
| 172 |
+
"Count": list(year_dist_filtered.values())
|
| 173 |
+
}
|
| 174 |
+
year_df = pd.DataFrame(year_data)
|
| 175 |
+
# Sort by year as integer but display as string
|
| 176 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 177 |
+
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
|
| 178 |
+
st.dataframe(year_df, hide_index=True, width='stretch')
|
| 179 |
+
else:
|
| 180 |
+
st.write("No year data")
|
| 181 |
+
else:
|
| 182 |
+
st.write("No year data")
|
| 183 |
+
|
| 184 |
+
with col4:
|
| 185 |
+
st.markdown("**π Files**")
|
| 186 |
+
if stats['filename_distribution']:
|
| 187 |
+
filename_items = list(stats['filename_distribution'].items())
|
| 188 |
+
filename_items.sort(key=lambda x: x[1], reverse=True)
|
| 189 |
+
|
| 190 |
+
# Show top files with truncated names
|
| 191 |
+
file_data = {
|
| 192 |
+
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
|
| 193 |
+
"Count": [c for f, c in filename_items[:5]]
|
| 194 |
+
}
|
| 195 |
+
file_df = pd.DataFrame(file_data)
|
| 196 |
+
st.dataframe(file_df, hide_index=True, width='stretch')
|
| 197 |
+
else:
|
| 198 |
+
st.write("No file data")
|
| 199 |
+
|
| 200 |
+
# Close container
|
| 201 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 202 |
+
|
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom CSS styles for Streamlit application
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_custom_css() -> str:
|
| 7 |
+
"""Get custom CSS styles as a string"""
|
| 8 |
+
return """
|
| 9 |
+
<style>
|
| 10 |
+
.main-header {
|
| 11 |
+
font-size: 2.5rem;
|
| 12 |
+
font-weight: bold;
|
| 13 |
+
color: #1f77b4;
|
| 14 |
+
text-align: center;
|
| 15 |
+
margin-bottom: 1rem;
|
| 16 |
+
width: 100%;
|
| 17 |
+
display: block;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.subtitle {
|
| 21 |
+
font-size: 1.2rem;
|
| 22 |
+
color: #666;
|
| 23 |
+
text-align: center;
|
| 24 |
+
margin-bottom: 2rem;
|
| 25 |
+
width: 100%;
|
| 26 |
+
display: block;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
.session-info {
|
| 30 |
+
background-color: #f0f2f6;
|
| 31 |
+
padding: 10px;
|
| 32 |
+
border-radius: 5px;
|
| 33 |
+
margin-bottom: 20px;
|
| 34 |
+
font-size: 0.9rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.user-message {
|
| 38 |
+
background-color: #007bff;
|
| 39 |
+
color: white;
|
| 40 |
+
padding: 12px 16px;
|
| 41 |
+
border-radius: 18px 18px 4px 18px;
|
| 42 |
+
margin: 8px 0;
|
| 43 |
+
margin-left: 20%;
|
| 44 |
+
word-wrap: break-word;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.bot-message {
|
| 48 |
+
background-color: #f1f3f4;
|
| 49 |
+
color: #333;
|
| 50 |
+
padding: 12px 16px;
|
| 51 |
+
border-radius: 18px 18px 18px 4px;
|
| 52 |
+
margin: 8px 0;
|
| 53 |
+
margin-right: 20%;
|
| 54 |
+
word-wrap: break-word;
|
| 55 |
+
border: 1px solid #e0e0e0;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.filter-section {
|
| 59 |
+
margin-bottom: 20px;
|
| 60 |
+
padding: 15px;
|
| 61 |
+
background-color: #f8f9fa;
|
| 62 |
+
border-radius: 8px;
|
| 63 |
+
border: 1px solid #e9ecef;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.filter-title {
|
| 67 |
+
font-weight: bold;
|
| 68 |
+
margin-bottom: 10px;
|
| 69 |
+
color: #495057;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.feedback-section {
|
| 73 |
+
background-color: #f8f9fa;
|
| 74 |
+
padding: 20px;
|
| 75 |
+
border-radius: 10px;
|
| 76 |
+
margin-top: 30px;
|
| 77 |
+
border: 2px solid #dee2e6;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.retrieval-history {
|
| 81 |
+
background-color: #ffffff;
|
| 82 |
+
padding: 15px;
|
| 83 |
+
border-radius: 5px;
|
| 84 |
+
margin: 10px 0;
|
| 85 |
+
border-left: 4px solid #007bff;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.retrieval-distribution-container {
|
| 89 |
+
background-color: #ffffff;
|
| 90 |
+
padding: 25px;
|
| 91 |
+
border-radius: 10px;
|
| 92 |
+
margin: 20px 0;
|
| 93 |
+
border: 2px solid #e0e0e0;
|
| 94 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.metric-label {
|
| 98 |
+
font-size: 0.9rem;
|
| 99 |
+
color: #555;
|
| 100 |
+
margin-bottom: 5px;
|
| 101 |
+
text-align: center;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.metric-value {
|
| 105 |
+
font-size: 1.8rem;
|
| 106 |
+
font-weight: bold;
|
| 107 |
+
color: #000000;
|
| 108 |
+
text-align: center;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.metric-container {
|
| 112 |
+
text-align: center;
|
| 113 |
+
padding: 10px;
|
| 114 |
+
}
|
| 115 |
+
</style>
|
| 116 |
+
"""
|
| 117 |
+
|
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI utility functions for data processing and statistics
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
|
| 10 |
+
"""Extract statistics from retrieved chunks."""
|
| 11 |
+
if not sources:
|
| 12 |
+
return {}
|
| 13 |
+
|
| 14 |
+
sources_list = []
|
| 15 |
+
years = []
|
| 16 |
+
filenames = []
|
| 17 |
+
districts = []
|
| 18 |
+
|
| 19 |
+
for doc in sources:
|
| 20 |
+
metadata = getattr(doc, 'metadata', {})
|
| 21 |
+
|
| 22 |
+
# Extract source
|
| 23 |
+
source = metadata.get('source', 'Unknown')
|
| 24 |
+
sources_list.append(source)
|
| 25 |
+
|
| 26 |
+
# Extract year
|
| 27 |
+
year = metadata.get('year', 'Unknown')
|
| 28 |
+
if year and year != 'Unknown':
|
| 29 |
+
try:
|
| 30 |
+
# Convert to int first, then back to string to ensure it's a proper year
|
| 31 |
+
year_int = int(float(year)) # Handle both int and float strings
|
| 32 |
+
if 1900 <= year_int <= 2030: # Reasonable year range
|
| 33 |
+
years.append(str(year_int))
|
| 34 |
+
else:
|
| 35 |
+
years.append('Unknown')
|
| 36 |
+
except (ValueError, TypeError):
|
| 37 |
+
years.append('Unknown')
|
| 38 |
+
else:
|
| 39 |
+
years.append('Unknown')
|
| 40 |
+
|
| 41 |
+
# Extract filename
|
| 42 |
+
filename = metadata.get('filename', 'Unknown')
|
| 43 |
+
filenames.append(filename)
|
| 44 |
+
|
| 45 |
+
# Extract district
|
| 46 |
+
district = metadata.get('district', 'Unknown')
|
| 47 |
+
if district and district != 'Unknown':
|
| 48 |
+
districts.append(district)
|
| 49 |
+
else:
|
| 50 |
+
districts.append('Unknown')
|
| 51 |
+
|
| 52 |
+
# Count occurrences
|
| 53 |
+
source_counts = Counter(sources_list)
|
| 54 |
+
year_counts = Counter(years)
|
| 55 |
+
filename_counts = Counter(filenames)
|
| 56 |
+
district_counts = Counter(districts)
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
'total_chunks': len(sources),
|
| 60 |
+
'unique_sources': len(source_counts),
|
| 61 |
+
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
|
| 62 |
+
'unique_filenames': len(filename_counts),
|
| 63 |
+
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
|
| 64 |
+
'source_distribution': dict(source_counts),
|
| 65 |
+
'year_distribution': dict(year_counts),
|
| 66 |
+
'filename_distribution': dict(filename_counts),
|
| 67 |
+
'district_distribution': dict(district_counts),
|
| 68 |
+
'sources': sources_list,
|
| 69 |
+
'years': years,
|
| 70 |
+
'filenames': filenames,
|
| 71 |
+
'districts': districts
|
| 72 |
+
}
|
| 73 |
+
|
|
File without changes
|
|
@@ -1,9 +1,20 @@
|
|
| 1 |
"""Vector store management and operations."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Dict, Any, List, Optional
|
| 4 |
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain_qdrant import QdrantVectorStore
|
| 8 |
from langchain.docstore.document import Document
|
| 9 |
from langchain_core.embeddings import Embeddings
|
|
@@ -28,11 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
|
|
| 28 |
|
| 29 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 34 |
else:
|
| 35 |
# Use standard HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 37 |
|
| 38 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -76,12 +99,17 @@ class VectorStoreManager:
|
|
| 76 |
|
| 77 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 78 |
"""Create embeddings model from configuration."""
|
| 79 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
-
|
| 81 |
model_name = self.config["retriever"]["model"]
|
| 82 |
normalize = self.config["retriever"]["normalize"]
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
encode_kwargs = {
|
| 86 |
"normalize_embeddings": normalize,
|
| 87 |
"batch_size": 100,
|
|
@@ -108,6 +136,8 @@ class VectorStoreManager:
|
|
| 108 |
return embeddings
|
| 109 |
|
| 110 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
|
|
|
|
|
|
| 111 |
embeddings = HuggingFaceEmbeddings(
|
| 112 |
model_name=model_name,
|
| 113 |
model_kwargs=model_kwargs,
|
|
|
|
| 1 |
"""Vector store management and operations."""
|
| 2 |
+
import os
|
| 3 |
+
# Disable MPS before importing torch to prevent meta tensor issues on Mac
|
| 4 |
+
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
| 5 |
+
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
|
| 6 |
+
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, Any, List, Optional
|
| 9 |
|
| 10 |
|
| 11 |
import torch
|
| 12 |
+
# Disable MPS backend explicitly to prevent meta tensor issues
|
| 13 |
+
if hasattr(torch.backends, 'mps'):
|
| 14 |
+
# Monkey patch to disable MPS
|
| 15 |
+
original_mps_available = torch.backends.mps.is_available
|
| 16 |
+
torch.backends.mps.is_available = lambda: False
|
| 17 |
+
|
| 18 |
from langchain_qdrant import QdrantVectorStore
|
| 19 |
from langchain.docstore.document import Document
|
| 20 |
from langchain_core.embeddings import Embeddings
|
|
|
|
| 39 |
|
| 40 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 41 |
# Use SentenceTransformer directly for Matryoshka models
|
| 42 |
+
# Fix for meta tensor issue: Explicitly force CPU
|
| 43 |
+
# MPS is already disabled at module level
|
| 44 |
+
# Explicitly pass device="cpu" to prevent MPS/CUDA detection
|
| 45 |
+
self.model = SentenceTransformer(
|
| 46 |
+
model_name,
|
| 47 |
+
truncate_dim=truncate_dim,
|
| 48 |
+
device="cpu" # Force CPU to prevent meta tensor issues
|
| 49 |
+
)
|
| 50 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 51 |
else:
|
| 52 |
# Use standard HuggingFaceEmbeddings
|
| 53 |
+
# Don't pass device parameter - let it load naturally on CPU
|
| 54 |
+
# This prevents the meta tensor error
|
| 55 |
+
if "model_kwargs" not in kwargs:
|
| 56 |
+
kwargs["model_kwargs"] = {}
|
| 57 |
+
# Remove device from model_kwargs if present to prevent meta tensor issues
|
| 58 |
+
kwargs["model_kwargs"].pop("device", None)
|
| 59 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 60 |
|
| 61 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
| 99 |
|
| 100 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 101 |
"""Create embeddings model from configuration."""
|
|
|
|
|
|
|
| 102 |
model_name = self.config["retriever"]["model"]
|
| 103 |
normalize = self.config["retriever"]["normalize"]
|
| 104 |
|
| 105 |
+
# Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
|
| 106 |
+
# The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
|
| 107 |
+
# MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
|
| 108 |
+
model_kwargs = {
|
| 109 |
+
"device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
|
| 110 |
+
"trust_remote_code": True, # Some models need this
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
encode_kwargs = {
|
| 114 |
"normalize_embeddings": normalize,
|
| 115 |
"batch_size": 100,
|
|
|
|
| 136 |
return embeddings
|
| 137 |
|
| 138 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
| 139 |
+
# Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
|
| 140 |
+
# but ensure we're not using meta device
|
| 141 |
embeddings = HuggingFaceEmbeddings(
|
| 142 |
model_name=model_name,
|
| 143 |
model_kwargs=model_kwargs,
|