Spaces:
Sleeping
Sleeping
Gemini FSA
#6
by
Yeroyan
- opened
- .gitignore +115 -0
- Dockerfile +1 -0
- add_district_metadata.py +379 -0
- app.py +679 -311
- src/agents/__init__.py +10 -0
- src/agents/gemini_chatbot.py +392 -0
- multi_agent_chatbot.py β src/agents/multi_agent_chatbot.py +283 -53
- smart_chatbot.py β src/agents/smart_chatbot.py +4 -3
- src/config/paths.py +59 -0
- src/feedback/__init__.py +152 -0
- src/feedback/feedback_schema.py +161 -0
- src/feedback/snowflake_connector.py +331 -0
- src/gemini/__init__.py +11 -0
- src/gemini/file_search.py +427 -0
- src/{loader.py β llm/loader.py} +0 -0
- src/pipeline.py +33 -38
- src/reporting/__init__.py +5 -1
- src/reporting/feedback_schema.py +36 -71
- src/reporting/snowflake_connector.py +67 -39
- src/streamlit_app.py +0 -40
- src/ui_components/__init__.py +21 -0
- src/ui_components/components.py +202 -0
- src/ui_components/styles.py +117 -0
- src/ui_components/utils.py +73 -0
- utils.py β src/utils.py +0 -0
- src/vectorstore.py +35 -5
.gitignore
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==========================================
|
| 2 |
+
# PYTHON
|
| 3 |
+
# ==========================================
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# Virtual environments
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
env/
|
| 14 |
+
ENV/
|
| 15 |
+
.conda/
|
| 16 |
+
.venv*/
|
| 17 |
+
|
| 18 |
+
# Byte-compiled / optimized / DLL files
|
| 19 |
+
*.so
|
| 20 |
+
*.dll
|
| 21 |
+
*.dylib
|
| 22 |
+
|
| 23 |
+
# Logs and debug
|
| 24 |
+
*.log
|
| 25 |
+
*.out
|
| 26 |
+
*.err
|
| 27 |
+
logs/
|
| 28 |
+
debug/
|
| 29 |
+
*.sqlite3
|
| 30 |
+
|
| 31 |
+
# ==========================================
|
| 32 |
+
# BUILD / PACKAGING
|
| 33 |
+
# ==========================================
|
| 34 |
+
build/
|
| 35 |
+
dist/
|
| 36 |
+
*.egg-info/
|
| 37 |
+
.eggs/
|
| 38 |
+
pip-wheel-metadata/
|
| 39 |
+
.wheels/
|
| 40 |
+
|
| 41 |
+
# ==========================================
|
| 42 |
+
# JUPYTER / NOTEBOOKS
|
| 43 |
+
# ==========================================
|
| 44 |
+
.ipynb_checkpoints/
|
| 45 |
+
*.ipynb_convert/
|
| 46 |
+
|
| 47 |
+
# ==========================================
|
| 48 |
+
# DATA / MODELS / CACHE
|
| 49 |
+
# ==========================================
|
| 50 |
+
data/
|
| 51 |
+
datasets/
|
| 52 |
+
.cache/
|
| 53 |
+
*.ckpt
|
| 54 |
+
*.h5
|
| 55 |
+
*.hdf5
|
| 56 |
+
*.tflite
|
| 57 |
+
*.onnx
|
| 58 |
+
*.pth
|
| 59 |
+
*.pt
|
| 60 |
+
*.joblib
|
| 61 |
+
*.pkl
|
| 62 |
+
*.pickle
|
| 63 |
+
*.npz
|
| 64 |
+
*.npy
|
| 65 |
+
outputs/
|
| 66 |
+
artifacts/
|
| 67 |
+
checkpoints/
|
| 68 |
+
runs/
|
| 69 |
+
wandb/
|
| 70 |
+
mlruns/
|
| 71 |
+
lightning_logs/
|
| 72 |
+
|
| 73 |
+
# Hugging Face
|
| 74 |
+
huggingface/
|
| 75 |
+
~/.cache/huggingface/
|
| 76 |
+
~/.cache/torch/
|
| 77 |
+
~/.cache/datasets/
|
| 78 |
+
~/.cache/transformers/
|
| 79 |
+
|
| 80 |
+
# ==========================================
|
| 81 |
+
# EDITORS / TOOLS
|
| 82 |
+
# ==========================================
|
| 83 |
+
.vscode/
|
| 84 |
+
.idea/
|
| 85 |
+
*.swp
|
| 86 |
+
*.swo
|
| 87 |
+
*.bak
|
| 88 |
+
.DS_Store
|
| 89 |
+
Thumbs.db
|
| 90 |
+
|
| 91 |
+
# ==========================================
|
| 92 |
+
# ENV FILES / CREDENTIALS
|
| 93 |
+
# ==========================================
|
| 94 |
+
.env
|
| 95 |
+
.env.*
|
| 96 |
+
*.env.local
|
| 97 |
+
secrets.*
|
| 98 |
+
config.json
|
| 99 |
+
token.json
|
| 100 |
+
|
| 101 |
+
# ==========================================
|
| 102 |
+
# TESTS / TEMP FILES
|
| 103 |
+
# ==========================================
|
| 104 |
+
__tests__/
|
| 105 |
+
.tox/
|
| 106 |
+
.coverage
|
| 107 |
+
.cache/
|
| 108 |
+
pytest_cache/
|
| 109 |
+
tmp/
|
| 110 |
+
temp/
|
| 111 |
+
*.tmp
|
| 112 |
+
*.temp
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
local_*
|
Dockerfile
CHANGED
|
@@ -59,6 +59,7 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
|
| 59 |
CMD curl --fail http://localhost:8501/_stcore/health || exit 1
|
| 60 |
|
| 61 |
#temp developement commands
|
|
|
|
| 62 |
# RUN mkdir /app/conversations && chmod -R 777 conversations
|
| 63 |
# RUN mkdir /app/feedback && chmod -R 777 feedback
|
| 64 |
|
|
|
|
| 59 |
CMD curl --fail http://localhost:8501/_stcore/health || exit 1
|
| 60 |
|
| 61 |
#temp developement commands
|
| 62 |
+
RUN pip3 install plotly
|
| 63 |
# RUN mkdir /app/conversations && chmod -R 777 conversations
|
| 64 |
# RUN mkdir /app/feedback && chmod -R 777 feedback
|
| 65 |
|
add_district_metadata.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to add District metadata to Qdrant chunks based on filename analysis.
|
| 4 |
+
Handles Uganda districts, ministry mappings, and LLM inference for ambiguous cases.
|
| 5 |
+
"""
|
| 6 |
+
import re
|
| 7 |
+
import yaml
|
| 8 |
+
import logging
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from qdrant_client import QdrantClient
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class DistrictMapping:
|
| 22 |
+
"""Mapping for district-related entities"""
|
| 23 |
+
name: str
|
| 24 |
+
aliases: List[str]
|
| 25 |
+
is_district: bool = True
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DistrictMetadataProcessor:
|
| 29 |
+
def __init__(self, config_path: str = "src/config/settings.yaml"):
|
| 30 |
+
# Load config manually
|
| 31 |
+
with open(config_path, 'r') as f:
|
| 32 |
+
self.config = yaml.safe_load(f)
|
| 33 |
+
|
| 34 |
+
# Initialize Qdrant client (will be imported when needed)
|
| 35 |
+
self.llm_client = None
|
| 36 |
+
self.qdrant_client = None
|
| 37 |
+
self.collection_name = self.config["qdrant"]["collection_name"]
|
| 38 |
+
|
| 39 |
+
# Initialize district mappings
|
| 40 |
+
self.district_mappings = self._initialize_district_mappings()
|
| 41 |
+
self.ministry_mappings = self._initialize_ministry_mappings()
|
| 42 |
+
|
| 43 |
+
def _initialize_district_mappings(self) -> Dict[str, DistrictMapping]:
|
| 44 |
+
"""Initialize Uganda districts and their aliases"""
|
| 45 |
+
districts = [
|
| 46 |
+
# Central Region
|
| 47 |
+
DistrictMapping("Kampala", ["KCCA", "Kampala Capital City Authority"]),
|
| 48 |
+
DistrictMapping("Wakiso", ["Wakiso"]),
|
| 49 |
+
DistrictMapping("Mukono", ["Mukono"]),
|
| 50 |
+
DistrictMapping("Luweero", ["Luweero"]),
|
| 51 |
+
DistrictMapping("Nakaseke", ["Nakaseke"]),
|
| 52 |
+
DistrictMapping("Nakasongola", ["Nakasongola"]),
|
| 53 |
+
DistrictMapping("Kayunga", ["Kayunga"]),
|
| 54 |
+
DistrictMapping("Buikwe", ["Buikwe"]),
|
| 55 |
+
DistrictMapping("Buvuma", ["Buvuma"]),
|
| 56 |
+
|
| 57 |
+
# Northern Region
|
| 58 |
+
DistrictMapping("Gulu", ["Gulu", "Gulu DLG"]),
|
| 59 |
+
DistrictMapping("Kitgum", ["Kitgum"]),
|
| 60 |
+
DistrictMapping("Pader", ["Pader"]),
|
| 61 |
+
DistrictMapping("Agago", ["Agago"]),
|
| 62 |
+
DistrictMapping("Lamwo", ["Lamwo"]),
|
| 63 |
+
DistrictMapping("Nwoya", ["Nwoya"]),
|
| 64 |
+
DistrictMapping("Amuru", ["Amuru"]),
|
| 65 |
+
DistrictMapping("Omoro", ["Omoro"]),
|
| 66 |
+
DistrictMapping("Oyam", ["Oyam"]),
|
| 67 |
+
DistrictMapping("Kole", ["Kole"]),
|
| 68 |
+
DistrictMapping("Apac", ["Apac", "Apac District"]),
|
| 69 |
+
DistrictMapping("Lira", ["Lira"]),
|
| 70 |
+
DistrictMapping("Alebtong", ["Alebtong"]),
|
| 71 |
+
DistrictMapping("Amolatar", ["Amolatar"]),
|
| 72 |
+
DistrictMapping("Dokolo", ["Dokolo"]),
|
| 73 |
+
DistrictMapping("Otuke", ["Otuke"]),
|
| 74 |
+
DistrictMapping("Kwania", ["Kwania"]),
|
| 75 |
+
|
| 76 |
+
# Eastern Region
|
| 77 |
+
DistrictMapping("Jinja", ["Jinja"]),
|
| 78 |
+
DistrictMapping("Kamuli", ["Kamuli"]),
|
| 79 |
+
DistrictMapping("Iganga", ["Iganga"]),
|
| 80 |
+
DistrictMapping("Bugiri", ["Bugiri"]),
|
| 81 |
+
DistrictMapping("Mayuge", ["Mayuge"]),
|
| 82 |
+
DistrictMapping("Namayingo", ["Namayingo"]),
|
| 83 |
+
DistrictMapping("Busia", ["Busia"]),
|
| 84 |
+
DistrictMapping("Tororo", ["Tororo"]),
|
| 85 |
+
DistrictMapping("Pallisa", ["Pallisa"]),
|
| 86 |
+
DistrictMapping("Kumi", ["Kumi"]),
|
| 87 |
+
DistrictMapping("Bukedea", ["Bukedea"]),
|
| 88 |
+
DistrictMapping("Soroti", ["Soroti"]),
|
| 89 |
+
DistrictMapping("Serere", ["Serere"]),
|
| 90 |
+
DistrictMapping("Ngora", ["Ngora"]),
|
| 91 |
+
DistrictMapping("Kaberamaido", ["Kaberamaido"]),
|
| 92 |
+
DistrictMapping("Kalaki", ["Kalaki"]),
|
| 93 |
+
DistrictMapping("Kapelebyong", ["Kapelebyong"]),
|
| 94 |
+
DistrictMapping("Amuria", ["Amuria"]),
|
| 95 |
+
DistrictMapping("Katakwi", ["Katakwi"]),
|
| 96 |
+
DistrictMapping("Kotido", ["Kotido"]),
|
| 97 |
+
DistrictMapping("Abim", ["Abim"]),
|
| 98 |
+
DistrictMapping("Kaabong", ["Kaabong", "Kaabong District"]),
|
| 99 |
+
DistrictMapping("Karenga", ["Karenga"]),
|
| 100 |
+
DistrictMapping("Moroto", ["Moroto"]),
|
| 101 |
+
DistrictMapping("Napak", ["Napak"]),
|
| 102 |
+
DistrictMapping("Nabilatuk", ["Nabilatuk"]),
|
| 103 |
+
DistrictMapping("Amudat", ["Amudat"]),
|
| 104 |
+
DistrictMapping("Nakapiripirit", ["Nakapiripirit"]),
|
| 105 |
+
DistrictMapping("Bukwo", ["Bukwo"]),
|
| 106 |
+
DistrictMapping("Kween", ["Kween"]),
|
| 107 |
+
DistrictMapping("Kapchorwa", ["Kapchorwa"]),
|
| 108 |
+
DistrictMapping("Sironko", ["Sironko"]),
|
| 109 |
+
DistrictMapping("Manafwa", ["Manafwa"]),
|
| 110 |
+
DistrictMapping("Bududa", ["Bududa"]),
|
| 111 |
+
DistrictMapping("Mbale", ["Mbale"]),
|
| 112 |
+
DistrictMapping("Butaleja", ["Butaleja"]),
|
| 113 |
+
DistrictMapping("Namisindwa", ["Namisindwa"]),
|
| 114 |
+
DistrictMapping("Bulambuli", ["Bulambuli"]),
|
| 115 |
+
|
| 116 |
+
# Western Region
|
| 117 |
+
DistrictMapping("Masaka", ["Masaka"]),
|
| 118 |
+
DistrictMapping("Kalungu", ["Kalungu"]),
|
| 119 |
+
DistrictMapping("Bukomansimbi", ["Bukomansimbi"]),
|
| 120 |
+
DistrictMapping("Lwengo", ["Lwengo"]),
|
| 121 |
+
DistrictMapping("Sembabule", ["Sembabule"]),
|
| 122 |
+
DistrictMapping("Rakai", ["Rakai"]),
|
| 123 |
+
DistrictMapping("Kyotera", ["Kyotera"]),
|
| 124 |
+
DistrictMapping("Mpigi", ["Mpigi"]),
|
| 125 |
+
DistrictMapping("Butambala", ["Butambala"]),
|
| 126 |
+
DistrictMapping("Gomba", ["Gomba"]),
|
| 127 |
+
DistrictMapping("Mityana", ["Mityana"]),
|
| 128 |
+
DistrictMapping("Mubende", ["Mubende"]),
|
| 129 |
+
DistrictMapping("Kassanda", ["Kassanda"]),
|
| 130 |
+
DistrictMapping("Kiboga", ["Kiboga"]),
|
| 131 |
+
DistrictMapping("Kyankwanzi", ["Kyankwanzi"]),
|
| 132 |
+
DistrictMapping("Hoima", ["Hoima"]),
|
| 133 |
+
DistrictMapping("Kikuube", ["Kikuube"]),
|
| 134 |
+
DistrictMapping("Kakumiro", ["Kakumiro"]),
|
| 135 |
+
DistrictMapping("Kibaale", ["Kibaale"]),
|
| 136 |
+
DistrictMapping("Kagadi", ["Kagadi"]),
|
| 137 |
+
DistrictMapping("Buliisa", ["Buliisa"]),
|
| 138 |
+
DistrictMapping("Masindi", ["Masindi"]),
|
| 139 |
+
DistrictMapping("Kiryandongo", ["Kiryandongo"]),
|
| 140 |
+
DistrictMapping("Buliisa", ["Buliisa"]),
|
| 141 |
+
DistrictMapping("Pakwach", ["Pakwach"]),
|
| 142 |
+
DistrictMapping("Nebbi", ["Nebbi"]),
|
| 143 |
+
DistrictMapping("Zombo", ["Zombo"]),
|
| 144 |
+
DistrictMapping("Arua", ["Arua"]),
|
| 145 |
+
DistrictMapping("Terego", ["Terego"]),
|
| 146 |
+
DistrictMapping("Madi-Okollo", ["Madi-Okollo"]),
|
| 147 |
+
DistrictMapping("Obongi", ["Obongi"]),
|
| 148 |
+
DistrictMapping("Moyo", ["Moyo"]),
|
| 149 |
+
DistrictMapping("Yumbe", ["Yumbe"]),
|
| 150 |
+
DistrictMapping("Koboko", ["Koboko"]),
|
| 151 |
+
DistrictMapping("Maracha", ["Maracha"]),
|
| 152 |
+
DistrictMapping("Adjumani", ["Adjumani"]),
|
| 153 |
+
|
| 154 |
+
# South Western Region
|
| 155 |
+
DistrictMapping("Mbarara", ["Mbarara"]),
|
| 156 |
+
DistrictMapping("Ibanda", ["Ibanda"]),
|
| 157 |
+
DistrictMapping("Isingiro", ["Isingiro"]),
|
| 158 |
+
DistrictMapping("Kiruhura", ["Kiruhura"]),
|
| 159 |
+
DistrictMapping("Kazo", ["Kazo"]),
|
| 160 |
+
DistrictMapping("Ntungamo", ["Ntungamo"]),
|
| 161 |
+
DistrictMapping("Rwampara", ["Rwampara"]),
|
| 162 |
+
DistrictMapping("Rubanda", ["Rubanda"]),
|
| 163 |
+
DistrictMapping("Rukiga", ["Rukiga"]),
|
| 164 |
+
DistrictMapping("Kanungu", ["Kanungu"]),
|
| 165 |
+
DistrictMapping("Rukungiri", ["Rukungiri"]),
|
| 166 |
+
DistrictMapping("Kisoro", ["Kisoro"]),
|
| 167 |
+
DistrictMapping("Bundibugyo", ["Bundibugyo"]),
|
| 168 |
+
DistrictMapping("Ntoroko", ["Ntoroko"]),
|
| 169 |
+
DistrictMapping("Kasese", ["Kasese"]),
|
| 170 |
+
DistrictMapping("Bunyangabu", ["Bunyangabu"]),
|
| 171 |
+
DistrictMapping("Fort Portal", ["Fort Portal"]),
|
| 172 |
+
DistrictMapping("Kabarole", ["Kabarole"]),
|
| 173 |
+
DistrictMapping("Kyenjojo", ["Kyenjojo"]),
|
| 174 |
+
DistrictMapping("Kamwenge", ["Kamwenge"]),
|
| 175 |
+
DistrictMapping("Kitagwenda", ["Kitagwenda"]),
|
| 176 |
+
DistrictMapping("Kyegegwa", ["Kyegegwa"]),
|
| 177 |
+
DistrictMapping("Mitooma", ["Mitooma"]),
|
| 178 |
+
DistrictMapping("Rubirizi", ["Rubirizi"]),
|
| 179 |
+
DistrictMapping("Sheema", ["Sheema"]),
|
| 180 |
+
DistrictMapping("Bushenyi", ["Bushenyi"]),
|
| 181 |
+
|
| 182 |
+
# Special cases
|
| 183 |
+
DistrictMapping("Kalangala", ["Kalangala", "Kalangala DLG"]),
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
# Create mapping dictionary
|
| 187 |
+
mapping_dict = {}
|
| 188 |
+
for district in districts:
|
| 189 |
+
mapping_dict[district.name.lower()] = district
|
| 190 |
+
for alias in district.aliases:
|
| 191 |
+
mapping_dict[alias.lower()] = district
|
| 192 |
+
return mapping_dict
|
| 193 |
+
|
| 194 |
+
def _initialize_ministry_mappings(self) -> Dict[str, str]:
|
| 195 |
+
"""Initialize ministry and organization mappings"""
|
| 196 |
+
return {
|
| 197 |
+
"maaif": "Ministry of Agriculture, Animal Industry and Fisheries",
|
| 198 |
+
"mwts": "Ministry of Works and Transport",
|
| 199 |
+
"kcca": "Kampala Capital City Authority",
|
| 200 |
+
"oag": "Office of the Auditor General",
|
| 201 |
+
"arsdp": "Albertine Regional Sustainable Development Project",
|
| 202 |
+
"avcdp": "Agriculture Value Chain Development Project",
|
| 203 |
+
"ida": "International Development Association",
|
| 204 |
+
"dlg": "District Local Government",
|
| 205 |
+
"lg": "Local Government",
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
def _extract_district_from_filename(self, filename: str) -> Optional[str]:
|
| 209 |
+
"""Extract district from filename using pattern matching"""
|
| 210 |
+
filename_lower = filename.lower()
|
| 211 |
+
|
| 212 |
+
# Check for explicit district mentions
|
| 213 |
+
for key, district_mapping in self.district_mappings.items():
|
| 214 |
+
if key in filename_lower:
|
| 215 |
+
return district_mapping.name
|
| 216 |
+
|
| 217 |
+
# Check for ministry/organization patterns that are NOT districts
|
| 218 |
+
for ministry_key in self.ministry_mappings.keys():
|
| 219 |
+
if ministry_key in filename_lower:
|
| 220 |
+
return None # This is a ministry, not a district
|
| 221 |
+
|
| 222 |
+
# Check for patterns like "District Local Government"
|
| 223 |
+
district_pattern = r'(\w+)\s+district\s+local\s+government'
|
| 224 |
+
match = re.search(district_pattern, filename_lower)
|
| 225 |
+
if match:
|
| 226 |
+
district_name = match.group(1).title()
|
| 227 |
+
if district_name.lower() in self.district_mappings:
|
| 228 |
+
return self.district_mappings[district_name.lower()].name
|
| 229 |
+
|
| 230 |
+
# Check for patterns like "DLG Report"
|
| 231 |
+
dlg_pattern = r'(\w+)\s+dlg\s+report'
|
| 232 |
+
match = re.search(dlg_pattern, filename_lower)
|
| 233 |
+
if match:
|
| 234 |
+
district_name = match.group(1).title()
|
| 235 |
+
if district_name.lower() in self.district_mappings:
|
| 236 |
+
return self.district_mappings[district_name.lower()].name
|
| 237 |
+
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
def _infer_district_with_llm(self, filename: str) -> Optional[str]:
|
| 241 |
+
"""Use LLM to infer district from filename when pattern matching fails"""
|
| 242 |
+
# For now, return None - LLM integration can be added later
|
| 243 |
+
logger.info(f"LLM inference needed for filename: {filename}")
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
def infer_district(self, filename: str) -> Optional[str]:
|
| 247 |
+
"""Main method to infer district from filename"""
|
| 248 |
+
# First try pattern matching
|
| 249 |
+
district = self._extract_district_from_filename(filename)
|
| 250 |
+
if district:
|
| 251 |
+
return district
|
| 252 |
+
|
| 253 |
+
# If pattern matching fails, use LLM
|
| 254 |
+
return self._infer_district_with_llm(filename)
|
| 255 |
+
|
| 256 |
+
def fetch_chunks_batch(self, batch_size: int = 100, offset: int = 0) -> List[Dict]:
|
| 257 |
+
"""Fetch a batch of chunks from Qdrant (metadata only)"""
|
| 258 |
+
try:
|
| 259 |
+
# Import Qdrant client when needed
|
| 260 |
+
if self.qdrant_client is None:
|
| 261 |
+
self.qdrant_client = QdrantClient(
|
| 262 |
+
url=self.config["qdrant"]["url"],
|
| 263 |
+
api_key=self.config["qdrant"]["api_key"]
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Get points with metadata only (no vectors)
|
| 267 |
+
points = self.qdrant_client.scroll(
|
| 268 |
+
collection_name=self.collection_name,
|
| 269 |
+
limit=batch_size,
|
| 270 |
+
offset=offset,
|
| 271 |
+
with_payload=True,
|
| 272 |
+
with_vectors=False
|
| 273 |
+
)[0]
|
| 274 |
+
|
| 275 |
+
return points
|
| 276 |
+
except Exception as e:
|
| 277 |
+
logger.error(f"Failed to fetch batch: {e}")
|
| 278 |
+
return []
|
| 279 |
+
|
| 280 |
+
def update_chunks_with_district(self, points: List[Dict]) -> int:
|
| 281 |
+
"""Update chunks with district metadata"""
|
| 282 |
+
updated_count = 0
|
| 283 |
+
|
| 284 |
+
# Import Qdrant client when needed
|
| 285 |
+
if self.qdrant_client is None:
|
| 286 |
+
from qdrant_client import QdrantClient
|
| 287 |
+
self.qdrant_client = QdrantClient(
|
| 288 |
+
url=self.config["qdrant"]["url"],
|
| 289 |
+
api_key=self.config["qdrant"]["api_key"]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
for point in points:
|
| 293 |
+
try:
|
| 294 |
+
point_id = point.id
|
| 295 |
+
metadata = point.payload.get("metadata", {})
|
| 296 |
+
filename = metadata.get("filename", "")
|
| 297 |
+
|
| 298 |
+
if not filename:
|
| 299 |
+
logger.warning(f"Point {point_id} has no filename")
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
# Infer district
|
| 303 |
+
district = self.infer_district(filename)
|
| 304 |
+
|
| 305 |
+
# Update metadata
|
| 306 |
+
updated_metadata = metadata.copy()
|
| 307 |
+
updated_metadata["district"] = district
|
| 308 |
+
|
| 309 |
+
# Update point in Qdrant
|
| 310 |
+
self.qdrant_client.set_payload(
|
| 311 |
+
collection_name=self.collection_name,
|
| 312 |
+
payload={"metadata": updated_metadata},
|
| 313 |
+
points=[point_id]
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
updated_count += 1
|
| 317 |
+
logger.info(f"Updated point {point_id}: {filename} -> {district}")
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Failed to update point {point_id}: {e}")
|
| 321 |
+
|
| 322 |
+
return updated_count
|
| 323 |
+
|
| 324 |
+
def process_all_chunks(self, batch_size: int = 100):
|
| 325 |
+
"""Process all chunks in batches"""
|
| 326 |
+
total_updated = 0
|
| 327 |
+
offset = 0
|
| 328 |
+
|
| 329 |
+
logger.info(f"Starting to process chunks in batches of {batch_size}")
|
| 330 |
+
|
| 331 |
+
while True:
|
| 332 |
+
# Fetch batch
|
| 333 |
+
points = self.fetch_chunks_batch(batch_size, offset)
|
| 334 |
+
if not points:
|
| 335 |
+
break
|
| 336 |
+
|
| 337 |
+
logger.info(f"Processing batch: {len(points)} points (offset: {offset})")
|
| 338 |
+
|
| 339 |
+
# Update batch
|
| 340 |
+
updated_count = self.update_chunks_with_district(points)
|
| 341 |
+
total_updated += updated_count
|
| 342 |
+
|
| 343 |
+
logger.info(f"Updated {updated_count} points in this batch")
|
| 344 |
+
|
| 345 |
+
# Move to next batch
|
| 346 |
+
offset += batch_size
|
| 347 |
+
|
| 348 |
+
logger.info(f"Total updated: {total_updated} points")
|
| 349 |
+
return total_updated
|
| 350 |
+
|
| 351 |
+
def main():
|
| 352 |
+
"""Main function to run the district metadata processor"""
|
| 353 |
+
try:
|
| 354 |
+
processor = DistrictMetadataProcessor()
|
| 355 |
+
|
| 356 |
+
# Test with a small batch first
|
| 357 |
+
logger.info("Testing with first 10 chunks...")
|
| 358 |
+
test_points = processor.fetch_chunks_batch(10, 0)
|
| 359 |
+
|
| 360 |
+
if test_points:
|
| 361 |
+
logger.info("Test batch fetched successfully. Processing...")
|
| 362 |
+
for point in test_points:
|
| 363 |
+
filename = point.payload.get("metadata", {}).get("filename", "")
|
| 364 |
+
district = processor.infer_district(filename)
|
| 365 |
+
logger.info(f"Test: {filename} -> {district}")
|
| 366 |
+
|
| 367 |
+
# Ask user if they want to proceed with full processing
|
| 368 |
+
response = input("\nProceed with full processing? (y/n): ")
|
| 369 |
+
if response.lower() == 'y':
|
| 370 |
+
processor.process_all_chunks(batch_size=100)
|
| 371 |
+
else:
|
| 372 |
+
logger.info("Processing cancelled by user")
|
| 373 |
+
|
| 374 |
+
except Exception as e:
|
| 375 |
+
logger.error(f"Error in main: {e}")
|
| 376 |
+
raise
|
| 377 |
+
|
| 378 |
+
if __name__ == "__main__":
|
| 379 |
+
main()
|
app.py
CHANGED
|
@@ -3,7 +3,36 @@ Intelligent Audit Report Chatbot UI
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 9 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
|
@@ -29,42 +58,33 @@ except (ValueError, TypeError):
|
|
| 29 |
|
| 30 |
# ===== Setup HuggingFace cache directories BEFORE any model imports =====
|
| 31 |
# CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
os.environ["
|
| 36 |
-
os.environ["
|
| 37 |
-
os.environ["
|
| 38 |
-
os.environ["
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
import
|
| 49 |
-
|
| 50 |
-
import uuid
|
| 51 |
-
import logging
|
| 52 |
-
from pathlib import Path
|
| 53 |
-
|
| 54 |
-
import argparse
|
| 55 |
-
import streamlit as st
|
| 56 |
-
from langchain_core.messages import HumanMessage, AIMessage
|
| 57 |
-
|
| 58 |
-
from multi_agent_chatbot import get_multi_agent_chatbot
|
| 59 |
-
from smart_chatbot import get_chatbot as get_smart_chatbot
|
| 60 |
-
from src.reporting.feedback_schema import create_feedback_from_dict
|
| 61 |
|
| 62 |
# Configure logging
|
| 63 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 64 |
logger = logging.getLogger(__name__)
|
| 65 |
|
| 66 |
# Log environment setup for debugging
|
| 67 |
-
logger.info(f"
|
|
|
|
|
|
|
| 68 |
logger.info(f"π§ OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
|
| 69 |
|
| 70 |
|
|
@@ -76,84 +96,9 @@ st.set_page_config(
|
|
| 76 |
page_title="Intelligent Audit Report Chatbot"
|
| 77 |
)
|
| 78 |
|
| 79 |
-
|
| 80 |
-
st.markdown(
|
| 81 |
-
|
| 82 |
-
.main-header {
|
| 83 |
-
font-size: 2.5rem;
|
| 84 |
-
font-weight: bold;
|
| 85 |
-
color: #1f77b4;
|
| 86 |
-
text-align: center;
|
| 87 |
-
margin-bottom: 1rem;
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
.subtitle {
|
| 91 |
-
font-size: 1.2rem;
|
| 92 |
-
color: #666;
|
| 93 |
-
text-align: center;
|
| 94 |
-
margin-bottom: 2rem;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
.session-info {
|
| 98 |
-
background-color: #f0f2f6;
|
| 99 |
-
padding: 10px;
|
| 100 |
-
border-radius: 5px;
|
| 101 |
-
margin-bottom: 20px;
|
| 102 |
-
font-size: 0.9rem;
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
.user-message {
|
| 106 |
-
background-color: #007bff;
|
| 107 |
-
color: white;
|
| 108 |
-
padding: 12px 16px;
|
| 109 |
-
border-radius: 18px 18px 4px 18px;
|
| 110 |
-
margin: 8px 0;
|
| 111 |
-
margin-left: 20%;
|
| 112 |
-
word-wrap: break-word;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
.bot-message {
|
| 116 |
-
background-color: #f1f3f4;
|
| 117 |
-
color: #333;
|
| 118 |
-
padding: 12px 16px;
|
| 119 |
-
border-radius: 18px 18px 18px 4px;
|
| 120 |
-
margin: 8px 0;
|
| 121 |
-
margin-right: 20%;
|
| 122 |
-
word-wrap: break-word;
|
| 123 |
-
border: 1px solid #e0e0e0;
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
.filter-section {
|
| 127 |
-
margin-bottom: 20px;
|
| 128 |
-
padding: 15px;
|
| 129 |
-
background-color: #f8f9fa;
|
| 130 |
-
border-radius: 8px;
|
| 131 |
-
border: 1px solid #e9ecef;
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
.filter-title {
|
| 135 |
-
font-weight: bold;
|
| 136 |
-
margin-bottom: 10px;
|
| 137 |
-
color: #495057;
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
.feedback-section {
|
| 141 |
-
background-color: #f8f9fa;
|
| 142 |
-
padding: 20px;
|
| 143 |
-
border-radius: 10px;
|
| 144 |
-
margin-top: 30px;
|
| 145 |
-
border: 2px solid #dee2e6;
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
.retrieval-history {
|
| 149 |
-
background-color: #ffffff;
|
| 150 |
-
padding: 15px;
|
| 151 |
-
border-radius: 5px;
|
| 152 |
-
margin: 10px 0;
|
| 153 |
-
border-left: 4px solid #007bff;
|
| 154 |
-
}
|
| 155 |
-
</style>
|
| 156 |
-
""", unsafe_allow_html=True)
|
| 157 |
|
| 158 |
def get_system_type():
|
| 159 |
"""Get the current system type"""
|
|
@@ -163,14 +108,17 @@ def get_system_type():
|
|
| 163 |
else:
|
| 164 |
return "Multi-Agent System"
|
| 165 |
|
| 166 |
-
def get_chatbot():
|
| 167 |
-
"""Initialize and return the chatbot based on
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
if system == 'smart':
|
| 171 |
-
return get_smart_chatbot()
|
| 172 |
else:
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
def serialize_messages(messages):
|
| 176 |
"""Serialize LangChain messages to dictionaries"""
|
|
@@ -215,13 +163,18 @@ def serialize_documents(sources):
|
|
| 215 |
|
| 216 |
return serialized
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
@st.cache_data
|
| 219 |
def load_filter_options():
|
| 220 |
try:
|
| 221 |
-
|
|
|
|
| 222 |
return json.load(f)
|
| 223 |
except FileNotFoundError:
|
| 224 |
-
st.info(
|
| 225 |
st.error("filter_options.json not found. Please run the metadata analysis script.")
|
| 226 |
return {"sources": [], "years": [], "districts": [], 'filenames': []}
|
| 227 |
|
|
@@ -238,11 +191,48 @@ def main():
|
|
| 238 |
# Track RAG retrieval history for feedback
|
| 239 |
if 'rag_retrieval_history' not in st.session_state:
|
| 240 |
st.session_state.rag_retrieval_history = []
|
| 241 |
-
#
|
| 242 |
-
if '
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 248 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
@@ -254,17 +244,43 @@ def main():
|
|
| 254 |
st.session_state.reset_conversation = False
|
| 255 |
st.rerun()
|
| 256 |
|
| 257 |
-
|
|
|
|
| 258 |
col1, col2 = st.columns([3, 1])
|
| 259 |
with col1:
|
| 260 |
-
st.markdown('<
|
| 261 |
with col2:
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
st.
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Session info
|
| 270 |
duration = int(time.time() - st.session_state.session_start_time)
|
|
@@ -280,6 +296,34 @@ def main():
|
|
| 280 |
|
| 281 |
# Sidebar for filters
|
| 282 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
st.markdown("### π Search Filters")
|
| 284 |
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
|
| 285 |
|
|
@@ -298,7 +342,7 @@ def main():
|
|
| 298 |
# Determine if filename filter is active
|
| 299 |
filename_mode = len(selected_filenames) > 0
|
| 300 |
# Sources filter
|
| 301 |
-
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 302 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 303 |
selected_sources = st.multiselect(
|
| 304 |
"Select sources:",
|
|
@@ -311,7 +355,7 @@ def main():
|
|
| 311 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 312 |
|
| 313 |
# Years filter
|
| 314 |
-
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 315 |
st.markdown('<div class="filter-title">π
Years</div>', unsafe_allow_html=True)
|
| 316 |
selected_years = st.multiselect(
|
| 317 |
"Select years:",
|
|
@@ -324,7 +368,7 @@ def main():
|
|
| 324 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 325 |
|
| 326 |
# Districts filter
|
| 327 |
-
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 328 |
st.markdown('<div class="filter-title">ποΈ Districts</div>', unsafe_allow_html=True)
|
| 329 |
selected_districts = st.multiselect(
|
| 330 |
"Select districts:",
|
|
@@ -375,26 +419,37 @@ def main():
|
|
| 375 |
if 'input_counter' not in st.session_state:
|
| 376 |
st.session_state.input_counter = 0
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
user_input = st.text_input(
|
| 379 |
"Type your message here...",
|
| 380 |
placeholder="Ask about budget allocations, expenditures, or audit findings...",
|
| 381 |
-
key=f"user_input_{
|
| 382 |
-
label_visibility="collapsed"
|
|
|
|
| 383 |
)
|
| 384 |
|
| 385 |
with col2:
|
| 386 |
-
send_button = st.button("Send", key="send_button",
|
| 387 |
|
| 388 |
# Clear chat button
|
| 389 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
| 390 |
st.session_state.reset_conversation = True
|
| 391 |
# Clear all conversation files
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
os.remove(os.path.join(conversations_dir, file))
|
| 398 |
st.rerun()
|
| 399 |
|
| 400 |
# Handle user input
|
|
@@ -436,6 +491,36 @@ def main():
|
|
| 436 |
if rag_result:
|
| 437 |
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# Get the actual RAG query
|
| 440 |
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 441 |
if actual_rag_query:
|
|
@@ -445,12 +530,25 @@ def main():
|
|
| 445 |
else:
|
| 446 |
formatted_query = "No RAG query available"
|
| 447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
retrieval_entry = {
|
| 449 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 450 |
"rag_query_expansion": formatted_query,
|
| 451 |
-
"docs_retrieved": serialize_documents(sources)
|
|
|
|
|
|
|
| 452 |
}
|
| 453 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
|
|
|
|
|
|
|
|
|
| 454 |
else:
|
| 455 |
response = chat_result
|
| 456 |
st.session_state.last_rag_result = None
|
|
@@ -480,6 +578,16 @@ def main():
|
|
| 480 |
# Dictionary format from multi-agent system
|
| 481 |
sources = rag_result['sources']
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
if sources and len(sources) > 0:
|
| 484 |
# Count unique filenames
|
| 485 |
unique_filenames = set()
|
|
@@ -487,16 +595,40 @@ def main():
|
|
| 487 |
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
|
| 488 |
unique_filenames.add(filename)
|
| 489 |
|
| 490 |
-
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top
|
| 491 |
if len(unique_filenames) < len(sources):
|
| 492 |
st.info(f"π‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
|
| 493 |
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
# Get relevance score and ID if available
|
| 496 |
metadata = getattr(doc, 'metadata', {})
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 502 |
# Display document metadata with emojis
|
|
@@ -543,200 +675,409 @@ def main():
|
|
| 543 |
if 'feedback_submitted' not in st.session_state:
|
| 544 |
st.session_state.feedback_submitted = False
|
| 545 |
|
| 546 |
-
# Feedback form
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
with col1:
|
| 551 |
-
feedback_score = st.slider(
|
| 552 |
-
"Rate this conversation (1-5)",
|
| 553 |
-
min_value=1,
|
| 554 |
-
max_value=5,
|
| 555 |
-
help="How satisfied are you with the conversation?"
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
with col2:
|
| 559 |
-
is_feedback_about_last_retrieval = st.checkbox(
|
| 560 |
-
"Feedback about last retrieval only",
|
| 561 |
-
value=True,
|
| 562 |
-
help="If checked, feedback applies to the most recent document retrieval"
|
| 563 |
-
)
|
| 564 |
-
|
| 565 |
-
open_ended_feedback = st.text_area(
|
| 566 |
-
"Your feedback (optional)",
|
| 567 |
-
placeholder="Tell us what went well or what could be improved...",
|
| 568 |
-
height=100
|
| 569 |
-
)
|
| 570 |
-
|
| 571 |
-
# Disable submit if no score selected
|
| 572 |
-
submit_disabled = feedback_score is None
|
| 573 |
-
|
| 574 |
-
submitted = st.form_submit_button(
|
| 575 |
-
"π€ Submit Feedback",
|
| 576 |
-
use_container_width=True,
|
| 577 |
-
disabled=submit_disabled
|
| 578 |
-
)
|
| 579 |
-
|
| 580 |
-
if submitted and not st.session_state.feedback_submitted:
|
| 581 |
-
# Log the feedback data being submitted
|
| 582 |
-
print("=" * 80)
|
| 583 |
-
print("π FEEDBACK SUBMISSION: Starting...")
|
| 584 |
-
print("=" * 80)
|
| 585 |
-
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 586 |
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
"timestamp": time.time(),
|
| 595 |
-
"message_count": len(st.session_state.messages),
|
| 596 |
-
"has_retrievals": has_retrievals,
|
| 597 |
-
"retrieval_count": len(st.session_state.rag_retrieval_history)
|
| 598 |
-
}
|
| 599 |
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
st.write(f"β
**Feedback Object Created**")
|
| 608 |
-
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
| 609 |
-
st.write(f"- Score: {feedback_obj.score}/5")
|
| 610 |
-
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
|
| 611 |
-
|
| 612 |
-
# Convert back to dict for JSON serialization
|
| 613 |
-
feedback_data = feedback_obj.to_dict()
|
| 614 |
-
except Exception as e:
|
| 615 |
-
print(f"β FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
|
| 616 |
-
st.error(f"Failed to create feedback object: {e}")
|
| 617 |
-
feedback_data = feedback_dict
|
| 618 |
-
|
| 619 |
-
# Display the data being submitted
|
| 620 |
-
st.json(feedback_data)
|
| 621 |
|
| 622 |
-
#
|
| 623 |
-
|
| 624 |
-
try:
|
| 625 |
-
# Ensure directory exists with write permissions (777 for compatibility)
|
| 626 |
-
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 627 |
-
except (PermissionError, OSError) as e:
|
| 628 |
-
logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
|
| 629 |
-
# Fallback to relative path
|
| 630 |
-
feedback_dir = Path("feedback")
|
| 631 |
-
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 632 |
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
| 635 |
-
|
| 636 |
-
#
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
-
#
|
| 640 |
-
|
| 641 |
-
with open(feedback_file, 'w') as f:
|
| 642 |
-
json.dump(feedback_data, f, indent=2, default=str)
|
| 643 |
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
|
| 648 |
-
|
| 649 |
-
logger.info("π FEEDBACK SAVE: Starting Snowflake save process...")
|
| 650 |
-
logger.info(f"π FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
|
| 651 |
|
|
|
|
|
|
|
| 652 |
try:
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
try:
|
| 660 |
-
from src.reporting.snowflake_connector import save_to_snowflake
|
| 661 |
-
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 662 |
-
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
|
| 663 |
-
|
| 664 |
-
if save_to_snowflake(feedback_obj):
|
| 665 |
-
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 666 |
-
print("β
SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
|
| 667 |
-
st.success("β
Feedback also saved to Snowflake!")
|
| 668 |
-
else:
|
| 669 |
-
logger.warning("β οΈ SNOWFLAKE UI: Save failed")
|
| 670 |
-
print("β οΈ SNOWFLAKE UI: Save failed") # Also print to terminal
|
| 671 |
-
st.warning("β οΈ Snowflake save failed, but local save succeeded")
|
| 672 |
-
except Exception as e:
|
| 673 |
-
logger.error(f"β SNOWFLAKE UI ERROR: {e}")
|
| 674 |
-
print(f"β SNOWFLAKE UI ERROR: {e}") # Also print to terminal
|
| 675 |
-
import traceback
|
| 676 |
-
traceback.print_exc() # Print full traceback to terminal
|
| 677 |
-
st.warning(f"β οΈ Could not save to Snowflake: {e}")
|
| 678 |
-
else:
|
| 679 |
-
logger.warning("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 680 |
-
print("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
|
| 681 |
-
st.warning("β οΈ Skipping Snowflake save (feedback object not created)")
|
| 682 |
-
else:
|
| 683 |
-
logger.info("π‘ SNOWFLAKE UI: Integration disabled")
|
| 684 |
-
print("π‘ SNOWFLAKE UI: Integration disabled") # Also print to terminal
|
| 685 |
-
st.info("π‘ Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
|
| 686 |
-
except NameError as e:
|
| 687 |
-
import traceback
|
| 688 |
-
traceback.print_exc()
|
| 689 |
-
logger.error(f"β NameError in Snowflake save: {e}")
|
| 690 |
-
print(f"β NameError in Snowflake save: {e}") # Also print to terminal
|
| 691 |
-
st.warning(f"β οΈ Snowflake save error: {e}")
|
| 692 |
except Exception as e:
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
|
| 697 |
-
#
|
| 698 |
-
st.
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
-
|
| 705 |
-
st.info(f"π Feedback saved to: {feedback_file}")
|
| 706 |
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
# Display retrieval history stats
|
| 722 |
if st.session_state.rag_retrieval_history:
|
| 723 |
st.markdown("---")
|
| 724 |
st.markdown("#### π Retrieval History")
|
| 725 |
|
| 726 |
-
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=
|
| 727 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 728 |
-
st.markdown(f"**Retrieval #{idx}**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
# Display the actual RAG query
|
| 731 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
|
|
|
| 732 |
st.code(rag_query_expansion, language="text")
|
| 733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
# Display summary stats
|
|
|
|
| 735 |
st.json({
|
| 736 |
-
"conversation_length": len(
|
| 737 |
-
"documents_retrieved": len(
|
| 738 |
})
|
| 739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
# Auto-scroll to bottom
|
| 742 |
st.markdown("""
|
|
@@ -745,5 +1086,32 @@ def main():
|
|
| 745 |
</script>
|
| 746 |
""", unsafe_allow_html=True)
|
| 747 |
|
|
|
|
| 748 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
main()
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
import uuid
|
| 10 |
+
import logging
|
| 11 |
+
import traceback
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from collections import Counter
|
| 15 |
+
from typing import List, Dict, Any, Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import streamlit as st
|
| 20 |
+
import plotly.express as px
|
| 21 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
|
| 25 |
+
from src.feedback import FeedbackManager
|
| 26 |
+
from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
|
| 27 |
+
|
| 28 |
+
from src.config.paths import (
|
| 29 |
+
IS_DEPLOYED,
|
| 30 |
+
PROJECT_DIR,
|
| 31 |
+
HF_CACHE_DIR,
|
| 32 |
+
FEEDBACK_DIR,
|
| 33 |
+
CONVERSATIONS_DIR,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
|
| 37 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 38 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
|
|
|
| 58 |
|
| 59 |
# ===== Setup HuggingFace cache directories BEFORE any model imports =====
|
| 60 |
# CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
|
| 61 |
+
# Only override cache directories in deployed environment (local uses defaults)
|
| 62 |
+
if IS_DEPLOYED and HF_CACHE_DIR:
|
| 63 |
+
cache_dir = str(HF_CACHE_DIR)
|
| 64 |
+
os.environ["HF_HOME"] = cache_dir
|
| 65 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
| 66 |
+
os.environ["HF_DATASETS_CACHE"] = cache_dir
|
| 67 |
+
os.environ["HF_HUB_CACHE"] = cache_dir
|
| 68 |
+
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
|
| 69 |
+
|
| 70 |
+
# Ensure cache directory exists (created in Dockerfile, but ensure it's there)
|
| 71 |
+
try:
|
| 72 |
+
os.makedirs(cache_dir, mode=0o755, exist_ok=True)
|
| 73 |
+
except (PermissionError, OSError):
|
| 74 |
+
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 75 |
+
pass
|
| 76 |
+
else:
|
| 77 |
+
from dotenv import load_dotenv
|
| 78 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# Configure logging
|
| 81 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 82 |
logger = logging.getLogger(__name__)
|
| 83 |
|
| 84 |
# Log environment setup for debugging
|
| 85 |
+
logger.info(f"π Environment: {'DEPLOYED' if IS_DEPLOYED else 'LOCAL'}")
|
| 86 |
+
logger.info(f"π PROJECT_DIR: {PROJECT_DIR}")
|
| 87 |
+
logger.info(f"π HuggingFace cache: {os.environ.get('HF_HOME', 'DEFAULT (not overridden)')}")
|
| 88 |
logger.info(f"π§ OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
|
| 89 |
|
| 90 |
|
|
|
|
| 96 |
page_title="Intelligent Audit Report Chatbot"
|
| 97 |
)
|
| 98 |
|
| 99 |
+
|
| 100 |
+
st.markdown(get_custom_css(), unsafe_allow_html=True)
|
| 101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def get_system_type():
|
| 104 |
"""Get the current system type"""
|
|
|
|
| 108 |
else:
|
| 109 |
return "Multi-Agent System"
|
| 110 |
|
| 111 |
+
def get_chatbot(version: str = "v1"):
|
| 112 |
+
"""Initialize and return the chatbot based on version"""
|
| 113 |
+
if version == "beta":
|
| 114 |
+
return get_gemini_chatbot()
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
+
# Check environment variable for system type (v1)
|
| 117 |
+
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 118 |
+
if system == 'smart':
|
| 119 |
+
return get_smart_chatbot()
|
| 120 |
+
else:
|
| 121 |
+
return get_multi_agent_chatbot()
|
| 122 |
|
| 123 |
def serialize_messages(messages):
|
| 124 |
"""Serialize LangChain messages to dictionaries"""
|
|
|
|
| 163 |
|
| 164 |
return serialized
|
| 165 |
|
| 166 |
+
|
| 167 |
+
feedback_manager = FeedbackManager()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
@st.cache_data
|
| 171 |
def load_filter_options():
|
| 172 |
try:
|
| 173 |
+
filter_options_path = PROJECT_DIR / "src" / "config" / "filter_options.json"
|
| 174 |
+
with open(filter_options_path, "r") as f:
|
| 175 |
return json.load(f)
|
| 176 |
except FileNotFoundError:
|
| 177 |
+
st.info(f"Looking for filter_options.json in: {PROJECT_DIR / 'src' / 'config'}")
|
| 178 |
st.error("filter_options.json not found. Please run the metadata analysis script.")
|
| 179 |
return {"sources": [], "years": [], "districts": [], 'filenames': []}
|
| 180 |
|
|
|
|
| 191 |
# Track RAG retrieval history for feedback
|
| 192 |
if 'rag_retrieval_history' not in st.session_state:
|
| 193 |
st.session_state.rag_retrieval_history = []
|
| 194 |
+
# Version selection (v1 or beta)
|
| 195 |
+
if 'chatbot_version' not in st.session_state:
|
| 196 |
+
st.session_state.chatbot_version = "v1"
|
| 197 |
+
|
| 198 |
+
# Initialize chatbot based on version (only if not already initialized for this version)
|
| 199 |
+
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 200 |
+
|
| 201 |
+
# Check if we need to initialize: chatbot doesn't exist OR version changed
|
| 202 |
+
needs_init = (
|
| 203 |
+
chatbot_version_key not in st.session_state or
|
| 204 |
+
st.session_state.get('_last_version') != st.session_state.chatbot_version
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if needs_init:
|
| 208 |
+
try:
|
| 209 |
+
# Different spinner messages for different versions
|
| 210 |
+
if st.session_state.chatbot_version == "beta":
|
| 211 |
+
spinner_msg = "π Initializing Gemini FSA"
|
| 212 |
+
else:
|
| 213 |
+
spinner_msg = "π Loading AI models and connecting to database..."
|
| 214 |
+
|
| 215 |
+
with st.spinner(spinner_msg):
|
| 216 |
+
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 217 |
+
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 218 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 219 |
+
print("β
AI system ready!")
|
| 220 |
+
except Exception as e:
|
| 221 |
+
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 222 |
+
# Only show Gemini-specific error message for beta version
|
| 223 |
+
if st.session_state.chatbot_version == "beta":
|
| 224 |
+
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
|
| 225 |
+
else:
|
| 226 |
+
st.error("Please check your configuration and ensure all required models and databases are accessible.")
|
| 227 |
+
# Reset to v1 to prevent infinite loop
|
| 228 |
+
st.session_state.chatbot_version = "v1"
|
| 229 |
+
st.session_state['_last_version'] = "v1"
|
| 230 |
+
if 'chatbot' in st.session_state:
|
| 231 |
+
del st.session_state['chatbot']
|
| 232 |
+
st.stop() # Stop execution to prevent infinite loop
|
| 233 |
+
else:
|
| 234 |
+
# Chatbot already initialized for this version, just use it
|
| 235 |
+
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 236 |
|
| 237 |
# Reset conversation history if needed (but keep chatbot cached)
|
| 238 |
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
|
|
|
| 244 |
st.session_state.reset_conversation = False
|
| 245 |
st.rerun()
|
| 246 |
|
| 247 |
+
|
| 248 |
+
# Version selection radio button (top right)
|
| 249 |
col1, col2 = st.columns([3, 1])
|
| 250 |
with col1:
|
| 251 |
+
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 252 |
with col2:
|
| 253 |
+
st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
|
| 254 |
+
selected_version = st.radio(
|
| 255 |
+
"**Version:**",
|
| 256 |
+
options=["v1", "beta"],
|
| 257 |
+
index=0 if st.session_state.chatbot_version == "v1" else 1,
|
| 258 |
+
horizontal=True,
|
| 259 |
+
key="version_selector",
|
| 260 |
+
help="Select v1 (default RAG system) or beta (Gemini FSA)"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Update version if changed
|
| 264 |
+
if selected_version != st.session_state.chatbot_version:
|
| 265 |
+
# Store the old version to check if we need to switch
|
| 266 |
+
old_version = st.session_state.chatbot_version
|
| 267 |
+
st.session_state.chatbot_version = selected_version
|
| 268 |
+
|
| 269 |
+
# If chatbot for new version already exists, just switch to it
|
| 270 |
+
new_chatbot_key = f"chatbot_{selected_version}"
|
| 271 |
+
if new_chatbot_key in st.session_state:
|
| 272 |
+
# Chatbot already exists, just switch
|
| 273 |
+
st.session_state.chatbot = st.session_state[new_chatbot_key]
|
| 274 |
+
st.session_state['_last_version'] = selected_version
|
| 275 |
+
else:
|
| 276 |
+
# Need to initialize new version - will be handled by initialization logic above
|
| 277 |
+
st.session_state['_last_version'] = old_version # Set to old to trigger init check
|
| 278 |
+
|
| 279 |
+
st.rerun()
|
| 280 |
+
|
| 281 |
+
# Show version info
|
| 282 |
+
if st.session_state.chatbot_version == "beta":
|
| 283 |
+
st.info("π¬ **Beta Mode**: Using Google Gemini FSA")
|
| 284 |
|
| 285 |
# Session info
|
| 286 |
duration = int(time.time() - st.session_state.session_start_time)
|
|
|
|
| 296 |
|
| 297 |
# Sidebar for filters
|
| 298 |
with st.sidebar:
|
| 299 |
+
# Instructions section (collapsible)
|
| 300 |
+
with st.expander("π How to Use", expanded=False):
|
| 301 |
+
st.markdown("""
|
| 302 |
+
#### π― Using Filters
|
| 303 |
+
|
| 304 |
+
1. **Select filters** from the sidebar to narrow your search:
|
| 305 |
+
|
| 306 |
+
2. **Leave filters empty** to search across all data
|
| 307 |
+
|
| 308 |
+
3. **Type your question** in the chat input at the bottom
|
| 309 |
+
|
| 310 |
+
4. **Click "Send"** to submit your question
|
| 311 |
+
|
| 312 |
+
#### π‘ Tips
|
| 313 |
+
|
| 314 |
+
- Use specific questions for better results
|
| 315 |
+
- Combine multiple filters for precise searches
|
| 316 |
+
- Check the "Retrieved Documents" tab to see source material
|
| 317 |
+
|
| 318 |
+
#### β οΈ Important
|
| 319 |
+
|
| 320 |
+
**When finished, please close the browser window** to free up computational resources.
|
| 321 |
+
|
| 322 |
+
---
|
| 323 |
+
|
| 324 |
+
For more detailed help, see the example questions at the bottom of the page.
|
| 325 |
+
""")
|
| 326 |
+
|
| 327 |
st.markdown("### π Search Filters")
|
| 328 |
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
|
| 329 |
|
|
|
|
| 342 |
# Determine if filename filter is active
|
| 343 |
filename_mode = len(selected_filenames) > 0
|
| 344 |
# Sources filter
|
| 345 |
+
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 346 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 347 |
selected_sources = st.multiselect(
|
| 348 |
"Select sources:",
|
|
|
|
| 355 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 356 |
|
| 357 |
# Years filter
|
| 358 |
+
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 359 |
st.markdown('<div class="filter-title">π
Years</div>', unsafe_allow_html=True)
|
| 360 |
selected_years = st.multiselect(
|
| 361 |
"Select years:",
|
|
|
|
| 368 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 369 |
|
| 370 |
# Districts filter
|
| 371 |
+
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 372 |
st.markdown('<div class="filter-title">ποΈ Districts</div>', unsafe_allow_html=True)
|
| 373 |
selected_districts = st.multiselect(
|
| 374 |
"Select districts:",
|
|
|
|
| 419 |
if 'input_counter' not in st.session_state:
|
| 420 |
st.session_state.input_counter = 0
|
| 421 |
|
| 422 |
+
# Handle pending question from example questions section
|
| 423 |
+
if 'pending_question' in st.session_state and st.session_state.pending_question:
|
| 424 |
+
default_value = st.session_state.pending_question
|
| 425 |
+
# Increment counter to force new input widget
|
| 426 |
+
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 427 |
+
del st.session_state.pending_question
|
| 428 |
+
key_suffix = st.session_state.input_counter
|
| 429 |
+
else:
|
| 430 |
+
default_value = ""
|
| 431 |
+
key_suffix = st.session_state.input_counter
|
| 432 |
+
|
| 433 |
user_input = st.text_input(
|
| 434 |
"Type your message here...",
|
| 435 |
placeholder="Ask about budget allocations, expenditures, or audit findings...",
|
| 436 |
+
key=f"user_input_{key_suffix}",
|
| 437 |
+
label_visibility="collapsed",
|
| 438 |
+
value=default_value if default_value else None
|
| 439 |
)
|
| 440 |
|
| 441 |
with col2:
|
| 442 |
+
send_button = st.button("Send", key="send_button", width='stretch')
|
| 443 |
|
| 444 |
# Clear chat button
|
| 445 |
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
| 446 |
st.session_state.reset_conversation = True
|
| 447 |
# Clear all conversation files
|
| 448 |
+
conversations_path = CONVERSATIONS_DIR
|
| 449 |
+
if conversations_path.exists():
|
| 450 |
+
for file in conversations_path.iterdir():
|
| 451 |
+
if file.suffix == '.json':
|
| 452 |
+
file.unlink()
|
|
|
|
| 453 |
st.rerun()
|
| 454 |
|
| 455 |
# Handle user input
|
|
|
|
| 491 |
if rag_result:
|
| 492 |
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 493 |
|
| 494 |
+
# For Gemini, also check gemini_result for sources
|
| 495 |
+
if not sources or len(sources) == 0:
|
| 496 |
+
gemini_result = chat_result.get('gemini_result')
|
| 497 |
+
print(f"π DEBUG: Checking gemini_result for sources...")
|
| 498 |
+
print(f" gemini_result exists: {gemini_result is not None}")
|
| 499 |
+
if gemini_result:
|
| 500 |
+
print(f" gemini_result type: {type(gemini_result)}")
|
| 501 |
+
print(f" has sources attr: {hasattr(gemini_result, 'sources')}")
|
| 502 |
+
if hasattr(gemini_result, 'sources'):
|
| 503 |
+
print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}")
|
| 504 |
+
|
| 505 |
+
if gemini_result and hasattr(gemini_result, 'sources'):
|
| 506 |
+
# Format Gemini sources for display
|
| 507 |
+
if hasattr(st.session_state.chatbot, 'gemini_client'):
|
| 508 |
+
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
|
| 509 |
+
print(f"β
Formatted {len(sources)} sources from gemini_client")
|
| 510 |
+
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
|
| 511 |
+
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 512 |
+
print(f"β
Formatted {len(sources)} sources from _format_gemini_sources")
|
| 513 |
+
|
| 514 |
+
# Update rag_result with sources if we found them
|
| 515 |
+
if sources and len(sources) > 0:
|
| 516 |
+
if isinstance(rag_result, dict):
|
| 517 |
+
rag_result['sources'] = sources
|
| 518 |
+
elif hasattr(rag_result, 'sources'):
|
| 519 |
+
rag_result.sources = sources
|
| 520 |
+
# Update last_rag_result with sources
|
| 521 |
+
st.session_state.last_rag_result = rag_result
|
| 522 |
+
print(f"β
Updated rag_result with {len(sources)} sources")
|
| 523 |
+
|
| 524 |
# Get the actual RAG query
|
| 525 |
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 526 |
if actual_rag_query:
|
|
|
|
| 530 |
else:
|
| 531 |
formatted_query = "No RAG query available"
|
| 532 |
|
| 533 |
+
# Extract filters from active filters
|
| 534 |
+
filters_used = {
|
| 535 |
+
"sources": st.session_state.active_filters.get('sources', []),
|
| 536 |
+
"years": st.session_state.active_filters.get('years', []),
|
| 537 |
+
"districts": st.session_state.active_filters.get('districts', []),
|
| 538 |
+
"filenames": st.session_state.active_filters.get('filenames', [])
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
retrieval_entry = {
|
| 542 |
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 543 |
"rag_query_expansion": formatted_query,
|
| 544 |
+
"docs_retrieved": serialize_documents(sources),
|
| 545 |
+
"filters_applied": filters_used,
|
| 546 |
+
"timestamp": time.time()
|
| 547 |
}
|
| 548 |
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 549 |
+
|
| 550 |
+
# Debug logging
|
| 551 |
+
print(f"π RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history")
|
| 552 |
else:
|
| 553 |
response = chat_result
|
| 554 |
st.session_state.last_rag_result = None
|
|
|
|
| 578 |
# Dictionary format from multi-agent system
|
| 579 |
sources = rag_result['sources']
|
| 580 |
|
| 581 |
+
# For Gemini, also check if we need to format sources from gemini_result
|
| 582 |
+
if (not sources or len(sources) == 0) and isinstance(rag_result, dict):
|
| 583 |
+
gemini_result = rag_result.get('gemini_result')
|
| 584 |
+
if gemini_result and hasattr(gemini_result, 'sources'):
|
| 585 |
+
# Format Gemini sources for display
|
| 586 |
+
if hasattr(st.session_state.chatbot, 'gemini_client'):
|
| 587 |
+
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
|
| 588 |
+
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
|
| 589 |
+
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 590 |
+
|
| 591 |
if sources and len(sources) > 0:
|
| 592 |
# Count unique filenames
|
| 593 |
unique_filenames = set()
|
|
|
|
| 595 |
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
|
| 596 |
unique_filenames.add(filename)
|
| 597 |
|
| 598 |
+
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
|
| 599 |
if len(unique_filenames) < len(sources):
|
| 600 |
st.info(f"π‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
|
| 601 |
|
| 602 |
+
# Extract and display statistics
|
| 603 |
+
stats = extract_chunk_statistics(sources)
|
| 604 |
+
|
| 605 |
+
# Show charts for 10+ results, tables for fewer
|
| 606 |
+
if len(sources) >= 10:
|
| 607 |
+
display_chunk_statistics_charts(stats, "Retrieval Statistics")
|
| 608 |
+
# Also show tables below charts for detailed view
|
| 609 |
+
st.markdown("---")
|
| 610 |
+
display_chunk_statistics_table(stats, "Retrieval Distribution")
|
| 611 |
+
else:
|
| 612 |
+
display_chunk_statistics_table(stats, "Retrieval Distribution")
|
| 613 |
+
|
| 614 |
+
st.markdown("---")
|
| 615 |
+
st.markdown("### π Document Details")
|
| 616 |
+
|
| 617 |
+
for i, doc in enumerate(sources): # Show all documents
|
| 618 |
# Get relevance score and ID if available
|
| 619 |
metadata = getattr(doc, 'metadata', {})
|
| 620 |
+
# Handle both standard RAG scores and Gemini scores
|
| 621 |
+
score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
|
| 622 |
+
chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
|
| 623 |
+
if score is not None:
|
| 624 |
+
try:
|
| 625 |
+
score_text = f" (Score: {float(score):.3f})"
|
| 626 |
+
except (ValueError, TypeError):
|
| 627 |
+
score_text = ""
|
| 628 |
+
else:
|
| 629 |
+
score_text = ""
|
| 630 |
+
if chunk_id and chunk_id != 'Unknown':
|
| 631 |
+
score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
|
| 632 |
|
| 633 |
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 634 |
# Display document metadata with emojis
|
|
|
|
| 675 |
if 'feedback_submitted' not in st.session_state:
|
| 676 |
st.session_state.feedback_submitted = False
|
| 677 |
|
| 678 |
+
# Feedback form - only show if feedback not already submitted
|
| 679 |
+
if not st.session_state.feedback_submitted:
|
| 680 |
+
with st.form("feedback_form", clear_on_submit=False):
|
| 681 |
+
col1, col2 = st.columns([1, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
+
with col1:
|
| 684 |
+
feedback_score = st.slider(
|
| 685 |
+
"Rate this conversation (1-5)",
|
| 686 |
+
min_value=1,
|
| 687 |
+
max_value=5,
|
| 688 |
+
help="How satisfied are you with the conversation?"
|
| 689 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
+
with col2:
|
| 692 |
+
is_feedback_about_last_retrieval = st.checkbox(
|
| 693 |
+
"Feedback about last retrieval only",
|
| 694 |
+
value=True,
|
| 695 |
+
help="If checked, feedback applies to the most recent document retrieval"
|
| 696 |
+
)
|
| 697 |
|
| 698 |
+
open_ended_feedback = st.text_area(
|
| 699 |
+
"Your feedback (optional)",
|
| 700 |
+
placeholder="Tell us what went well or what could be improved...",
|
| 701 |
+
height=100
|
| 702 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
+
# Disable submit if no score selected
|
| 705 |
+
submit_disabled = feedback_score is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
|
| 707 |
+
submitted = st.form_submit_button(
|
| 708 |
+
"π€ Submit Feedback",
|
| 709 |
+
width='stretch',
|
| 710 |
+
disabled=submit_disabled
|
| 711 |
+
)
|
| 712 |
|
| 713 |
+
if submitted:
|
| 714 |
+
# Log the feedback data being submitted
|
| 715 |
+
print("=" * 80)
|
| 716 |
+
print("π FEEDBACK SUBMISSION: Starting...")
|
| 717 |
+
print("=" * 80)
|
| 718 |
+
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 719 |
+
|
| 720 |
+
# Extract transcript from messages
|
| 721 |
+
transcript = feedback_manager.extract_transcript(st.session_state.messages)
|
| 722 |
+
|
| 723 |
+
# Build retrievals structure
|
| 724 |
+
retrievals = feedback_manager.build_retrievals_structure(
|
| 725 |
+
|
| 726 |
+
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 727 |
+
st.session_state.messages
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# Build feedback_score_related_retrieval_docs
|
| 731 |
+
|
| 732 |
+
feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
|
| 733 |
+
is_feedback_about_last_retrieval,
|
| 734 |
+
st.session_state.messages,
|
| 735 |
+
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
| 736 |
+
)
|
| 737 |
|
| 738 |
+
# Preserve old retrieved_data format for backward compatibility
|
| 739 |
+
retrieved_data_old_format = st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
|
|
|
|
|
|
| 740 |
|
| 741 |
+
# Create feedback data dictionary
|
| 742 |
+
feedback_dict = {
|
| 743 |
+
"open_ended_feedback": open_ended_feedback,
|
| 744 |
+
"score": feedback_score,
|
| 745 |
+
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
|
| 746 |
+
"conversation_id": st.session_state.conversation_id,
|
| 747 |
+
"timestamp": time.time(),
|
| 748 |
+
"message_count": len(st.session_state.messages),
|
| 749 |
+
"has_retrievals": has_retrievals,
|
| 750 |
+
"retrieval_count": len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0,
|
| 751 |
+
"transcript": transcript,
|
| 752 |
+
"retrievals": retrievals,
|
| 753 |
+
"feedback_score_related_retrieval_docs": feedback_score_related_retrieval_docs,
|
| 754 |
+
"retrieved_data": retrieved_data_old_format # Preserved old column
|
| 755 |
+
}
|
| 756 |
|
| 757 |
+
print(f"π FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
|
|
|
|
|
|
|
| 758 |
|
| 759 |
+
# Create UserFeedback dataclass instance
|
| 760 |
+
feedback_obj = None # Initialize outside try block
|
| 761 |
try:
|
| 762 |
+
feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
|
| 763 |
+
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 764 |
+
st.write(f"β
**Feedback Object Created**")
|
| 765 |
+
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
| 766 |
+
st.write(f"- Score: {feedback_obj.score}/5")
|
| 767 |
+
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
|
| 768 |
|
| 769 |
+
# Convert back to dict for JSON serialization
|
| 770 |
+
feedback_data = feedback_obj.to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
except Exception as e:
|
| 772 |
+
print(f"β FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
|
| 773 |
+
st.error(f"Failed to create feedback object: {e}")
|
| 774 |
+
feedback_data = feedback_dict
|
| 775 |
|
| 776 |
+
# Display the data being submitted
|
| 777 |
+
st.json(feedback_data)
|
| 778 |
|
| 779 |
+
# Save feedback to file - use PROJECT_DIR to ensure writability
|
| 780 |
+
feedback_dir = FEEDBACK_DIR
|
| 781 |
+
try:
|
| 782 |
+
# Ensure directory exists with write permissions (777 for compatibility)
|
| 783 |
+
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 784 |
+
except (PermissionError, OSError) as e:
|
| 785 |
+
logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
|
| 786 |
+
# Fallback to relative path
|
| 787 |
+
feedback_dir = Path("feedback")
|
| 788 |
+
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 789 |
|
| 790 |
+
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
|
|
|
|
| 791 |
|
| 792 |
+
try:
|
| 793 |
+
# Ensure parent directory exists before writing
|
| 794 |
+
feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 795 |
+
|
| 796 |
+
# Save to local file first
|
| 797 |
+
print(f"πΎ FEEDBACK SAVE: Saving to local file: {feedback_file}")
|
| 798 |
+
with open(feedback_file, 'w') as f:
|
| 799 |
+
json.dump(feedback_data, f, indent=2, default=str)
|
| 800 |
+
|
| 801 |
+
print(f"β
FEEDBACK SAVE: Local file saved successfully")
|
| 802 |
+
|
| 803 |
+
# Save to Snowflake if enabled and credentials available
|
| 804 |
+
logger.info("π FEEDBACK SAVE: Starting Snowflake save process...")
|
| 805 |
+
logger.info(f"π FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
|
| 806 |
+
|
| 807 |
+
snowflake_success = False
|
| 808 |
+
try:
|
| 809 |
+
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
|
| 810 |
+
logger.info(f"π SNOWFLAKE CHECK: enabled={snowflake_enabled}")
|
| 811 |
+
|
| 812 |
+
if snowflake_enabled:
|
| 813 |
+
if feedback_obj:
|
| 814 |
+
try:
|
| 815 |
+
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 816 |
+
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 817 |
+
|
| 818 |
+
# Show spinner while saving to Snowflake (can take 10-15 seconds)
|
| 819 |
+
# This includes: connection establishment (~5s), data preparation, and SQL execution (~5s)
|
| 820 |
+
with st.spinner("πΎ Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"):
|
| 821 |
+
snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
|
| 822 |
+
|
| 823 |
+
if snowflake_success:
|
| 824 |
+
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 825 |
+
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 826 |
+
else:
|
| 827 |
+
logger.warning("β οΈ SNOWFLAKE UI: Save failed")
|
| 828 |
+
print("β οΈ SNOWFLAKE UI: Save failed")
|
| 829 |
+
except Exception as e:
|
| 830 |
+
logger.error(f"β SNOWFLAKE UI ERROR: {e}")
|
| 831 |
+
print(f"β SNOWFLAKE UI ERROR: {e}")
|
| 832 |
+
traceback.print_exc()
|
| 833 |
+
snowflake_success = False
|
| 834 |
+
else:
|
| 835 |
+
logger.warning("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 836 |
+
print("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 837 |
+
snowflake_success = False
|
| 838 |
+
else:
|
| 839 |
+
logger.info("π‘ SNOWFLAKE UI: Integration disabled")
|
| 840 |
+
print("π‘ SNOWFLAKE UI: Integration disabled")
|
| 841 |
+
# If Snowflake is disabled, consider it successful (local save only)
|
| 842 |
+
snowflake_success = True
|
| 843 |
+
|
| 844 |
+
except Exception as e:
|
| 845 |
+
logger.error(f"β Exception in Snowflake save: {type(e).__name__}: {e}")
|
| 846 |
+
print(f"β Exception in Snowflake save: {type(e).__name__}: {e}")
|
| 847 |
+
snowflake_success = False
|
| 848 |
+
|
| 849 |
+
# Only show success if Snowflake save succeeded (or if Snowflake is disabled)
|
| 850 |
+
if snowflake_success:
|
| 851 |
+
st.success("β
Thank you for your feedback! It has been saved successfully.")
|
| 852 |
+
st.balloons()
|
| 853 |
+
else:
|
| 854 |
+
st.warning("β οΈ Feedback saved locally, but Snowflake save failed. Please check logs.")
|
| 855 |
+
|
| 856 |
+
# Mark feedback as submitted to prevent resubmission
|
| 857 |
+
st.session_state.feedback_submitted = True
|
| 858 |
+
|
| 859 |
+
print("=" * 80)
|
| 860 |
+
print(f"β
FEEDBACK SUBMISSION: Completed successfully")
|
| 861 |
+
print("=" * 80)
|
| 862 |
+
|
| 863 |
+
# Log file location
|
| 864 |
+
st.info(f"π Feedback saved to: {feedback_file}")
|
| 865 |
+
|
| 866 |
+
except Exception as e:
|
| 867 |
+
print(f"β FEEDBACK SUBMISSION: Error saving feedback: {e}")
|
| 868 |
+
print(f"β FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
|
| 869 |
+
traceback.print_exc()
|
| 870 |
+
st.error(f"β Error saving feedback: {e}")
|
| 871 |
+
st.write(f"Debug error: {str(e)}")
|
| 872 |
+
else:
|
| 873 |
+
# Feedback already submitted - show success message and reset option
|
| 874 |
+
st.success("β
Feedback already submitted for this conversation!")
|
| 875 |
+
col1, col2 = st.columns([1, 1])
|
| 876 |
+
with col1:
|
| 877 |
+
if st.button("π Submit New Feedback", key="new_feedback_button", width='stretch'):
|
| 878 |
+
try:
|
| 879 |
+
st.session_state.feedback_submitted = False
|
| 880 |
+
st.rerun()
|
| 881 |
+
except Exception as e:
|
| 882 |
+
# Handle any Streamlit API exceptions gracefully
|
| 883 |
+
logger.error(f"Error resetting feedback state: {e}")
|
| 884 |
+
st.error(f"Error resetting feedback. Please refresh the page.")
|
| 885 |
+
with col2:
|
| 886 |
+
if st.button("π View Conversation", key="view_conversation_button", width='stretch'):
|
| 887 |
+
# Scroll to conversation - this is handled by the auto-scroll at bottom
|
| 888 |
+
pass
|
| 889 |
|
| 890 |
# Display retrieval history stats
|
| 891 |
if st.session_state.rag_retrieval_history:
|
| 892 |
st.markdown("---")
|
| 893 |
st.markdown("#### π Retrieval History")
|
| 894 |
|
| 895 |
+
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
|
| 896 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 897 |
+
st.markdown(f"### **Retrieval #{idx}**")
|
| 898 |
+
|
| 899 |
+
# Display timestamp if available
|
| 900 |
+
if entry.get("timestamp"):
|
| 901 |
+
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
|
| 902 |
+
st.caption(f"π {timestamp_str}")
|
| 903 |
|
| 904 |
# Display the actual RAG query
|
| 905 |
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
| 906 |
+
st.markdown("**π RAG Query:**")
|
| 907 |
st.code(rag_query_expansion, language="text")
|
| 908 |
|
| 909 |
+
# Display filters used
|
| 910 |
+
filters_applied = entry.get("filters_applied", {})
|
| 911 |
+
if filters_applied and any(filters_applied.values()):
|
| 912 |
+
st.markdown("**π― Filters Applied:**")
|
| 913 |
+
filter_display = {}
|
| 914 |
+
if filters_applied.get("sources"):
|
| 915 |
+
filter_display["Sources"] = filters_applied["sources"]
|
| 916 |
+
if filters_applied.get("years"):
|
| 917 |
+
filter_display["Years"] = filters_applied["years"]
|
| 918 |
+
if filters_applied.get("districts"):
|
| 919 |
+
filter_display["Districts"] = filters_applied["districts"]
|
| 920 |
+
if filters_applied.get("filenames"):
|
| 921 |
+
filter_display["Filenames"] = filters_applied["filenames"]
|
| 922 |
+
|
| 923 |
+
if filter_display:
|
| 924 |
+
st.json(filter_display)
|
| 925 |
+
else:
|
| 926 |
+
st.info("No filters applied")
|
| 927 |
+
else:
|
| 928 |
+
st.info("No filters applied")
|
| 929 |
+
|
| 930 |
+
# Display conversation history up to retrieval point
|
| 931 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 932 |
+
if conversation_up_to:
|
| 933 |
+
st.markdown("**π¬ Conversation History (up to retrieval point):**")
|
| 934 |
+
with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
|
| 935 |
+
for msg_idx, msg in enumerate(conversation_up_to, 1):
|
| 936 |
+
role = msg.get("type", "unknown")
|
| 937 |
+
content = msg.get("content", "")
|
| 938 |
+
|
| 939 |
+
if role == "HumanMessage" or role == "human":
|
| 940 |
+
st.markdown(f"**π€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 941 |
+
elif role == "AIMessage" or role == "ai":
|
| 942 |
+
st.markdown(f"**π€ Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
|
| 943 |
+
else:
|
| 944 |
+
st.info("No conversation history available")
|
| 945 |
+
|
| 946 |
+
# Display documents retrieved
|
| 947 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 948 |
+
if docs_retrieved:
|
| 949 |
+
st.markdown(f"**π Documents Retrieved ({len(docs_retrieved)}):**")
|
| 950 |
+
with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
|
| 951 |
+
for doc_idx, doc in enumerate(docs_retrieved, 1):
|
| 952 |
+
st.markdown(f"**Document {doc_idx}:**")
|
| 953 |
+
|
| 954 |
+
# Display metadata
|
| 955 |
+
metadata = doc.get("metadata", {})
|
| 956 |
+
if metadata:
|
| 957 |
+
col1, col2, col3 = st.columns(3)
|
| 958 |
+
with col1:
|
| 959 |
+
st.write(f"π **File:** {metadata.get('filename', 'Unknown')}")
|
| 960 |
+
with col2:
|
| 961 |
+
st.write(f"ποΈ **Source:** {metadata.get('source', 'Unknown')}")
|
| 962 |
+
with col3:
|
| 963 |
+
st.write(f"π
**Year:** {metadata.get('year', 'Unknown')}")
|
| 964 |
+
|
| 965 |
+
# Additional metadata
|
| 966 |
+
if metadata.get('district'):
|
| 967 |
+
st.write(f"π **District:** {metadata.get('district')}")
|
| 968 |
+
if metadata.get('page'):
|
| 969 |
+
st.write(f"π **Page:** {metadata.get('page')}")
|
| 970 |
+
if metadata.get('score') is not None:
|
| 971 |
+
st.write(f"β **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"β **Score:** {metadata.get('score')}")
|
| 972 |
+
|
| 973 |
+
# Display content preview (first 200 chars)
|
| 974 |
+
content = doc.get("content", doc.get("page_content", ""))
|
| 975 |
+
if content:
|
| 976 |
+
st.markdown("**Content Preview:**")
|
| 977 |
+
st.text_area(
|
| 978 |
+
"Content Preview",
|
| 979 |
+
value=content[:200] + ("..." if len(content) > 200 else ""),
|
| 980 |
+
height=100,
|
| 981 |
+
disabled=True,
|
| 982 |
+
label_visibility="collapsed",
|
| 983 |
+
key=f"retrieval_{idx}_doc_{doc_idx}_preview"
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
if doc_idx < len(docs_retrieved):
|
| 987 |
+
st.markdown("---")
|
| 988 |
+
else:
|
| 989 |
+
st.info("No documents retrieved")
|
| 990 |
+
|
| 991 |
# Display summary stats
|
| 992 |
+
st.markdown("**π Summary:**")
|
| 993 |
st.json({
|
| 994 |
+
"conversation_length": len(conversation_up_to),
|
| 995 |
+
"documents_retrieved": len(docs_retrieved)
|
| 996 |
})
|
| 997 |
+
|
| 998 |
+
if idx < len(st.session_state.rag_retrieval_history):
|
| 999 |
+
st.markdown("---")
|
| 1000 |
+
|
| 1001 |
+
# Example Questions Section
|
| 1002 |
+
st.markdown("---")
|
| 1003 |
+
st.markdown("### π‘ Example Questions")
|
| 1004 |
+
st.markdown("Click on any question below to use it, or modify the editable examples:")
|
| 1005 |
+
|
| 1006 |
+
# Initialize example question state
|
| 1007 |
+
if 'custom_question_1' not in st.session_state:
|
| 1008 |
+
st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?"
|
| 1009 |
+
if 'custom_question_2' not in st.session_state:
|
| 1010 |
+
st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?"
|
| 1011 |
+
|
| 1012 |
+
# Question 1: Filename insights (fixed, clickable)
|
| 1013 |
+
st.markdown("#### π Question 1: List insights from a specific file")
|
| 1014 |
+
col1, col2 = st.columns([3, 1])
|
| 1015 |
+
with col1:
|
| 1016 |
+
example_q1 = "List couple of insights from the filename."
|
| 1017 |
+
st.markdown(f"**Example:** `{example_q1}`")
|
| 1018 |
+
st.info("π‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
|
| 1019 |
+
with col2:
|
| 1020 |
+
if st.button("π Use This Question", key="use_example_1", width='stretch'):
|
| 1021 |
+
st.session_state.pending_question = example_q1
|
| 1022 |
+
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1023 |
+
st.rerun()
|
| 1024 |
+
|
| 1025 |
+
st.markdown("---")
|
| 1026 |
+
|
| 1027 |
+
# Questions 2 & 3: Editable examples
|
| 1028 |
+
st.markdown("#### βοΈ Customizable Questions (Edit and use)")
|
| 1029 |
+
|
| 1030 |
+
# Question 2
|
| 1031 |
+
# st.markdown("**Question 2:**")
|
| 1032 |
+
custom_q1 = st.text_area(
|
| 1033 |
+
"Edit question 2:",
|
| 1034 |
+
value=st.session_state.custom_question_1,
|
| 1035 |
+
height=80,
|
| 1036 |
+
key="edit_question_2",
|
| 1037 |
+
help="Modify this question to fit your needs, then click 'Use This Question'"
|
| 1038 |
+
)
|
| 1039 |
+
col1, col2 = st.columns([1, 4])
|
| 1040 |
+
with col1:
|
| 1041 |
+
if st.button("π Use Question 2", key="use_custom_1", width='stretch'):
|
| 1042 |
+
if custom_q1.strip():
|
| 1043 |
+
st.session_state.pending_question = custom_q1.strip()
|
| 1044 |
+
st.session_state.custom_question_1 = custom_q1.strip()
|
| 1045 |
+
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1046 |
+
st.rerun()
|
| 1047 |
+
else:
|
| 1048 |
+
st.warning("Please enter a question first!")
|
| 1049 |
+
with col2:
|
| 1050 |
+
st.caption("π‘ Tip: Add specific details like dates, names, or amounts to get more precise answers")
|
| 1051 |
+
|
| 1052 |
+
st.info("π‘ **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
|
| 1053 |
+
|
| 1054 |
+
st.markdown("---")
|
| 1055 |
+
|
| 1056 |
+
# Question 3
|
| 1057 |
+
# st.markdown("**Question 3:**")
|
| 1058 |
+
custom_q2 = st.text_area(
|
| 1059 |
+
"Edit question 3:",
|
| 1060 |
+
value=st.session_state.custom_question_2,
|
| 1061 |
+
height=80,
|
| 1062 |
+
key="edit_question_3",
|
| 1063 |
+
help="Modify this question to fit your needs, then click 'Use This Question'"
|
| 1064 |
+
)
|
| 1065 |
+
col1, col2 = st.columns([1, 4])
|
| 1066 |
+
with col1:
|
| 1067 |
+
if st.button("π Use Question 3", key="use_custom_2", width='stretch'):
|
| 1068 |
+
if custom_q2.strip():
|
| 1069 |
+
st.session_state.pending_question = custom_q2.strip()
|
| 1070 |
+
st.session_state.custom_question_2 = custom_q2.strip()
|
| 1071 |
+
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1072 |
+
st.rerun()
|
| 1073 |
+
else:
|
| 1074 |
+
st.warning("Please enter a question first!")
|
| 1075 |
+
with col2:
|
| 1076 |
+
st.caption("π‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
# Store selected question for next render (handled in input section above)
|
| 1080 |
+
# This ensures the question populates the input field correctly
|
| 1081 |
|
| 1082 |
# Auto-scroll to bottom
|
| 1083 |
st.markdown("""
|
|
|
|
| 1086 |
</script>
|
| 1087 |
""", unsafe_allow_html=True)
|
| 1088 |
|
| 1089 |
+
|
| 1090 |
if __name__ == "__main__":
|
| 1091 |
+
# Check if running in Streamlit context
|
| 1092 |
+
try:
|
| 1093 |
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
| 1094 |
+
if get_script_run_ctx() is None:
|
| 1095 |
+
# Not in Streamlit runtime - show helpful message
|
| 1096 |
+
print("=" * 80)
|
| 1097 |
+
print("β οΈ WARNING: This is a Streamlit app!")
|
| 1098 |
+
print("=" * 80)
|
| 1099 |
+
print("\nPlease run this app using:")
|
| 1100 |
+
print(" streamlit run app.py")
|
| 1101 |
+
print("\nNot: python app.py")
|
| 1102 |
+
print("\nThe app will not function correctly when run with 'python app.py'")
|
| 1103 |
+
print("=" * 80)
|
| 1104 |
+
import sys
|
| 1105 |
+
sys.exit(1)
|
| 1106 |
+
except ImportError:
|
| 1107 |
+
# Streamlit not installed or not in Streamlit context
|
| 1108 |
+
print("=" * 80)
|
| 1109 |
+
print("β οΈ WARNING: This is a Streamlit app!")
|
| 1110 |
+
print("=" * 80)
|
| 1111 |
+
print("\nPlease run this app using:")
|
| 1112 |
+
print(" streamlit run app.py")
|
| 1113 |
+
print("\nNot: python app.py")
|
| 1114 |
+
print("=" * 80)
|
| 1115 |
+
import sys
|
| 1116 |
+
sys.exit(1)
|
| 1117 |
main()
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent modules for chatbot implementations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .smart_chatbot import get_chatbot as get_smart_chatbot
|
| 6 |
+
from .multi_agent_chatbot import get_multi_agent_chatbot
|
| 7 |
+
from .gemini_chatbot import get_gemini_chatbot
|
| 8 |
+
|
| 9 |
+
__all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
|
| 10 |
+
|
src/agents/gemini_chatbot.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Chatbot (Beta Version)
|
| 3 |
+
|
| 4 |
+
This chatbot uses Google Gemini File Search API for RAG.
|
| 5 |
+
It provides a simpler architecture: Main Agent + Gemini Agent
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import traceback
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Any, Optional, TypedDict
|
| 15 |
+
|
| 16 |
+
from langgraph.graph import StateGraph, END
|
| 17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 19 |
+
|
| 20 |
+
from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 21 |
+
from src.config.paths import CONVERSATIONS_DIR
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GeminiState(TypedDict):
|
| 28 |
+
"""State for Gemini chatbot conversation flow"""
|
| 29 |
+
conversation_id: str
|
| 30 |
+
messages: List[Any]
|
| 31 |
+
current_query: str
|
| 32 |
+
query_context: Optional[Dict[str, Any]]
|
| 33 |
+
gemini_result: Optional[GeminiFileSearchResult]
|
| 34 |
+
final_response: Optional[str]
|
| 35 |
+
agent_logs: List[str]
|
| 36 |
+
conversation_context: Dict[str, Any]
|
| 37 |
+
session_start_time: float
|
| 38 |
+
last_ai_message_time: float
|
| 39 |
+
filters: Optional[Dict[str, Any]]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GeminiRAGChatbot:
|
| 43 |
+
"""Gemini File Search RAG chatbot (Beta version)"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
"""Initialize the Gemini chatbot"""
|
| 47 |
+
logger.info("π€ INITIALIZING: Gemini File Search Chatbot (Beta)")
|
| 48 |
+
|
| 49 |
+
# Initialize Gemini File Search client
|
| 50 |
+
try:
|
| 51 |
+
self.gemini_client = GeminiFileSearchClient()
|
| 52 |
+
logger.info("β
Gemini File Search client initialized")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"β Failed to initialize Gemini client: {e}")
|
| 55 |
+
raise RuntimeError(f"Gemini client initialization failed: {e}")
|
| 56 |
+
|
| 57 |
+
# Build the LangGraph with LangSmith tracing if enabled
|
| 58 |
+
self.graph = self._build_graph()
|
| 59 |
+
|
| 60 |
+
# Enable LangSmith tracing if configured
|
| 61 |
+
langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
| 62 |
+
if langsmith_enabled:
|
| 63 |
+
logger.info("π LangSmith tracing enabled")
|
| 64 |
+
langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
|
| 65 |
+
logger.info(f"π LangSmith project: {langsmith_project}")
|
| 66 |
+
|
| 67 |
+
# Conversations directory
|
| 68 |
+
self.conversations_dir = CONVERSATIONS_DIR
|
| 69 |
+
try:
|
| 70 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 71 |
+
except (PermissionError, OSError) as e:
|
| 72 |
+
logger.warning(f"Could not create conversations directory: {e}")
|
| 73 |
+
self.conversations_dir = Path("conversations")
|
| 74 |
+
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
logger.info("β
Gemini File Search Chatbot initialized")
|
| 77 |
+
|
| 78 |
+
def _build_graph(self) -> StateGraph:
|
| 79 |
+
"""Build the LangGraph for Gemini chatbot"""
|
| 80 |
+
graph = StateGraph(GeminiState)
|
| 81 |
+
|
| 82 |
+
# Add nodes
|
| 83 |
+
graph.add_node("main_agent", self._main_agent)
|
| 84 |
+
graph.add_node("gemini_agent", self._gemini_agent)
|
| 85 |
+
|
| 86 |
+
# Define the flow
|
| 87 |
+
graph.set_entry_point("main_agent")
|
| 88 |
+
graph.add_edge("main_agent", "gemini_agent")
|
| 89 |
+
graph.add_edge("gemini_agent", END)
|
| 90 |
+
|
| 91 |
+
return graph.compile()
|
| 92 |
+
|
| 93 |
+
def _main_agent(self, state: GeminiState) -> GeminiState:
|
| 94 |
+
"""Main Agent: Extracts filters and prepares query"""
|
| 95 |
+
logger.info("π― MAIN AGENT: Processing query")
|
| 96 |
+
|
| 97 |
+
query = state["current_query"]
|
| 98 |
+
messages = state["messages"]
|
| 99 |
+
|
| 100 |
+
# Extract UI filters if present in query
|
| 101 |
+
ui_filters = self._extract_ui_filters(query)
|
| 102 |
+
|
| 103 |
+
# Extract context from conversation
|
| 104 |
+
context = self._extract_context_from_conversation(messages, ui_filters)
|
| 105 |
+
|
| 106 |
+
# Store context and filters
|
| 107 |
+
state["query_context"] = context
|
| 108 |
+
state["filters"] = context.get("filters", {})
|
| 109 |
+
|
| 110 |
+
logger.info(f"π― MAIN AGENT: Filters extracted: {state['filters']}")
|
| 111 |
+
|
| 112 |
+
return state
|
| 113 |
+
|
| 114 |
+
def _gemini_agent(self, state: GeminiState) -> GeminiState:
|
| 115 |
+
"""Gemini Agent: Performs file search and generates response"""
|
| 116 |
+
logger.info("π GEMINI AGENT: Starting file search")
|
| 117 |
+
|
| 118 |
+
query = state["current_query"]
|
| 119 |
+
filters = state.get("filters", {})
|
| 120 |
+
|
| 121 |
+
# Perform Gemini file search
|
| 122 |
+
try:
|
| 123 |
+
result = self.gemini_client.search(query=query, filters=filters)
|
| 124 |
+
logger.info(f"β
GEMINI AGENT: Search completed, {len(result.sources)} sources found")
|
| 125 |
+
|
| 126 |
+
# Enhance response with document references
|
| 127 |
+
enhanced_response = self._enhance_response_with_references(
|
| 128 |
+
result.answer,
|
| 129 |
+
result.sources,
|
| 130 |
+
query
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
state["gemini_result"] = result
|
| 134 |
+
state["final_response"] = enhanced_response
|
| 135 |
+
state["last_ai_message_time"] = time.time()
|
| 136 |
+
|
| 137 |
+
state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"β GEMINI AGENT ERROR: {e}")
|
| 141 |
+
traceback.print_exc()
|
| 142 |
+
state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
|
| 143 |
+
state["last_ai_message_time"] = time.time()
|
| 144 |
+
|
| 145 |
+
return state
|
| 146 |
+
|
| 147 |
+
def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
|
| 148 |
+
"""Enhance Gemini response to include document references and format nicely"""
|
| 149 |
+
if not sources or not answer:
|
| 150 |
+
return answer
|
| 151 |
+
|
| 152 |
+
# Use LLM to intelligently add document references and format nicely
|
| 153 |
+
try:
|
| 154 |
+
from src.llm.adapters import get_llm_client
|
| 155 |
+
llm = get_llm_client()
|
| 156 |
+
|
| 157 |
+
# Prepare document summaries for the LLM
|
| 158 |
+
doc_summaries = []
|
| 159 |
+
for idx, doc in enumerate(sources, 1):
|
| 160 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 161 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 162 |
+
|
| 163 |
+
filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 164 |
+
year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 165 |
+
source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
|
| 166 |
+
district = metadata.get('district', '') if isinstance(metadata, dict) else ''
|
| 167 |
+
|
| 168 |
+
doc_info = f"{filename}"
|
| 169 |
+
if year and year != 'Unknown':
|
| 170 |
+
doc_info += f" ({year})"
|
| 171 |
+
if source and source != 'Unknown':
|
| 172 |
+
doc_info += f" - {source}"
|
| 173 |
+
if district:
|
| 174 |
+
doc_info += f" - {district}"
|
| 175 |
+
|
| 176 |
+
doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...")
|
| 177 |
+
|
| 178 |
+
prompt = f"""You are enhancing a response from a document search system. The original response is:
|
| 179 |
+
|
| 180 |
+
{answer}
|
| 181 |
+
|
| 182 |
+
The following documents were retrieved and used to generate this response:
|
| 183 |
+
|
| 184 |
+
{chr(10).join(doc_summaries)}
|
| 185 |
+
|
| 186 |
+
CRITICAL RULES:
|
| 187 |
+
1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate
|
| 188 |
+
2. The response should ONLY contain information from the retrieved documents listed above
|
| 189 |
+
3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
|
| 190 |
+
4. Add document references [Doc i] at the end of sentences that use information from specific documents
|
| 191 |
+
5. Only reference documents that are actually used in the response
|
| 192 |
+
6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
|
| 193 |
+
7. Keep the response natural, conversational, and well-formatted
|
| 194 |
+
8. Use proper formatting: paragraphs, line breaks, and structure for readability
|
| 195 |
+
9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting
|
| 196 |
+
10. If multiple documents support the same claim, use [Doc i, Doc j] format
|
| 197 |
+
11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
|
| 198 |
+
|
| 199 |
+
Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary."""
|
| 200 |
+
|
| 201 |
+
enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
|
| 202 |
+
|
| 203 |
+
# Fallback: if LLM fails, just return original with basic formatting
|
| 204 |
+
if not enhanced or len(enhanced) < len(answer) * 0.5:
|
| 205 |
+
logger.warning("LLM enhancement failed, using original response with basic formatting")
|
| 206 |
+
# Basic formatting: add line breaks after periods for readability
|
| 207 |
+
formatted = answer.replace('. ', '.\n\n')
|
| 208 |
+
if sources:
|
| 209 |
+
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
|
| 210 |
+
formatted += f"\n\n*Based on documents: {ref_list}*"
|
| 211 |
+
return formatted
|
| 212 |
+
|
| 213 |
+
return enhanced
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.warning(f"Failed to enhance response with references: {e}")
|
| 217 |
+
# Fallback: add basic formatting and references at the end
|
| 218 |
+
formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting
|
| 219 |
+
if sources:
|
| 220 |
+
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
|
| 221 |
+
formatted += f"\n\n*Based on documents: {ref_list}*"
|
| 222 |
+
return formatted
|
| 223 |
+
|
| 224 |
+
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
|
| 225 |
+
"""Extract UI filters from query if present"""
|
| 226 |
+
filters = {}
|
| 227 |
+
|
| 228 |
+
if "FILTER CONTEXT:" in query:
|
| 229 |
+
filter_section = query.split("FILTER CONTEXT:")[1]
|
| 230 |
+
if "USER QUERY:" in filter_section:
|
| 231 |
+
filter_section = filter_section.split("USER QUERY:")[0]
|
| 232 |
+
filter_section = filter_section.strip()
|
| 233 |
+
|
| 234 |
+
if "Sources:" in filter_section:
|
| 235 |
+
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
|
| 236 |
+
if sources_line:
|
| 237 |
+
sources_str = sources_line[0].split("Sources:")[1].strip()
|
| 238 |
+
if sources_str and sources_str != "None":
|
| 239 |
+
filters["sources"] = [s.strip() for s in sources_str.split(",")]
|
| 240 |
+
|
| 241 |
+
if "Years:" in filter_section:
|
| 242 |
+
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
|
| 243 |
+
if years_line:
|
| 244 |
+
years_str = years_line[0].split("Years:")[1].strip()
|
| 245 |
+
if years_str and years_str != "None":
|
| 246 |
+
filters["year"] = [y.strip() for y in years_str.split(",")]
|
| 247 |
+
|
| 248 |
+
if "Districts:" in filter_section:
|
| 249 |
+
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
|
| 250 |
+
if districts_line:
|
| 251 |
+
districts_str = districts_line[0].split("Districts:")[1].strip()
|
| 252 |
+
if districts_str and districts_str != "None":
|
| 253 |
+
filters["district"] = [d.strip() for d in districts_str.split(",")]
|
| 254 |
+
|
| 255 |
+
if "Filenames:" in filter_section:
|
| 256 |
+
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
|
| 257 |
+
if filenames_line:
|
| 258 |
+
filenames_str = filenames_line[0].split("Filenames:")[1].strip()
|
| 259 |
+
if filenames_str and filenames_str != "None":
|
| 260 |
+
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
|
| 261 |
+
|
| 262 |
+
return filters
|
| 263 |
+
|
| 264 |
+
def _extract_context_from_conversation(
|
| 265 |
+
self,
|
| 266 |
+
messages: List[Any],
|
| 267 |
+
ui_filters: Dict[str, List[str]]
|
| 268 |
+
) -> Dict[str, Any]:
|
| 269 |
+
"""Extract context from conversation history"""
|
| 270 |
+
# Use UI filters if available
|
| 271 |
+
filters = ui_filters.copy() if ui_filters else {}
|
| 272 |
+
|
| 273 |
+
# For Gemini, we pass filters directly to the search function
|
| 274 |
+
# The filters will be used to add context to the query
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"filters": filters,
|
| 278 |
+
"has_filters": bool(filters)
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
|
| 282 |
+
"""Main chat interface"""
|
| 283 |
+
logger.info(f"π¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
|
| 284 |
+
|
| 285 |
+
# Load conversation
|
| 286 |
+
conversation_file = self.conversations_dir / f"{conversation_id}.json"
|
| 287 |
+
conversation = self._load_conversation(conversation_file)
|
| 288 |
+
|
| 289 |
+
# Add user message
|
| 290 |
+
conversation["messages"].append(HumanMessage(content=user_input))
|
| 291 |
+
|
| 292 |
+
# Prepare state
|
| 293 |
+
state = GeminiState(
|
| 294 |
+
conversation_id=conversation_id,
|
| 295 |
+
messages=conversation["messages"],
|
| 296 |
+
current_query=user_input,
|
| 297 |
+
query_context=None,
|
| 298 |
+
gemini_result=None,
|
| 299 |
+
final_response=None,
|
| 300 |
+
agent_logs=[],
|
| 301 |
+
conversation_context=conversation.get("context", {}),
|
| 302 |
+
session_start_time=conversation["session_start_time"],
|
| 303 |
+
last_ai_message_time=conversation["last_ai_message_time"],
|
| 304 |
+
filters=None
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Run graph
|
| 308 |
+
final_state = self.graph.invoke(state)
|
| 309 |
+
|
| 310 |
+
# Add AI response to conversation
|
| 311 |
+
if final_state["final_response"]:
|
| 312 |
+
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
|
| 313 |
+
|
| 314 |
+
# Update conversation
|
| 315 |
+
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
|
| 316 |
+
conversation["context"] = final_state["conversation_context"]
|
| 317 |
+
|
| 318 |
+
# Save conversation
|
| 319 |
+
self._save_conversation(conversation_file, conversation)
|
| 320 |
+
|
| 321 |
+
# Format sources for display
|
| 322 |
+
sources = []
|
| 323 |
+
gemini_result = final_state.get("gemini_result")
|
| 324 |
+
if gemini_result:
|
| 325 |
+
sources = self.gemini_client.format_sources_for_display(gemini_result)
|
| 326 |
+
logger.info(f"π GEMINI CHAT: Formatted {len(sources)} sources for display")
|
| 327 |
+
|
| 328 |
+
return {
|
| 329 |
+
'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
|
| 330 |
+
'rag_result': {
|
| 331 |
+
'sources': sources,
|
| 332 |
+
'answer': final_state["final_response"]
|
| 333 |
+
},
|
| 334 |
+
'agent_logs': final_state["agent_logs"],
|
| 335 |
+
'actual_rag_query': final_state["current_query"],
|
| 336 |
+
'gemini_result': gemini_result # Include raw result for tracking
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
|
| 340 |
+
"""Load conversation from file"""
|
| 341 |
+
if conversation_file.exists():
|
| 342 |
+
try:
|
| 343 |
+
with open(conversation_file) as f:
|
| 344 |
+
data = json.load(f)
|
| 345 |
+
messages = []
|
| 346 |
+
for msg_data in data.get("messages", []):
|
| 347 |
+
if msg_data["type"] == "human":
|
| 348 |
+
messages.append(HumanMessage(content=msg_data["content"]))
|
| 349 |
+
elif msg_data["type"] == "ai":
|
| 350 |
+
messages.append(AIMessage(content=msg_data["content"]))
|
| 351 |
+
data["messages"] = messages
|
| 352 |
+
return data
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.warning(f"Could not load conversation: {e}")
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"messages": [],
|
| 358 |
+
"session_start_time": time.time(),
|
| 359 |
+
"last_ai_message_time": time.time(),
|
| 360 |
+
"context": {}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
|
| 364 |
+
"""Save conversation to file"""
|
| 365 |
+
try:
|
| 366 |
+
conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 367 |
+
|
| 368 |
+
messages_data = []
|
| 369 |
+
for msg in conversation["messages"]:
|
| 370 |
+
if isinstance(msg, HumanMessage):
|
| 371 |
+
messages_data.append({"type": "human", "content": msg.content})
|
| 372 |
+
elif isinstance(msg, AIMessage):
|
| 373 |
+
messages_data.append({"type": "ai", "content": msg.content})
|
| 374 |
+
|
| 375 |
+
conversation_data = {
|
| 376 |
+
"messages": messages_data,
|
| 377 |
+
"session_start_time": conversation["session_start_time"],
|
| 378 |
+
"last_ai_message_time": conversation["last_ai_message_time"],
|
| 379 |
+
"context": conversation.get("context", {})
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
with open(conversation_file, 'w') as f:
|
| 383 |
+
json.dump(conversation_data, f, indent=2)
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"Could not save conversation: {e}")
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def get_gemini_chatbot():
|
| 390 |
+
"""Get Gemini chatbot instance"""
|
| 391 |
+
return GeminiRAGChatbot()
|
| 392 |
+
|
multi_agent_chatbot.py β src/agents/multi_agent_chatbot.py
RENAMED
|
@@ -8,24 +8,26 @@ This system implements a 3-agent architecture:
|
|
| 8 |
|
| 9 |
Each agent has specialized prompts and responsibilities.
|
| 10 |
"""
|
|
|
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
import logging
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
from datetime import datetime
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from typing import Dict, List, Any, Optional, TypedDict
|
| 18 |
|
| 19 |
-
|
| 20 |
from langchain_core.tools import tool
|
| 21 |
from langgraph.graph import StateGraph, END
|
| 22 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 23 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
from src.pipeline import PipelineManager
|
| 27 |
-
from src.config.loader import load_config
|
| 28 |
from src.llm.adapters import get_llm_client
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -46,6 +48,7 @@ class QueryContext:
|
|
| 46 |
needs_follow_up: bool = False
|
| 47 |
follow_up_question: Optional[str] = None
|
| 48 |
|
|
|
|
| 49 |
class MultiAgentState(TypedDict):
|
| 50 |
"""State for the multi-agent conversation flow"""
|
| 51 |
conversation_id: str
|
|
@@ -61,6 +64,7 @@ class MultiAgentState(TypedDict):
|
|
| 61 |
session_start_time: float
|
| 62 |
last_ai_message_time: float
|
| 63 |
|
|
|
|
| 64 |
class MultiAgentRAGChatbot:
|
| 65 |
"""Multi-agent RAG chatbot with specialized agents"""
|
| 66 |
|
|
@@ -112,7 +116,6 @@ class MultiAgentRAGChatbot:
|
|
| 112 |
logger.info("β
Pipeline manager initialized and models loaded")
|
| 113 |
except Exception as e:
|
| 114 |
logger.error(f"β Failed to initialize pipeline manager: {e}")
|
| 115 |
-
import traceback
|
| 116 |
traceback.print_exc()
|
| 117 |
raise RuntimeError(f"Pipeline manager initialization failed: {e}")
|
| 118 |
|
|
@@ -129,7 +132,6 @@ class MultiAgentRAGChatbot:
|
|
| 129 |
raise # Re-raise RuntimeError as-is
|
| 130 |
except Exception as e:
|
| 131 |
logger.error(f"β Error during vector store connection: {e}")
|
| 132 |
-
import traceback
|
| 133 |
traceback.print_exc()
|
| 134 |
raise RuntimeError(f"Vector store connection failed: {e}")
|
| 135 |
|
|
@@ -139,8 +141,8 @@ class MultiAgentRAGChatbot:
|
|
| 139 |
# Build the multi-agent graph
|
| 140 |
self.graph = self._build_graph()
|
| 141 |
|
| 142 |
-
# Conversations directory - use
|
| 143 |
-
self.conversations_dir =
|
| 144 |
try:
|
| 145 |
# Use 777 permissions for maximum compatibility (HF Spaces runs as different user)
|
| 146 |
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
|
@@ -158,9 +160,9 @@ class MultiAgentRAGChatbot:
|
|
| 158 |
|
| 159 |
def _load_dynamic_data(self):
|
| 160 |
"""Load dynamic data from filter_options.json and add_district_metadata.py"""
|
| 161 |
-
# Load filter options
|
| 162 |
try:
|
| 163 |
-
fo =
|
| 164 |
if fo.exists():
|
| 165 |
with open(fo) as f:
|
| 166 |
data = json.load(f)
|
|
@@ -178,7 +180,7 @@ class MultiAgentRAGChatbot:
|
|
| 178 |
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
|
| 179 |
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
|
| 180 |
|
| 181 |
-
# Enrich district list from add_district_metadata.py
|
| 182 |
try:
|
| 183 |
from add_district_metadata import DistrictMetadataProcessor
|
| 184 |
proc = DistrictMetadataProcessor()
|
|
@@ -206,6 +208,59 @@ class MultiAgentRAGChatbot:
|
|
| 206 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 207 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
def _build_graph(self) -> StateGraph:
|
| 210 |
"""Build the multi-agent LangGraph"""
|
| 211 |
graph = StateGraph(MultiAgentState)
|
|
@@ -510,6 +565,10 @@ class MultiAgentRAGChatbot:
|
|
| 510 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 511 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 512 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
- If no exact matches found, set extracted values to null
|
| 514 |
|
| 515 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
@@ -590,7 +649,6 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 590 |
# Clean and parse JSON with better error handling
|
| 591 |
try:
|
| 592 |
# Remove comments (// and /* */) from JSON
|
| 593 |
-
import re
|
| 594 |
# Remove single-line comments
|
| 595 |
content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
|
| 596 |
# Remove multi-line comments
|
|
@@ -603,7 +661,6 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 603 |
logger.error(f"β Raw content: {content[:200]}...")
|
| 604 |
|
| 605 |
# Try to extract JSON from text if embedded
|
| 606 |
-
import re
|
| 607 |
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
| 608 |
if json_match:
|
| 609 |
try:
|
|
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 656 |
# Validate each district in the array
|
| 657 |
valid_districts = []
|
| 658 |
for district in extracted_district:
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
# Try removing "District" suffix
|
| 663 |
-
district_name = district.replace(" District", "").replace(" district", "")
|
| 664 |
-
if district_name in self.district_whitelist:
|
| 665 |
-
valid_districts.append(district_name)
|
| 666 |
|
| 667 |
if valid_districts:
|
| 668 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
|
|
| 671 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 672 |
extracted_district = None
|
| 673 |
else:
|
| 674 |
-
# Single district validation
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
extracted_district = None
|
| 684 |
|
| 685 |
# Validate source (handle both single values and arrays)
|
| 686 |
if extracted_source:
|
|
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
|
|
| 918 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 919 |
|
| 920 |
# Merge with extracted context for missing filters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
if not filters.get("year") and context.extracted_year:
|
| 922 |
# Handle both single values and arrays
|
| 923 |
if isinstance(context.extracted_year, list):
|
|
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
|
|
| 926 |
filters["year"] = [context.extracted_year]
|
| 927 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 928 |
|
| 929 |
-
if not filters.get("district") and context.extracted_district:
|
| 930 |
-
# Handle both single values and arrays
|
| 931 |
-
if isinstance(context.extracted_district, list):
|
| 932 |
-
# Normalize district names to title case (match Qdrant metadata format)
|
| 933 |
-
normalized = [d.title() for d in context.extracted_district]
|
| 934 |
-
filters["district"] = normalized
|
| 935 |
-
else:
|
| 936 |
-
filters["district"] = [context.extracted_district.title()]
|
| 937 |
-
logger.info(f"π§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
|
| 938 |
-
|
| 939 |
if not filters.get("sources") and context.extracted_source:
|
| 940 |
# Handle both single values and arrays
|
| 941 |
if isinstance(context.extracted_source, list):
|
|
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
|
|
| 963 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 964 |
|
| 965 |
if context.extracted_district:
|
| 966 |
-
#
|
| 967 |
if isinstance(context.extracted_district, list):
|
| 968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
else:
|
| 970 |
-
|
| 971 |
-
|
|
|
|
|
|
|
| 972 |
|
| 973 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 974 |
return filters
|
|
@@ -978,49 +1046,212 @@ Rewrite the best retrieval query:""")
|
|
| 978 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 979 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 980 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
# Create response prompt
|
| 983 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 984 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 985 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 986 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 987 |
RULES:
|
| 988 |
1. Answer the user's question directly and clearly
|
| 989 |
-
2. Use the retrieved documents as evidence
|
| 990 |
3. Be conversational, not technical
|
| 991 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 992 |
5. If relevant documents were found, reference them naturally
|
| 993 |
-
6. If no relevant documents,
|
| 994 |
-
7. If the passages have useful facts or numbers, use them in your answer
|
| 995 |
-
8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 996 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 997 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 998 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 999 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1000 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1001 |
-
14.
|
| 1002 |
-
|
|
|
|
| 1003 |
|
| 1004 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1005 |
-
HumanMessage(content=f"""
|
|
|
|
|
|
|
|
|
|
| 1006 |
|
| 1007 |
Retrieved Documents: {len(documents)} documents found
|
| 1008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
RAG Answer: {rag_answer}
|
| 1010 |
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
])
|
| 1013 |
|
| 1014 |
try:
|
| 1015 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1016 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1017 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
except Exception as e:
|
| 1020 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1021 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1022 |
return rag_answer # Fallback to RAG answer
|
| 1023 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1025 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1026 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
|
@@ -1178,7 +1409,6 @@ Generate a conversational response based on your knowledge:""")
|
|
| 1178 |
|
| 1179 |
except Exception as e:
|
| 1180 |
logger.error(f"Could not save conversation: {e}")
|
| 1181 |
-
import traceback
|
| 1182 |
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 1183 |
|
| 1184 |
|
|
|
|
| 8 |
|
| 9 |
Each agent has specialized prompts and responsibilities.
|
| 10 |
"""
|
| 11 |
+
import re
|
| 12 |
import json
|
| 13 |
import time
|
| 14 |
import logging
|
| 15 |
+
import traceback
|
| 16 |
from pathlib import Path
|
| 17 |
from datetime import datetime
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from typing import Dict, List, Any, Optional, TypedDict
|
| 20 |
|
|
|
|
| 21 |
from langchain_core.tools import tool
|
| 22 |
from langgraph.graph import StateGraph, END
|
|
|
|
| 23 |
from langchain_core.prompts import ChatPromptTemplate
|
| 24 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 25 |
|
| 26 |
|
| 27 |
from src.pipeline import PipelineManager
|
|
|
|
| 28 |
from src.llm.adapters import get_llm_client
|
| 29 |
+
from src.config.paths import PROJECT_DIR, CONVERSATIONS_DIR
|
| 30 |
+
from src.config.loader import load_config, get_embedding_model_for_collection
|
| 31 |
|
| 32 |
|
| 33 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 48 |
needs_follow_up: bool = False
|
| 49 |
follow_up_question: Optional[str] = None
|
| 50 |
|
| 51 |
+
|
| 52 |
class MultiAgentState(TypedDict):
|
| 53 |
"""State for the multi-agent conversation flow"""
|
| 54 |
conversation_id: str
|
|
|
|
| 64 |
session_start_time: float
|
| 65 |
last_ai_message_time: float
|
| 66 |
|
| 67 |
+
|
| 68 |
class MultiAgentRAGChatbot:
|
| 69 |
"""Multi-agent RAG chatbot with specialized agents"""
|
| 70 |
|
|
|
|
| 116 |
logger.info("β
Pipeline manager initialized and models loaded")
|
| 117 |
except Exception as e:
|
| 118 |
logger.error(f"β Failed to initialize pipeline manager: {e}")
|
|
|
|
| 119 |
traceback.print_exc()
|
| 120 |
raise RuntimeError(f"Pipeline manager initialization failed: {e}")
|
| 121 |
|
|
|
|
| 132 |
raise # Re-raise RuntimeError as-is
|
| 133 |
except Exception as e:
|
| 134 |
logger.error(f"β Error during vector store connection: {e}")
|
|
|
|
| 135 |
traceback.print_exc()
|
| 136 |
raise RuntimeError(f"Vector store connection failed: {e}")
|
| 137 |
|
|
|
|
| 141 |
# Build the multi-agent graph
|
| 142 |
self.graph = self._build_graph()
|
| 143 |
|
| 144 |
+
# Conversations directory - use PROJECT_DIR for local vs deployed compatibility
|
| 145 |
+
self.conversations_dir = CONVERSATIONS_DIR
|
| 146 |
try:
|
| 147 |
# Use 777 permissions for maximum compatibility (HF Spaces runs as different user)
|
| 148 |
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
|
|
|
| 160 |
|
| 161 |
def _load_dynamic_data(self):
|
| 162 |
"""Load dynamic data from filter_options.json and add_district_metadata.py"""
|
| 163 |
+
# Load filter options - use PROJECT_DIR relative path
|
| 164 |
try:
|
| 165 |
+
fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
|
| 166 |
if fo.exists():
|
| 167 |
with open(fo) as f:
|
| 168 |
data = json.load(f)
|
|
|
|
| 180 |
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
|
| 181 |
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
|
| 182 |
|
| 183 |
+
# Enrich district list from add_district_metadata.py (if available)
|
| 184 |
try:
|
| 185 |
from add_district_metadata import DistrictMetadataProcessor
|
| 186 |
proc = DistrictMetadataProcessor()
|
|
|
|
| 208 |
logger.info(f" Sources: {self.source_whitelist}")
|
| 209 |
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 210 |
|
| 211 |
+
def _normalize_district_name(self, district: str) -> Optional[str]:
|
| 212 |
+
"""Normalize district name with fuzzy matching for common misspellings."""
|
| 213 |
+
if not district:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
district = district.strip()
|
| 217 |
+
|
| 218 |
+
# Direct match
|
| 219 |
+
if district in self.district_whitelist:
|
| 220 |
+
return district
|
| 221 |
+
|
| 222 |
+
# Remove "District" suffix
|
| 223 |
+
district_name = district.replace(" District", "").replace(" district", "").strip()
|
| 224 |
+
if district_name in self.district_whitelist:
|
| 225 |
+
return district_name
|
| 226 |
+
|
| 227 |
+
# Common misspellings mapping
|
| 228 |
+
misspelling_map = {
|
| 229 |
+
"kalagala": "Kalangala",
|
| 230 |
+
"Kalagala": "Kalangala",
|
| 231 |
+
"KALAGALA": "Kalangala",
|
| 232 |
+
"kalangala": "Kalangala",
|
| 233 |
+
"gulu": "Gulu",
|
| 234 |
+
"GULU": "Gulu",
|
| 235 |
+
"kampala": "Kampala",
|
| 236 |
+
"KAMPALA": "Kampala",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# Check misspelling map (case-insensitive)
|
| 240 |
+
district_lower = district_name.lower()
|
| 241 |
+
if district_lower in misspelling_map:
|
| 242 |
+
corrected = misspelling_map[district_lower]
|
| 243 |
+
if corrected in self.district_whitelist:
|
| 244 |
+
return corrected
|
| 245 |
+
|
| 246 |
+
# Fuzzy matching for similar names (simple Levenshtein-like check)
|
| 247 |
+
# Check if the district name is very similar to any whitelist entry
|
| 248 |
+
for whitelist_district in self.district_whitelist:
|
| 249 |
+
# Case-insensitive comparison
|
| 250 |
+
if district_name.lower() == whitelist_district.lower():
|
| 251 |
+
return whitelist_district
|
| 252 |
+
|
| 253 |
+
# Check if one is a substring of the other (for partial matches)
|
| 254 |
+
if len(district_name) >= 4 and len(whitelist_district) >= 4:
|
| 255 |
+
if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
|
| 256 |
+
# Only return if it's a strong match (at least 80% of characters match)
|
| 257 |
+
min_len = min(len(district_name), len(whitelist_district))
|
| 258 |
+
max_len = max(len(district_name), len(whitelist_district))
|
| 259 |
+
if min_len / max_len >= 0.8:
|
| 260 |
+
return whitelist_district
|
| 261 |
+
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
def _build_graph(self) -> StateGraph:
|
| 265 |
"""Build the multi-agent LangGraph"""
|
| 266 |
graph = StateGraph(MultiAgentState)
|
|
|
|
| 565 |
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 566 |
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 567 |
- Always return districts as JSON arrays when multiple districts are mentioned
|
| 568 |
+
- **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
|
| 569 |
+
* "Kalagala" (missing 'n') should be extracted as "Kalangala"
|
| 570 |
+
* "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
|
| 571 |
+
* Similar case-insensitive variations should be normalized to the correct district name
|
| 572 |
- If no exact matches found, set extracted values to null
|
| 573 |
|
| 574 |
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
|
|
|
| 649 |
# Clean and parse JSON with better error handling
|
| 650 |
try:
|
| 651 |
# Remove comments (// and /* */) from JSON
|
|
|
|
| 652 |
# Remove single-line comments
|
| 653 |
content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
|
| 654 |
# Remove multi-line comments
|
|
|
|
| 661 |
logger.error(f"β Raw content: {content[:200]}...")
|
| 662 |
|
| 663 |
# Try to extract JSON from text if embedded
|
|
|
|
| 664 |
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
| 665 |
if json_match:
|
| 666 |
try:
|
|
|
|
| 713 |
# Validate each district in the array
|
| 714 |
valid_districts = []
|
| 715 |
for district in extracted_district:
|
| 716 |
+
normalized = self._normalize_district_name(district)
|
| 717 |
+
if normalized:
|
| 718 |
+
valid_districts.append(normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
if valid_districts:
|
| 721 |
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
|
|
|
| 724 |
logger.warning(f"β οΈ No valid districts found in: '{extracted_district}'")
|
| 725 |
extracted_district = None
|
| 726 |
else:
|
| 727 |
+
# Single district validation with fuzzy matching
|
| 728 |
+
normalized = self._normalize_district_name(extracted_district)
|
| 729 |
+
if normalized:
|
| 730 |
+
if normalized != extracted_district:
|
| 731 |
+
logger.info(f"π QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
|
| 732 |
+
extracted_district = normalized
|
| 733 |
+
else:
|
| 734 |
+
logger.warning(f"β οΈ Invalid district extracted: '{extracted_district}' not in whitelist")
|
| 735 |
+
extracted_district = None
|
|
|
|
| 736 |
|
| 737 |
# Validate source (handle both single values and arrays)
|
| 738 |
if extracted_source:
|
|
|
|
| 970 |
logger.info(f"π§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β normalized: {normalized_districts}")
|
| 971 |
|
| 972 |
# Merge with extracted context for missing filters
|
| 973 |
+
if not filters.get("district") and context.extracted_district:
|
| 974 |
+
# Normalize district names using the normalization function
|
| 975 |
+
if isinstance(context.extracted_district, list):
|
| 976 |
+
normalized_districts = []
|
| 977 |
+
for d in context.extracted_district:
|
| 978 |
+
normalized = self._normalize_district_name(d)
|
| 979 |
+
if normalized:
|
| 980 |
+
normalized_districts.append(normalized)
|
| 981 |
+
if normalized_districts:
|
| 982 |
+
filters["district"] = normalized_districts
|
| 983 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 984 |
+
else:
|
| 985 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 986 |
+
if normalized:
|
| 987 |
+
filters["district"] = [normalized]
|
| 988 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 989 |
+
|
| 990 |
if not filters.get("year") and context.extracted_year:
|
| 991 |
# Handle both single values and arrays
|
| 992 |
if isinstance(context.extracted_year, list):
|
|
|
|
| 995 |
filters["year"] = [context.extracted_year]
|
| 996 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
if not filters.get("sources") and context.extracted_source:
|
| 999 |
# Handle both single values and arrays
|
| 1000 |
if isinstance(context.extracted_source, list):
|
|
|
|
| 1022 |
logger.info(f"π§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 1023 |
|
| 1024 |
if context.extracted_district:
|
| 1025 |
+
# Normalize district names using the normalization function
|
| 1026 |
if isinstance(context.extracted_district, list):
|
| 1027 |
+
normalized_districts = []
|
| 1028 |
+
for d in context.extracted_district:
|
| 1029 |
+
normalized = self._normalize_district_name(d)
|
| 1030 |
+
if normalized:
|
| 1031 |
+
normalized_districts.append(normalized)
|
| 1032 |
+
if normalized_districts:
|
| 1033 |
+
filters["district"] = normalized_districts
|
| 1034 |
+
logger.info(f"π§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β normalized: {normalized_districts}")
|
| 1035 |
else:
|
| 1036 |
+
normalized = self._normalize_district_name(context.extracted_district)
|
| 1037 |
+
if normalized:
|
| 1038 |
+
filters["district"] = [normalized]
|
| 1039 |
+
logger.info(f"π§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β normalized: {normalized}")
|
| 1040 |
|
| 1041 |
logger.info(f"π§ FILTER BUILDING: Final filters: {filters}")
|
| 1042 |
return filters
|
|
|
|
| 1046 |
logger.info("π¬ RESPONSE GENERATION: Starting conversational response generation")
|
| 1047 |
logger.info(f"π¬ RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 1048 |
logger.info(f"π¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
|
| 1049 |
+
logger.info(f"π¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
|
| 1050 |
+
|
| 1051 |
+
# Build conversation history context
|
| 1052 |
+
conversation_context = self._build_conversation_context(messages)
|
| 1053 |
+
|
| 1054 |
+
# Build detailed document information
|
| 1055 |
+
document_details = self._build_document_details(documents)
|
| 1056 |
+
|
| 1057 |
+
# Extract correct district/source/year names from documents (to correct misspellings)
|
| 1058 |
+
correct_names = self._extract_correct_names_from_documents(documents)
|
| 1059 |
|
| 1060 |
# Create response prompt
|
| 1061 |
logger.info(f"π¬ RESPONSE GENERATION: Building response prompt")
|
| 1062 |
response_prompt = ChatPromptTemplate.from_messages([
|
| 1063 |
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 1064 |
|
| 1065 |
+
CRITICAL RULES - NO HALLUCINATION:
|
| 1066 |
+
1. **ONLY use information from the retrieved documents provided below**
|
| 1067 |
+
2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
|
| 1068 |
+
3. **If a document doesn't contain the information, DO NOT make it up**
|
| 1069 |
+
4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
|
| 1070 |
+
5. **Check the document years/districts before making any claims about them**
|
| 1071 |
+
6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
|
| 1072 |
+
|
| 1073 |
RULES:
|
| 1074 |
1. Answer the user's question directly and clearly
|
| 1075 |
+
2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
|
| 1076 |
3. Be conversational, not technical
|
| 1077 |
4. Don't mention scores, retrieval details, or technical implementation
|
| 1078 |
5. If relevant documents were found, reference them naturally
|
| 1079 |
+
6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
|
| 1080 |
+
7. If the passages have useful facts or numbers, use them in your answer WITH references
|
| 1081 |
+
8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 1082 |
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 1083 |
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 1084 |
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 1085 |
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 1086 |
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 1087 |
+
14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
|
| 1088 |
+
15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
|
| 1089 |
+
16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
|
| 1090 |
|
| 1091 |
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1092 |
+
HumanMessage(content=f"""Conversation History:
|
| 1093 |
+
{conversation_context}
|
| 1094 |
+
|
| 1095 |
+
Current User Question: {query}
|
| 1096 |
|
| 1097 |
Retrieved Documents: {len(documents)} documents found
|
| 1098 |
|
| 1099 |
+
CORRECT NAMES TO USE (from document metadata - use these exact spellings):
|
| 1100 |
+
{correct_names}
|
| 1101 |
+
|
| 1102 |
+
Full Document Details:
|
| 1103 |
+
{document_details}
|
| 1104 |
+
|
| 1105 |
RAG Answer: {rag_answer}
|
| 1106 |
|
| 1107 |
+
CRITICAL:
|
| 1108 |
+
- Responses should be grounded to what is available in the retrieved documents
|
| 1109 |
+
- If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
|
| 1110 |
+
- Every factual claim MUST have [Doc i] reference
|
| 1111 |
+
- If information is not in documents, explicitly state it's not available
|
| 1112 |
+
- **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
|
| 1113 |
+
|
| 1114 |
+
Generate a conversational response with proper document references:""")
|
| 1115 |
])
|
| 1116 |
|
| 1117 |
try:
|
| 1118 |
logger.info(f"π¬ RESPONSE GENERATION: Calling LLM for final response")
|
| 1119 |
response = self.llm.invoke(response_prompt.format_messages())
|
| 1120 |
logger.info(f"π¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 1121 |
+
|
| 1122 |
+
# Post-process response to ensure no hallucination
|
| 1123 |
+
final_response = self._validate_and_enhance_response(
|
| 1124 |
+
response.content.strip(),
|
| 1125 |
+
documents,
|
| 1126 |
+
query
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
return final_response
|
| 1130 |
except Exception as e:
|
| 1131 |
logger.error(f"β RESPONSE GENERATION: Error during generation: {e}")
|
| 1132 |
logger.info(f"π¬ RESPONSE GENERATION: Using RAG answer as fallback")
|
| 1133 |
return rag_answer # Fallback to RAG answer
|
| 1134 |
|
| 1135 |
+
def _build_conversation_context(self, messages: List[Any]) -> str:
|
| 1136 |
+
"""Build conversation history context for response generation."""
|
| 1137 |
+
if not messages:
|
| 1138 |
+
return "No previous conversation."
|
| 1139 |
+
|
| 1140 |
+
context_lines = []
|
| 1141 |
+
# Show last 6 messages for context (to capture the current exchange)
|
| 1142 |
+
for msg in messages[-6:]:
|
| 1143 |
+
if isinstance(msg, HumanMessage):
|
| 1144 |
+
context_lines.append(f"User: {msg.content}")
|
| 1145 |
+
elif isinstance(msg, AIMessage):
|
| 1146 |
+
context_lines.append(f"Assistant: {msg.content}")
|
| 1147 |
+
|
| 1148 |
+
return "\n".join(context_lines) if context_lines else "No previous conversation."
|
| 1149 |
+
|
| 1150 |
+
def _build_document_details(self, documents: List[Any]) -> str:
|
| 1151 |
+
"""Build detailed document information for response generation."""
|
| 1152 |
+
if not documents:
|
| 1153 |
+
return "No documents retrieved."
|
| 1154 |
+
|
| 1155 |
+
details = []
|
| 1156 |
+
for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
|
| 1157 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1158 |
+
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
|
| 1159 |
+
|
| 1160 |
+
if isinstance(metadata, dict):
|
| 1161 |
+
filename = metadata.get('filename', 'Unknown')
|
| 1162 |
+
year = metadata.get('year', 'Unknown')
|
| 1163 |
+
district = metadata.get('district', 'Unknown')
|
| 1164 |
+
source = metadata.get('source', 'Unknown')
|
| 1165 |
+
page = metadata.get('page', metadata.get('page_label', 'Unknown'))
|
| 1166 |
+
|
| 1167 |
+
doc_info = f"[Doc {i}]"
|
| 1168 |
+
doc_info += f"\n Filename: {filename}"
|
| 1169 |
+
doc_info += f"\n Year: {year}"
|
| 1170 |
+
doc_info += f"\n District: {district}"
|
| 1171 |
+
doc_info += f"\n Source: {source}"
|
| 1172 |
+
if page != 'Unknown':
|
| 1173 |
+
doc_info += f"\n Page: {page}"
|
| 1174 |
+
doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
|
| 1175 |
+
details.append(doc_info)
|
| 1176 |
+
|
| 1177 |
+
return "\n\n".join(details) if details else "No document details available."
|
| 1178 |
+
|
| 1179 |
+
def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
|
| 1180 |
+
"""Extract correct district/source names from documents to correct misspellings."""
|
| 1181 |
+
districts = set()
|
| 1182 |
+
sources = set()
|
| 1183 |
+
years = set()
|
| 1184 |
+
|
| 1185 |
+
for doc in documents:
|
| 1186 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1187 |
+
if isinstance(metadata, dict):
|
| 1188 |
+
if metadata.get('district'):
|
| 1189 |
+
districts.add(str(metadata['district']))
|
| 1190 |
+
if metadata.get('source'):
|
| 1191 |
+
sources.add(str(metadata['source']))
|
| 1192 |
+
if metadata.get('year'):
|
| 1193 |
+
years.add(str(metadata['year']))
|
| 1194 |
+
|
| 1195 |
+
result = []
|
| 1196 |
+
if districts:
|
| 1197 |
+
result.append(f"Districts: {', '.join(sorted(districts))}")
|
| 1198 |
+
if sources:
|
| 1199 |
+
result.append(f"Sources: {', '.join(sorted(sources))}")
|
| 1200 |
+
if years:
|
| 1201 |
+
result.append(f"Years: {', '.join(sorted(years))}")
|
| 1202 |
+
|
| 1203 |
+
if result:
|
| 1204 |
+
return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
|
| 1205 |
+
return "No metadata available."
|
| 1206 |
+
|
| 1207 |
+
def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
|
| 1208 |
+
"""Validate response and ensure all claims are referenced."""
|
| 1209 |
+
# Extract years and districts from documents
|
| 1210 |
+
doc_years = set()
|
| 1211 |
+
doc_districts = set()
|
| 1212 |
+
doc_sources = set()
|
| 1213 |
+
|
| 1214 |
+
for doc in documents:
|
| 1215 |
+
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
|
| 1216 |
+
if isinstance(metadata, dict):
|
| 1217 |
+
if metadata.get('year'):
|
| 1218 |
+
doc_years.add(str(metadata['year']))
|
| 1219 |
+
if metadata.get('district'):
|
| 1220 |
+
doc_districts.add(str(metadata['district']))
|
| 1221 |
+
if metadata.get('source'):
|
| 1222 |
+
doc_sources.add(str(metadata['source']))
|
| 1223 |
+
|
| 1224 |
+
# Correct misspellings in response using correct names from documents
|
| 1225 |
+
# response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
|
| 1226 |
+
|
| 1227 |
+
# Check if response mentions years not in documents
|
| 1228 |
+
year_pattern = r'\b(20\d{2})\b'
|
| 1229 |
+
mentioned_years = set(re.findall(year_pattern, response))
|
| 1230 |
+
|
| 1231 |
+
# Check if user query mentions a year
|
| 1232 |
+
query_years = set(re.findall(year_pattern, query))
|
| 1233 |
+
|
| 1234 |
+
# If user asks about a year not in documents, add a warning
|
| 1235 |
+
missing_years = query_years - doc_years
|
| 1236 |
+
if missing_years and doc_years:
|
| 1237 |
+
warning = f"\n\nβ οΈ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
|
| 1238 |
+
if warning not in response:
|
| 1239 |
+
response = response + warning
|
| 1240 |
+
|
| 1241 |
+
# Check if response has document references
|
| 1242 |
+
doc_ref_pattern = r'\[Doc\s+\d+\]'
|
| 1243 |
+
has_refs = bool(re.search(doc_ref_pattern, response))
|
| 1244 |
+
|
| 1245 |
+
# If response has factual claims but no references, add a note
|
| 1246 |
+
if not has_refs and len(documents) > 0:
|
| 1247 |
+
# Check if response has numbers or specific claims (simple heuristic)
|
| 1248 |
+
has_numbers = bool(re.search(r'\d+', response))
|
| 1249 |
+
if has_numbers and len(response) > 50:
|
| 1250 |
+
logger.warning("β οΈ Response contains factual claims but no document references")
|
| 1251 |
+
# Don't modify response, but log the issue
|
| 1252 |
+
|
| 1253 |
+
return response
|
| 1254 |
+
|
| 1255 |
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1256 |
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1257 |
logger.info("π¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
|
|
|
| 1409 |
|
| 1410 |
except Exception as e:
|
| 1411 |
logger.error(f"Could not save conversation: {e}")
|
|
|
|
| 1412 |
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 1413 |
|
| 1414 |
|
smart_chatbot.py β src/agents/smart_chatbot.py
RENAMED
|
@@ -26,6 +26,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
| 26 |
|
| 27 |
from src.pipeline import PipelineManager
|
| 28 |
from src.config.loader import load_config
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
@dataclass
|
|
@@ -161,7 +162,7 @@ class IntelligentRAGChatbot:
|
|
| 161 |
|
| 162 |
# Try to load district whitelist from filter_options.json
|
| 163 |
try:
|
| 164 |
-
fo =
|
| 165 |
if fo.exists():
|
| 166 |
with open(fo) as f:
|
| 167 |
data = json.load(f)
|
|
@@ -174,7 +175,7 @@ class IntelligentRAGChatbot:
|
|
| 174 |
except Exception:
|
| 175 |
self.district_whitelist = self.available_metadata['districts']
|
| 176 |
|
| 177 |
-
# Enrich whitelist from add_district_metadata.py if available
|
| 178 |
try:
|
| 179 |
from add_district_metadata import DistrictMetadataProcessor
|
| 180 |
proc = DistrictMetadataProcessor()
|
|
@@ -195,7 +196,7 @@ class IntelligentRAGChatbot:
|
|
| 195 |
|
| 196 |
# Get dynamic year list from filter_options.json
|
| 197 |
try:
|
| 198 |
-
fo =
|
| 199 |
if fo.exists():
|
| 200 |
with open(fo) as f:
|
| 201 |
data = json.load(f)
|
|
|
|
| 26 |
|
| 27 |
from src.pipeline import PipelineManager
|
| 28 |
from src.config.loader import load_config
|
| 29 |
+
from src.config.paths import PROJECT_DIR
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
|
|
|
| 162 |
|
| 163 |
# Try to load district whitelist from filter_options.json
|
| 164 |
try:
|
| 165 |
+
fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
|
| 166 |
if fo.exists():
|
| 167 |
with open(fo) as f:
|
| 168 |
data = json.load(f)
|
|
|
|
| 175 |
except Exception:
|
| 176 |
self.district_whitelist = self.available_metadata['districts']
|
| 177 |
|
| 178 |
+
# Enrich whitelist from add_district_metadata.py if available (optional module)
|
| 179 |
try:
|
| 180 |
from add_district_metadata import DistrictMetadataProcessor
|
| 181 |
proc = DistrictMetadataProcessor()
|
|
|
|
| 196 |
|
| 197 |
# Get dynamic year list from filter_options.json
|
| 198 |
try:
|
| 199 |
+
fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
|
| 200 |
if fo.exists():
|
| 201 |
with open(fo) as f:
|
| 202 |
data = json.load(f)
|
src/config/paths.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Path configuration for local vs deployed environments.
|
| 3 |
+
|
| 4 |
+
This module handles different paths for local development vs deployed (HF Spaces) environments.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Determine if we're in a deployed environment (HF Spaces/Docker) or local
|
| 10 |
+
# Check for environment variable or Docker-like paths
|
| 11 |
+
IS_DEPLOYED = (
|
| 12 |
+
os.getenv("DEPLOYED", "false").lower() == "true" or
|
| 13 |
+
os.path.exists("/app") or
|
| 14 |
+
os.getenv("SPACES_ID") is not None or
|
| 15 |
+
os.path.exists("/.dockerenv")
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# PROJECT_DIR: Base directory for application files
|
| 19 |
+
# In deployed: /app, in local: current working directory or project root
|
| 20 |
+
if IS_DEPLOYED:
|
| 21 |
+
PROJECT_DIR = Path("/app")
|
| 22 |
+
else:
|
| 23 |
+
# For local development, use current working directory or find project root
|
| 24 |
+
cwd = Path.cwd()
|
| 25 |
+
# Try to find project root (directory containing this src/ folder)
|
| 26 |
+
project_root = cwd
|
| 27 |
+
while project_root != project_root.parent:
|
| 28 |
+
if (project_root / "src" / "config").exists():
|
| 29 |
+
break
|
| 30 |
+
project_root = project_root.parent
|
| 31 |
+
PROJECT_DIR = project_root
|
| 32 |
+
|
| 33 |
+
# Cache directories - different for local vs deployed
|
| 34 |
+
# Local: Use default user cache locations (don't override)
|
| 35 |
+
# Deployed: Use PROJECT_DIR/.cache
|
| 36 |
+
if IS_DEPLOYED:
|
| 37 |
+
CACHE_DIR = PROJECT_DIR / ".cache"
|
| 38 |
+
HF_CACHE_DIR = CACHE_DIR / "huggingface"
|
| 39 |
+
STREAMLIT_CACHE_DIR = CACHE_DIR / "streamlit"
|
| 40 |
+
else:
|
| 41 |
+
# For local, use default user cache (let libraries use their defaults)
|
| 42 |
+
HF_CACHE_DIR = None # Will use HF defaults (~/.cache/huggingface)
|
| 43 |
+
STREAMLIT_CACHE_DIR = None # Will use Streamlit defaults
|
| 44 |
+
|
| 45 |
+
# Application directories
|
| 46 |
+
FEEDBACK_DIR = PROJECT_DIR / "feedback"
|
| 47 |
+
CONVERSATIONS_DIR = PROJECT_DIR / "conversations"
|
| 48 |
+
STREAMLIT_CONFIG_DIR = PROJECT_DIR / ".streamlit"
|
| 49 |
+
|
| 50 |
+
# Log the configuration
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
print(f"IS_DEPLOYED: {IS_DEPLOYED}")
|
| 53 |
+
print(f"PROJECT_DIR: {PROJECT_DIR}")
|
| 54 |
+
print(f"HF_CACHE_DIR: {HF_CACHE_DIR}")
|
| 55 |
+
print(f"FEEDBACK_DIR: {FEEDBACK_DIR}")
|
| 56 |
+
print(f"CONVERSATIONS_DIR: {CONVERSATIONS_DIR}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
src/feedback/__init__.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Management Module
|
| 3 |
+
|
| 4 |
+
This module provides a unified interface for handling user feedback,
|
| 5 |
+
including data preparation, validation, and Snowflake storage.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Any, List, Optional
|
| 9 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 10 |
+
|
| 11 |
+
from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
|
| 12 |
+
from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FeedbackManager:
|
| 16 |
+
"""
|
| 17 |
+
Unified manager for feedback operations.
|
| 18 |
+
|
| 19 |
+
This class provides a single interface for all feedback-related functionality,
|
| 20 |
+
including data preparation, validation, and storage.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize the FeedbackManager"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
|
| 29 |
+
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
|
| 30 |
+
transcript = []
|
| 31 |
+
for msg in messages:
|
| 32 |
+
if isinstance(msg, HumanMessage):
|
| 33 |
+
transcript.append({
|
| 34 |
+
"role": "user",
|
| 35 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 36 |
+
})
|
| 37 |
+
elif isinstance(msg, AIMessage):
|
| 38 |
+
transcript.append({
|
| 39 |
+
"role": "assistant",
|
| 40 |
+
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
|
| 41 |
+
})
|
| 42 |
+
return transcript
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
|
| 46 |
+
"""Build retrievals structure from retrieval history"""
|
| 47 |
+
retrievals = []
|
| 48 |
+
|
| 49 |
+
for entry in rag_retrieval_history:
|
| 50 |
+
# Get the user message that triggered this retrieval
|
| 51 |
+
# The entry has conversation_up_to which includes messages up to that point
|
| 52 |
+
conversation_up_to = entry.get("conversation_up_to", [])
|
| 53 |
+
|
| 54 |
+
# Find the last user message in conversation_up_to (this is the trigger)
|
| 55 |
+
user_message_trigger = ""
|
| 56 |
+
for msg_dict in reversed(conversation_up_to):
|
| 57 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 58 |
+
user_message_trigger = msg_dict.get("content", "")
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
# Fallback: if not found in conversation_up_to, get from actual messages
|
| 62 |
+
# This handles edge cases where conversation_up_to might be incomplete
|
| 63 |
+
if not user_message_trigger:
|
| 64 |
+
# Find which retrieval this is (0-indexed)
|
| 65 |
+
retrieval_idx = rag_retrieval_history.index(entry)
|
| 66 |
+
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
|
| 67 |
+
# because each retrieval is preceded by: user message, bot response, user message, ...
|
| 68 |
+
# But we need to account for the fact that the first retrieval happens after the first user message
|
| 69 |
+
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
|
| 70 |
+
if retrieval_idx < len(user_msgs):
|
| 71 |
+
user_message_trigger = str(user_msgs[retrieval_idx].content)
|
| 72 |
+
elif user_msgs:
|
| 73 |
+
# Fallback to last user message
|
| 74 |
+
user_message_trigger = str(user_msgs[-1].content)
|
| 75 |
+
|
| 76 |
+
# Get retrieved documents and truncate content to 100 chars
|
| 77 |
+
docs_retrieved = entry.get("docs_retrieved", [])
|
| 78 |
+
retrieved_docs = []
|
| 79 |
+
for doc in docs_retrieved:
|
| 80 |
+
doc_copy = doc.copy()
|
| 81 |
+
# Truncate content to 100 characters (keep all other fields)
|
| 82 |
+
if "content" in doc_copy:
|
| 83 |
+
doc_copy["content"] = doc_copy["content"][:100]
|
| 84 |
+
retrieved_docs.append(doc_copy)
|
| 85 |
+
|
| 86 |
+
retrievals.append({
|
| 87 |
+
"retrieved_docs": retrieved_docs,
|
| 88 |
+
"user_message_trigger": user_message_trigger
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
return retrievals
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def build_feedback_score_related_retrieval_docs(
|
| 95 |
+
is_feedback_about_last_retrieval: bool,
|
| 96 |
+
messages: List[Any],
|
| 97 |
+
rag_retrieval_history: List[Dict[str, Any]]
|
| 98 |
+
) -> Optional[Dict[str, Any]]:
|
| 99 |
+
"""Build feedback_score_related_retrieval_docs structure"""
|
| 100 |
+
if not rag_retrieval_history:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Get the relevant retrieval entry
|
| 104 |
+
if is_feedback_about_last_retrieval:
|
| 105 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 106 |
+
else:
|
| 107 |
+
# If feedback is about all retrievals, use the last one as default
|
| 108 |
+
relevant_entry = rag_retrieval_history[-1]
|
| 109 |
+
|
| 110 |
+
# Get conversation up to that point
|
| 111 |
+
conversation_up_to = relevant_entry.get("conversation_up_to", [])
|
| 112 |
+
|
| 113 |
+
# Convert to transcript format (role/content)
|
| 114 |
+
conversation_up_to_point = []
|
| 115 |
+
for msg_dict in conversation_up_to:
|
| 116 |
+
if msg_dict.get("type") == "HumanMessage":
|
| 117 |
+
conversation_up_to_point.append({
|
| 118 |
+
"role": "user",
|
| 119 |
+
"content": msg_dict.get("content", "")
|
| 120 |
+
})
|
| 121 |
+
elif msg_dict.get("type") == "AIMessage":
|
| 122 |
+
conversation_up_to_point.append({
|
| 123 |
+
"role": "assistant",
|
| 124 |
+
"content": msg_dict.get("content", "")
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# Get retrieved docs with full content (not truncated)
|
| 128 |
+
retrieved_docs = relevant_entry.get("docs_retrieved", [])
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"conversation_up_to_point": conversation_up_to_point,
|
| 132 |
+
"retrieved_docs": retrieved_docs
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 137 |
+
"""Create UserFeedback instance from dictionary"""
|
| 138 |
+
return create_feedback_from_dict(data)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 142 |
+
"""Save feedback to Snowflake"""
|
| 143 |
+
return save_to_snowflake(feedback, table_name)
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 147 |
+
"""Generate Snowflake schema SQL"""
|
| 148 |
+
return generate_snowflake_schema_sql(table_name)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
|
| 152 |
+
|
src/feedback/feedback_schema.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Schema for RAG Chatbot
|
| 3 |
+
|
| 4 |
+
This module defines dataclasses for feedback data structures
|
| 5 |
+
and provides Snowflake schema generation.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from dataclasses import dataclass, asdict, field
|
| 10 |
+
from typing import List, Optional, Dict, Any, Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class RetrievedDocument:
|
| 16 |
+
"""Single retrieved document metadata"""
|
| 17 |
+
doc_id: str
|
| 18 |
+
filename: str
|
| 19 |
+
page: int
|
| 20 |
+
score: float
|
| 21 |
+
content: str
|
| 22 |
+
metadata: Dict[str, Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class RetrievalEntry:
|
| 27 |
+
"""Single retrieval operation metadata"""
|
| 28 |
+
rag_query: str
|
| 29 |
+
documents_retrieved: List[RetrievedDocument]
|
| 30 |
+
conversation_length: int
|
| 31 |
+
filters_applied: Optional[Dict[str, Any]] = None
|
| 32 |
+
timestamp: Optional[float] = None
|
| 33 |
+
_raw_data: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class UserFeedback:
|
| 38 |
+
"""User feedback submission data"""
|
| 39 |
+
feedback_id: str
|
| 40 |
+
open_ended_feedback: Optional[str]
|
| 41 |
+
score: int
|
| 42 |
+
is_feedback_about_last_retrieval: bool
|
| 43 |
+
conversation_id: str
|
| 44 |
+
timestamp: float
|
| 45 |
+
message_count: int
|
| 46 |
+
has_retrievals: bool
|
| 47 |
+
retrieval_count: int
|
| 48 |
+
transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
|
| 49 |
+
retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
|
| 50 |
+
feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
|
| 51 |
+
retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
|
| 52 |
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 53 |
+
|
| 54 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 55 |
+
"""Convert to dictionary with nested data structures"""
|
| 56 |
+
result = asdict(self)
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
def to_snowflake_schema(self) -> Dict[str, Any]:
|
| 60 |
+
"""Generate Snowflake schema for this dataclass"""
|
| 61 |
+
schema = {
|
| 62 |
+
"feedback_id": "VARCHAR(255)",
|
| 63 |
+
"open_ended_feedback": "VARCHAR(16777216)", # Large text
|
| 64 |
+
"score": "INTEGER",
|
| 65 |
+
"is_feedback_about_last_retrieval": "BOOLEAN",
|
| 66 |
+
"conversation_id": "VARCHAR(255)",
|
| 67 |
+
"timestamp": "NUMBER(20, 0)",
|
| 68 |
+
"message_count": "INTEGER",
|
| 69 |
+
"has_retrievals": "BOOLEAN",
|
| 70 |
+
"retrieval_count": "INTEGER",
|
| 71 |
+
"transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
|
| 72 |
+
"retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
|
| 73 |
+
"feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
|
| 74 |
+
"retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
|
| 75 |
+
"created_at": "TIMESTAMP_NTZ",
|
| 76 |
+
# transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
| 77 |
+
# retrievals structure: [
|
| 78 |
+
# {
|
| 79 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
|
| 80 |
+
# "user_message_trigger": "final user message that triggered this retrieval"
|
| 81 |
+
# },
|
| 82 |
+
# ...
|
| 83 |
+
# ]
|
| 84 |
+
# feedback_score_related_retrieval_docs structure: {
|
| 85 |
+
# "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
|
| 86 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
|
| 87 |
+
# }
|
| 88 |
+
}
|
| 89 |
+
return schema
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
|
| 93 |
+
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 94 |
+
schema = cls.to_snowflake_schema(None)
|
| 95 |
+
|
| 96 |
+
columns = []
|
| 97 |
+
for col_name, col_type in schema.items():
|
| 98 |
+
nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
|
| 99 |
+
columns.append(f" {col_name} {col_type} {nullable}")
|
| 100 |
+
|
| 101 |
+
# Build SQL string properly
|
| 102 |
+
columns_str = ",\n".join(columns)
|
| 103 |
+
|
| 104 |
+
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 105 |
+
{columns_str},
|
| 106 |
+
PRIMARY KEY (feedback_id)
|
| 107 |
+
)
|
| 108 |
+
CLUSTER BY (timestamp, conversation_id, score);
|
| 109 |
+
-- Note: Snowflake doesn't support traditional indexes on regular tables.
|
| 110 |
+
-- Instead, we use CLUSTER BY to optimize queries on these columns.
|
| 111 |
+
-- Snowflake automatically maintains clustering for efficient querying.
|
| 112 |
+
-- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
|
| 113 |
+
-- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
|
| 114 |
+
"""
|
| 115 |
+
return sql
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Snowflake variant schema for retrieved_data array
|
| 119 |
+
RETRIEVAL_ENTRY_SCHEMA = {
|
| 120 |
+
"rag_query": "VARCHAR",
|
| 121 |
+
"documents_retrieved": "ARRAY", # Array of document objects
|
| 122 |
+
"conversation_length": "INTEGER",
|
| 123 |
+
"filters_applied": "OBJECT",
|
| 124 |
+
"timestamp": "NUMBER"
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
DOCUMENT_SCHEMA = {
|
| 128 |
+
"doc_id": "VARCHAR",
|
| 129 |
+
"filename": "VARCHAR",
|
| 130 |
+
"page": "INTEGER",
|
| 131 |
+
"score": "DOUBLE",
|
| 132 |
+
"content": "VARCHAR(16777216)",
|
| 133 |
+
"metadata": "OBJECT"
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 138 |
+
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 139 |
+
if table_name is None:
|
| 140 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 141 |
+
return UserFeedback.get_snowflake_create_table_sql(table_name)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 145 |
+
"""Create UserFeedback instance from dictionary"""
|
| 146 |
+
return UserFeedback(
|
| 147 |
+
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 148 |
+
open_ended_feedback=data.get("open_ended_feedback"),
|
| 149 |
+
score=data["score"],
|
| 150 |
+
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 151 |
+
conversation_id=data["conversation_id"],
|
| 152 |
+
timestamp=data["timestamp"],
|
| 153 |
+
message_count=data["message_count"],
|
| 154 |
+
has_retrievals=data["has_retrievals"],
|
| 155 |
+
retrieval_count=data["retrieval_count"],
|
| 156 |
+
transcript=data.get("transcript", []),
|
| 157 |
+
retrievals=data.get("retrievals", []),
|
| 158 |
+
feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
|
| 159 |
+
retrieved_data=data.get("retrieved_data")
|
| 160 |
+
)
|
| 161 |
+
|
src/feedback/snowflake_connector.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Snowflake Connector for Feedback System
|
| 3 |
+
|
| 4 |
+
This module handles inserting user feedback into Snowflake.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from .feedback_schema import UserFeedback
|
| 12 |
+
|
| 13 |
+
# Try to import snowflake connector
|
| 14 |
+
try:
|
| 15 |
+
import snowflake.connector
|
| 16 |
+
SNOWFLAKE_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SNOWFLAKE_AVAILABLE = False
|
| 19 |
+
logging.warning("β οΈ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SnowflakeFeedbackConnector:
|
| 27 |
+
"""Connector for inserting feedback into Snowflake"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
user: str,
|
| 32 |
+
password: str,
|
| 33 |
+
account: str,
|
| 34 |
+
warehouse: str,
|
| 35 |
+
database: str = "SNOWFLAKE_LEARNING",
|
| 36 |
+
schema: str = "PUBLIC"
|
| 37 |
+
):
|
| 38 |
+
self.user = user
|
| 39 |
+
self.password = password
|
| 40 |
+
self.account = account
|
| 41 |
+
self.warehouse = warehouse
|
| 42 |
+
self.database = database
|
| 43 |
+
self.schema = schema
|
| 44 |
+
self._connection = None
|
| 45 |
+
|
| 46 |
+
def connect(self):
|
| 47 |
+
"""Establish Snowflake connection"""
|
| 48 |
+
if not SNOWFLAKE_AVAILABLE:
|
| 49 |
+
raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
|
| 50 |
+
|
| 51 |
+
logger.info("=" * 80)
|
| 52 |
+
logger.info("π SNOWFLAKE CONNECTION: Attempting to connect...")
|
| 53 |
+
logger.info(f" - Account: {self.account}")
|
| 54 |
+
logger.info(f" - Warehouse: {self.warehouse}")
|
| 55 |
+
logger.info(f" - Database: {self.database}")
|
| 56 |
+
logger.info(f" - Schema: {self.schema}")
|
| 57 |
+
logger.info(f" - User: {self.user}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
self._connection = snowflake.connector.connect(
|
| 61 |
+
user=self.user,
|
| 62 |
+
password=self.password,
|
| 63 |
+
account=self.account,
|
| 64 |
+
warehouse=self.warehouse
|
| 65 |
+
# Don't set database/schema in connection - we'll do it per query
|
| 66 |
+
)
|
| 67 |
+
logger.info("β
SNOWFLAKE CONNECTION: Successfully connected")
|
| 68 |
+
logger.info("=" * 80)
|
| 69 |
+
print(f"β
Connected to Snowflake: {self.database}.{self.schema}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"β SNOWFLAKE CONNECTION FAILED: {e}")
|
| 72 |
+
logger.error("=" * 80)
|
| 73 |
+
print(f"β Failed to connect to Snowflake: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def disconnect(self):
|
| 77 |
+
"""Close Snowflake connection"""
|
| 78 |
+
if self._connection:
|
| 79 |
+
self._connection.close()
|
| 80 |
+
print("β
Disconnected from Snowflake")
|
| 81 |
+
|
| 82 |
+
def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 83 |
+
"""Insert a single feedback record into Snowflake"""
|
| 84 |
+
logger.info("=" * 80)
|
| 85 |
+
logger.info("π SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 87 |
+
|
| 88 |
+
# Get table name from parameter, env var, or default
|
| 89 |
+
if table_name is None:
|
| 90 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 91 |
+
|
| 92 |
+
if not self._connection:
|
| 93 |
+
logger.error("β Not connected to Snowflake. Call connect() first.")
|
| 94 |
+
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
logger.info("π VALIDATION: Validating feedback data structure...")
|
| 98 |
+
|
| 99 |
+
# Validate feedback object
|
| 100 |
+
validation_errors = []
|
| 101 |
+
if not feedback.feedback_id:
|
| 102 |
+
validation_errors.append("Missing feedback_id")
|
| 103 |
+
if feedback.score is None:
|
| 104 |
+
validation_errors.append("Missing score")
|
| 105 |
+
if feedback.timestamp is None:
|
| 106 |
+
validation_errors.append("Missing timestamp")
|
| 107 |
+
|
| 108 |
+
if validation_errors:
|
| 109 |
+
logger.error(f"β VALIDATION FAILED: {validation_errors}")
|
| 110 |
+
return False
|
| 111 |
+
else:
|
| 112 |
+
logger.info("β
VALIDATION PASSED: All required fields present")
|
| 113 |
+
|
| 114 |
+
logger.info("π Data Summary:")
|
| 115 |
+
logger.info(f" - Feedback ID: {feedback.feedback_id}")
|
| 116 |
+
logger.info(f" - Score: {feedback.score}")
|
| 117 |
+
logger.info(f" - Conversation ID: {feedback.conversation_id}")
|
| 118 |
+
logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
|
| 119 |
+
logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
|
| 120 |
+
logger.info(f" - Message Count: {feedback.message_count}")
|
| 121 |
+
logger.info(f" - Timestamp: {feedback.timestamp}")
|
| 122 |
+
|
| 123 |
+
cursor = self._connection.cursor()
|
| 124 |
+
logger.info("β
SNOWFLAKE CONNECTION: Cursor created")
|
| 125 |
+
|
| 126 |
+
# Set database and schema context
|
| 127 |
+
logger.info(f"π§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
|
| 128 |
+
try:
|
| 129 |
+
cursor.execute(f'USE DATABASE "{self.database}"')
|
| 130 |
+
cursor.execute(f'USE SCHEMA "{self.schema}"')
|
| 131 |
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
| 132 |
+
current_db, current_schema = cursor.fetchone()
|
| 133 |
+
logger.info(f"β
Current context verified: Database={current_db}, Schema={current_schema}")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"β Could not set context: {e}")
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
# Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 139 |
+
logger.info("π§ DATA PREPARATION: Preparing VARIANT columns...")
|
| 140 |
+
feedback_dict = feedback.to_dict()
|
| 141 |
+
|
| 142 |
+
# Prepare transcript (ARRAY) - convert to JSON string
|
| 143 |
+
transcript_raw = feedback_dict.get('transcript', [])
|
| 144 |
+
if transcript_raw:
|
| 145 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 146 |
+
transcript_for_db = json.dumps(transcript_raw)
|
| 147 |
+
logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
|
| 148 |
+
else:
|
| 149 |
+
transcript_for_db = None
|
| 150 |
+
logger.info(" - Transcript: None")
|
| 151 |
+
|
| 152 |
+
# Prepare retrievals (ARRAY) - convert to JSON string
|
| 153 |
+
retrievals_raw = feedback_dict.get('retrievals', [])
|
| 154 |
+
if retrievals_raw:
|
| 155 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 156 |
+
retrievals_for_db = json.dumps(retrievals_raw)
|
| 157 |
+
logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
|
| 158 |
+
else:
|
| 159 |
+
retrievals_for_db = None
|
| 160 |
+
logger.info(" - Retrievals: None")
|
| 161 |
+
|
| 162 |
+
# Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
|
| 163 |
+
feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
|
| 164 |
+
if feedback_score_related_raw:
|
| 165 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 166 |
+
feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
|
| 167 |
+
logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
|
| 168 |
+
else:
|
| 169 |
+
feedback_score_related_for_db = None
|
| 170 |
+
logger.info(" - Feedback score related docs: None")
|
| 171 |
+
|
| 172 |
+
# Prepare retrieved_data (preserved old column) - convert to JSON string
|
| 173 |
+
retrieved_data_raw = feedback_dict.get('retrieved_data')
|
| 174 |
+
if retrieved_data_raw:
|
| 175 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 176 |
+
retrieved_data_for_db = json.dumps(retrieved_data_raw)
|
| 177 |
+
logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
|
| 178 |
+
else:
|
| 179 |
+
retrieved_data_for_db = None
|
| 180 |
+
logger.info(" - Retrieved data (preserved): None")
|
| 181 |
+
|
| 182 |
+
# Build SQL with new column structure
|
| 183 |
+
# Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
|
| 184 |
+
sql = f"""INSERT INTO {table_name} (
|
| 185 |
+
feedback_id,
|
| 186 |
+
open_ended_feedback,
|
| 187 |
+
score,
|
| 188 |
+
is_feedback_about_last_retrieval,
|
| 189 |
+
conversation_id,
|
| 190 |
+
timestamp,
|
| 191 |
+
message_count,
|
| 192 |
+
has_retrievals,
|
| 193 |
+
retrieval_count,
|
| 194 |
+
transcript,
|
| 195 |
+
retrievals,
|
| 196 |
+
feedback_score_related_retrieval_docs,
|
| 197 |
+
retrieved_data,
|
| 198 |
+
created_at
|
| 199 |
+
) VALUES (
|
| 200 |
+
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 201 |
+
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 202 |
+
%(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
|
| 203 |
+
%(retrieved_data)s, %(created_at)s
|
| 204 |
+
)"""
|
| 205 |
+
|
| 206 |
+
logger.info("π SQL PREPARATION: Building INSERT statement...")
|
| 207 |
+
logger.info(f" - Target table: {table_name}")
|
| 208 |
+
logger.info(f" - Database: {self.database}")
|
| 209 |
+
logger.info(f" - Schema: {self.schema}")
|
| 210 |
+
|
| 211 |
+
# Prepare parameters
|
| 212 |
+
# Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 213 |
+
params = {
|
| 214 |
+
'feedback_id': feedback.feedback_id,
|
| 215 |
+
'open_ended_feedback': feedback.open_ended_feedback,
|
| 216 |
+
'score': feedback.score,
|
| 217 |
+
'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
|
| 218 |
+
'conversation_id': feedback.conversation_id,
|
| 219 |
+
'timestamp': int(feedback.timestamp),
|
| 220 |
+
'message_count': feedback.message_count,
|
| 221 |
+
'has_retrievals': feedback.has_retrievals,
|
| 222 |
+
'retrieval_count': feedback.retrieval_count,
|
| 223 |
+
'transcript': transcript_for_db, # JSON string
|
| 224 |
+
'retrievals': retrievals_for_db, # JSON string
|
| 225 |
+
'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
|
| 226 |
+
'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
|
| 227 |
+
'created_at': feedback.created_at
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# Execute insert
|
| 231 |
+
logger.info("π SQL EXECUTION: Executing INSERT query...")
|
| 232 |
+
cursor.execute(sql, params)
|
| 233 |
+
|
| 234 |
+
logger.info("β
SQL EXECUTION: Query executed successfully")
|
| 235 |
+
logger.info(f" - Rows affected: 1")
|
| 236 |
+
logger.info(f" - Status: SUCCESS")
|
| 237 |
+
|
| 238 |
+
cursor.close()
|
| 239 |
+
logger.info("β
SNOWFLAKE INSERT: Feedback inserted successfully")
|
| 240 |
+
logger.info(f"π Inserted feedback: {feedback.feedback_id}")
|
| 241 |
+
logger.info("=" * 80)
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
# Check if it's a Snowflake error
|
| 246 |
+
if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
|
| 247 |
+
logger.error(f"β SQL EXECUTION ERROR: {e}")
|
| 248 |
+
logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
|
| 249 |
+
logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
|
| 250 |
+
else:
|
| 251 |
+
logger.error(f"β SNOWFLAKE INSERT FAILED: {type(e).__name__}")
|
| 252 |
+
logger.error(f" - Error: {e}")
|
| 253 |
+
logger.error("=" * 80)
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def __enter__(self):
|
| 257 |
+
"""Context manager entry"""
|
| 258 |
+
self.connect()
|
| 259 |
+
return self
|
| 260 |
+
|
| 261 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 262 |
+
"""Context manager exit"""
|
| 263 |
+
self.disconnect()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
| 267 |
+
"""Create Snowflake connector from environment variables"""
|
| 268 |
+
user = os.getenv("SNOWFLAKE_USER")
|
| 269 |
+
password = os.getenv("SNOWFLAKE_PASSWORD")
|
| 270 |
+
account = os.getenv("SNOWFLAKE_ACCOUNT")
|
| 271 |
+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
|
| 272 |
+
database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
|
| 273 |
+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
| 274 |
+
|
| 275 |
+
if not all([user, password, account, warehouse]):
|
| 276 |
+
print("β οΈ Snowflake credentials not found in environment variables")
|
| 277 |
+
print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
return SnowflakeFeedbackConnector(
|
| 281 |
+
user=user,
|
| 282 |
+
password=password,
|
| 283 |
+
account=account,
|
| 284 |
+
warehouse=warehouse,
|
| 285 |
+
database=database,
|
| 286 |
+
schema=schema
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 291 |
+
"""Helper function to save feedback to Snowflake"""
|
| 292 |
+
logger.info("=" * 80)
|
| 293 |
+
logger.info("π΅ SNOWFLAKE SAVE: Starting save process")
|
| 294 |
+
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 295 |
+
|
| 296 |
+
# Get table name from parameter or env var
|
| 297 |
+
if table_name is None:
|
| 298 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 299 |
+
|
| 300 |
+
connector = get_snowflake_connector_from_env()
|
| 301 |
+
|
| 302 |
+
if not connector:
|
| 303 |
+
logger.warning("β οΈ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
|
| 304 |
+
logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 305 |
+
logger.info("=" * 80)
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
logger.info("π‘ SNOWFLAKE SAVE: Establishing connection...")
|
| 310 |
+
connector.connect()
|
| 311 |
+
logger.info("β
SNOWFLAKE SAVE: Connection established")
|
| 312 |
+
|
| 313 |
+
logger.info("π₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 314 |
+
success = connector.insert_feedback(feedback, table_name=table_name)
|
| 315 |
+
|
| 316 |
+
logger.info("π SNOWFLAKE SAVE: Disconnecting...")
|
| 317 |
+
connector.disconnect()
|
| 318 |
+
|
| 319 |
+
if success:
|
| 320 |
+
logger.info("β
SNOWFLAKE SAVE: Successfully saved feedback")
|
| 321 |
+
else:
|
| 322 |
+
logger.error("β SNOWFLAKE SAVE: Failed to save feedback")
|
| 323 |
+
|
| 324 |
+
logger.info("=" * 80)
|
| 325 |
+
return success
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"β SNOWFLAKE SAVE ERROR: {type(e).__name__}")
|
| 328 |
+
logger.error(f" - Error: {e}")
|
| 329 |
+
logger.info("=" * 80)
|
| 330 |
+
return False
|
| 331 |
+
|
src/gemini/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Integration Module
|
| 3 |
+
|
| 4 |
+
This module provides integration with Google Gemini File Search API
|
| 5 |
+
for RAG functionality using Gemini's built-in file search capabilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
|
| 9 |
+
|
| 10 |
+
__all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
|
| 11 |
+
|
src/gemini/file_search.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini File Search Client
|
| 3 |
+
|
| 4 |
+
Handles interaction with Google Gemini File Search API for RAG.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai import types
|
| 17 |
+
GEMINI_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
GEMINI_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class GeminiFileSearchResult:
|
| 24 |
+
"""Result from Gemini File Search query"""
|
| 25 |
+
answer: str
|
| 26 |
+
sources: List[Dict[str, Any]] # List of document references
|
| 27 |
+
grounding_metadata: Optional[Dict[str, Any]] = None
|
| 28 |
+
query: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GeminiFileSearchClient:
|
| 32 |
+
"""Client for interacting with Gemini File Search API"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
|
| 35 |
+
"""
|
| 36 |
+
Initialize Gemini File Search client.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
|
| 40 |
+
store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
|
| 41 |
+
"""
|
| 42 |
+
if not GEMINI_AVAILABLE:
|
| 43 |
+
raise ImportError("google-genai package not installed. Install with: pip install google-genai")
|
| 44 |
+
|
| 45 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 46 |
+
if not self.api_key:
|
| 47 |
+
raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
|
| 48 |
+
|
| 49 |
+
store_name_raw = store_name or os.getenv("GEMINI_FILESTORE_NAME")
|
| 50 |
+
if not store_name_raw:
|
| 51 |
+
raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
|
| 52 |
+
|
| 53 |
+
# Normalize store name: API expects the FULL path format (fileSearchStores/xxx)
|
| 54 |
+
# If just the ID is provided, construct the full path
|
| 55 |
+
if store_name_raw.startswith("fileSearchStores/"):
|
| 56 |
+
self.store_name = store_name_raw # Already full path
|
| 57 |
+
else:
|
| 58 |
+
# Just the ID provided, construct full path
|
| 59 |
+
self.store_name = f"fileSearchStores/{store_name_raw}"
|
| 60 |
+
|
| 61 |
+
logger.info(f"π¦ Using file search store: {self.store_name}")
|
| 62 |
+
|
| 63 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 64 |
+
self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
|
| 65 |
+
|
| 66 |
+
def search(
|
| 67 |
+
self,
|
| 68 |
+
query: str,
|
| 69 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 70 |
+
model: Optional[str] = None
|
| 71 |
+
) -> GeminiFileSearchResult:
|
| 72 |
+
"""
|
| 73 |
+
Search using Gemini File Search.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
query: User query
|
| 77 |
+
filters: Optional filters (year, source, district, etc.)
|
| 78 |
+
model: Model to use (defaults to gemini-2.5-flash)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
GeminiFileSearchResult with answer and sources
|
| 82 |
+
"""
|
| 83 |
+
model = model or self.model
|
| 84 |
+
|
| 85 |
+
# Build filter context for the query if filters are provided
|
| 86 |
+
# Gemini File Search doesn't support explicit filters in the API,
|
| 87 |
+
# so we add them as context in the query
|
| 88 |
+
filter_context = ""
|
| 89 |
+
if filters:
|
| 90 |
+
filter_parts = []
|
| 91 |
+
if filters.get("year"):
|
| 92 |
+
years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
|
| 93 |
+
filter_parts.append(f"Year: {', '.join(years)}")
|
| 94 |
+
if filters.get("sources"):
|
| 95 |
+
sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
|
| 96 |
+
filter_parts.append(f"Source: {', '.join(sources)}")
|
| 97 |
+
if filters.get("district"):
|
| 98 |
+
districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
|
| 99 |
+
filter_parts.append(f"District: {', '.join(districts)}")
|
| 100 |
+
if filters.get("filenames"):
|
| 101 |
+
filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
|
| 102 |
+
filter_parts.append(f"Filename: {', '.join(filenames)}")
|
| 103 |
+
|
| 104 |
+
if filter_parts:
|
| 105 |
+
filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
|
| 106 |
+
|
| 107 |
+
# Combine query with filter context
|
| 108 |
+
# Add comprehensive system instructions similar to multi-agent system
|
| 109 |
+
system_instructions = """You are a helpful audit report assistant specialized in analyzing government audit reports from Uganda's Office of the Auditor General.
|
| 110 |
+
|
| 111 |
+
CRITICAL RULES:
|
| 112 |
+
1. **NO HALLUCINATION**: Only use information that is explicitly stated in the retrieved documents. Do not make up facts, numbers, or details.
|
| 113 |
+
2. **Document References**: Always cite which documents you're using with [Doc i] references at the end of sentences that use specific information.
|
| 114 |
+
3. **Formatting**: Structure your response with clear paragraphs, bullet points, or sections for readability.
|
| 115 |
+
4. **Accuracy**: If the retrieved documents don't contain the requested information, explicitly state "The retrieved documents do not contain information about [topic]."
|
| 116 |
+
5. **Years and Data**: Pay careful attention to years mentioned in documents. If a user asks about a specific year but documents show different years, explicitly state this.
|
| 117 |
+
6. **District/Source Names**: Use the exact district and source names as they appear in the document metadata (e.g., "Kalangala" not "Kalagala").
|
| 118 |
+
7. **Financial Data**: When providing financial figures, include the currency (UGX) and be precise about amounts.
|
| 119 |
+
8. **Conversational Tone**: Be helpful, clear, and conversational while maintaining accuracy.
|
| 120 |
+
|
| 121 |
+
IMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents."""
|
| 122 |
+
|
| 123 |
+
# Combine system instructions with query
|
| 124 |
+
full_query = f"{system_instructions}\n\nUser Question: {query}{filter_context}\n\nPlease provide a detailed, well-formatted response with proper document references."
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Generate content with file search
|
| 128 |
+
# Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
|
| 129 |
+
# Try with full path format first, then fallback to just ID if needed
|
| 130 |
+
store_name_to_try = self.store_name
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Try the documented format first with full path
|
| 134 |
+
response = self.client.models.generate_content(
|
| 135 |
+
model=model,
|
| 136 |
+
contents=full_query,
|
| 137 |
+
config=types.GenerateContentConfig(
|
| 138 |
+
tools=[
|
| 139 |
+
types.Tool(
|
| 140 |
+
file_search=types.FileSearch(
|
| 141 |
+
file_search_store_names=[store_name_to_try]
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
except Exception as api_error:
|
| 148 |
+
error_str = str(api_error).lower()
|
| 149 |
+
# If format error, try with just the ID (without fileSearchStores/ prefix)
|
| 150 |
+
if 'format' in error_str or 'invalid' in error_str or 'too long' in error_str:
|
| 151 |
+
logger.warning(f"Full path format failed, trying with just store ID: {api_error}")
|
| 152 |
+
# Extract just the ID part
|
| 153 |
+
if store_name_to_try.startswith("fileSearchStores/"):
|
| 154 |
+
store_id = store_name_to_try.split("/", 1)[1]
|
| 155 |
+
store_name_to_try = store_id
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
response = self.client.models.generate_content(
|
| 159 |
+
model=model,
|
| 160 |
+
contents=full_query,
|
| 161 |
+
config=types.GenerateContentConfig(
|
| 162 |
+
tools=[
|
| 163 |
+
types.Tool(
|
| 164 |
+
file_search=types.FileSearch(
|
| 165 |
+
file_search_store_names=[store_name_to_try]
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
except Exception as e2:
|
| 172 |
+
raise Exception(f"Failed to call Gemini API with both formats. Full path error: {api_error}, ID-only error: {e2}")
|
| 173 |
+
else:
|
| 174 |
+
# Try alternative dict format
|
| 175 |
+
logger.warning(f"Primary API format failed, trying alternative: {api_error}")
|
| 176 |
+
try:
|
| 177 |
+
response = self.client.models.generate_content(
|
| 178 |
+
model=model,
|
| 179 |
+
contents=full_query,
|
| 180 |
+
tools=[{
|
| 181 |
+
"file_search": {
|
| 182 |
+
"file_search_store_names": [store_name_to_try]
|
| 183 |
+
}
|
| 184 |
+
}]
|
| 185 |
+
)
|
| 186 |
+
except Exception as e2:
|
| 187 |
+
raise Exception(f"Failed to call Gemini API: {e2}")
|
| 188 |
+
|
| 189 |
+
# Extract answer
|
| 190 |
+
answer = ""
|
| 191 |
+
if hasattr(response, 'text'):
|
| 192 |
+
answer = response.text
|
| 193 |
+
elif hasattr(response, 'candidates') and response.candidates:
|
| 194 |
+
# Try to get text from first candidate
|
| 195 |
+
candidate = response.candidates[0]
|
| 196 |
+
if hasattr(candidate, 'content') and candidate.content:
|
| 197 |
+
if hasattr(candidate.content, 'parts'):
|
| 198 |
+
text_parts = []
|
| 199 |
+
for part in candidate.content.parts:
|
| 200 |
+
if hasattr(part, 'text'):
|
| 201 |
+
text_parts.append(part.text)
|
| 202 |
+
answer = " ".join(text_parts)
|
| 203 |
+
elif isinstance(candidate.content, str):
|
| 204 |
+
answer = candidate.content
|
| 205 |
+
else:
|
| 206 |
+
answer = str(response)
|
| 207 |
+
|
| 208 |
+
# Extract grounding metadata (document references)
|
| 209 |
+
sources = []
|
| 210 |
+
grounding_metadata = None
|
| 211 |
+
|
| 212 |
+
logger.info(f"π Extracting sources from Gemini response...")
|
| 213 |
+
|
| 214 |
+
if hasattr(response, 'candidates') and response.candidates:
|
| 215 |
+
candidate = response.candidates[0]
|
| 216 |
+
logger.info(f" Found candidate, checking for grounding_metadata...")
|
| 217 |
+
|
| 218 |
+
# Get grounding metadata
|
| 219 |
+
if hasattr(candidate, 'grounding_metadata'):
|
| 220 |
+
grounding_metadata = candidate.grounding_metadata
|
| 221 |
+
logger.info(f" Found grounding_metadata: {type(grounding_metadata)}")
|
| 222 |
+
|
| 223 |
+
# Extract source documents from grounding metadata
|
| 224 |
+
# Handle different response formats
|
| 225 |
+
grounding_chunks = None
|
| 226 |
+
if hasattr(grounding_metadata, 'grounding_chunks'):
|
| 227 |
+
grounding_chunks = grounding_metadata.grounding_chunks
|
| 228 |
+
logger.info(f" Found grounding_chunks (attr): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 229 |
+
elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
|
| 230 |
+
grounding_chunks = grounding_metadata['grounding_chunks']
|
| 231 |
+
logger.info(f" Found grounding_chunks (dict): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 232 |
+
elif hasattr(grounding_metadata, '__dict__'):
|
| 233 |
+
# Try to access as object attributes
|
| 234 |
+
metadata_dict = grounding_metadata.__dict__
|
| 235 |
+
if 'grounding_chunks' in metadata_dict:
|
| 236 |
+
grounding_chunks = metadata_dict['grounding_chunks']
|
| 237 |
+
logger.info(f" Found grounding_chunks (__dict__): {len(grounding_chunks) if grounding_chunks else 0}")
|
| 238 |
+
|
| 239 |
+
if grounding_chunks:
|
| 240 |
+
logger.info(f" Processing {len(grounding_chunks)} grounding chunks...")
|
| 241 |
+
for idx, chunk in enumerate(grounding_chunks):
|
| 242 |
+
# Handle both object and dict formats
|
| 243 |
+
try:
|
| 244 |
+
if isinstance(chunk, dict):
|
| 245 |
+
chunk_data = chunk
|
| 246 |
+
else:
|
| 247 |
+
# Object format - convert to dict-like access
|
| 248 |
+
chunk_data = {}
|
| 249 |
+
if hasattr(chunk, 'chunk'):
|
| 250 |
+
chunk_obj = chunk.chunk
|
| 251 |
+
chunk_data['chunk'] = {
|
| 252 |
+
'text': getattr(chunk_obj, 'text', ''),
|
| 253 |
+
'file_name': getattr(chunk_obj, 'file_name', '')
|
| 254 |
+
}
|
| 255 |
+
if hasattr(chunk, 'relevance_score'):
|
| 256 |
+
score_obj = chunk.relevance_score
|
| 257 |
+
chunk_data['relevance_score'] = {
|
| 258 |
+
'score': getattr(score_obj, 'score', 0.0)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
chunk_info = chunk_data.get('chunk', {})
|
| 262 |
+
text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
|
| 263 |
+
file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
|
| 264 |
+
|
| 265 |
+
# Try to extract file URI and parse metadata from it
|
| 266 |
+
file_uri = chunk_info.get('file_uri', '') if isinstance(chunk_info, dict) else ''
|
| 267 |
+
|
| 268 |
+
# Also check for 'web' attribute (GroundingChunkData structure)
|
| 269 |
+
if hasattr(chunk, 'web') and chunk.web:
|
| 270 |
+
web_data = chunk.web
|
| 271 |
+
file_uri = getattr(web_data, 'file_uri', '') or file_uri
|
| 272 |
+
file_name = getattr(web_data, 'title', '') or getattr(web_data, 'filename', '') or file_name
|
| 273 |
+
text = getattr(web_data, 'text', '') or getattr(web_data, 'content', '') or text
|
| 274 |
+
|
| 275 |
+
# Check retrieved_context - this is where the actual data seems to be!
|
| 276 |
+
if hasattr(chunk, 'retrieved_context') and chunk.retrieved_context:
|
| 277 |
+
rc = chunk.retrieved_context
|
| 278 |
+
# Get text content
|
| 279 |
+
if hasattr(rc, 'text'):
|
| 280 |
+
text = getattr(rc, 'text', '') or text
|
| 281 |
+
# Get document name
|
| 282 |
+
if hasattr(rc, 'document_name'):
|
| 283 |
+
doc_name = getattr(rc, 'document_name', '')
|
| 284 |
+
if doc_name:
|
| 285 |
+
file_name = doc_name or file_name
|
| 286 |
+
|
| 287 |
+
# Fallback: Parse from string representation if we still don't have filename
|
| 288 |
+
if not file_name:
|
| 289 |
+
chunk_str = str(chunk)
|
| 290 |
+
import re
|
| 291 |
+
# Look for PDF filenames
|
| 292 |
+
pdf_match = re.search(r"([A-Za-z0-9\s_-]+\.pdf)", chunk_str)
|
| 293 |
+
if pdf_match:
|
| 294 |
+
file_name = pdf_match.group(1)
|
| 295 |
+
# Or look for title= pattern
|
| 296 |
+
if not file_name and 'title=' in chunk_str:
|
| 297 |
+
title_match = re.search(r"title=['\"]([^'\"]+)['\"]", chunk_str)
|
| 298 |
+
if title_match:
|
| 299 |
+
file_name = title_match.group(1)
|
| 300 |
+
|
| 301 |
+
if not file_name and file_uri:
|
| 302 |
+
# Extract filename from URI if available
|
| 303 |
+
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
|
| 304 |
+
|
| 305 |
+
score_data = chunk_data.get('relevance_score', {})
|
| 306 |
+
score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
|
| 307 |
+
|
| 308 |
+
if text or file_name: # Only add if we have content
|
| 309 |
+
source_info = {
|
| 310 |
+
"content": text,
|
| 311 |
+
"filename": file_name,
|
| 312 |
+
"score": score,
|
| 313 |
+
"file_uri": file_uri,
|
| 314 |
+
}
|
| 315 |
+
sources.append(source_info)
|
| 316 |
+
logger.info(f"π Extracted source {idx+1}: {file_name} (score: {score:.3f}, content length: {len(text)})")
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning(f"Error extracting chunk {idx+1} info: {e}")
|
| 319 |
+
import traceback
|
| 320 |
+
logger.debug(traceback.format_exc())
|
| 321 |
+
continue
|
| 322 |
+
else:
|
| 323 |
+
logger.warning(f" No grounding_chunks found in grounding_metadata")
|
| 324 |
+
else:
|
| 325 |
+
logger.warning(f" Candidate does not have grounding_metadata attribute")
|
| 326 |
+
|
| 327 |
+
# Also try to get file references from other parts of the response
|
| 328 |
+
# Sometimes Gemini includes file references in the response itself
|
| 329 |
+
if not sources or len(sources) == 0:
|
| 330 |
+
logger.info(f" No sources from grounding_metadata, trying alternative extraction...")
|
| 331 |
+
# Check if response has file references in other attributes
|
| 332 |
+
if hasattr(candidate, 'content') and candidate.content:
|
| 333 |
+
if hasattr(candidate.content, 'parts'):
|
| 334 |
+
for part in candidate.content.parts:
|
| 335 |
+
if hasattr(part, 'file_data'):
|
| 336 |
+
file_data = part.file_data
|
| 337 |
+
if hasattr(file_data, 'file_uri') or (isinstance(file_data, dict) and 'file_uri' in file_data):
|
| 338 |
+
file_uri = getattr(file_data, 'file_uri', None) or (file_data.get('file_uri') if isinstance(file_data, dict) else None)
|
| 339 |
+
if file_uri:
|
| 340 |
+
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
|
| 341 |
+
sources.append({
|
| 342 |
+
"content": "",
|
| 343 |
+
"filename": file_name,
|
| 344 |
+
"score": 0.0,
|
| 345 |
+
"file_uri": file_uri,
|
| 346 |
+
})
|
| 347 |
+
logger.info(f"π Extracted source from file_data: {file_name}")
|
| 348 |
+
|
| 349 |
+
logger.info(f"β
Total sources extracted: {len(sources)}")
|
| 350 |
+
|
| 351 |
+
return GeminiFileSearchResult(
|
| 352 |
+
answer=answer,
|
| 353 |
+
sources=sources,
|
| 354 |
+
grounding_metadata=grounding_metadata,
|
| 355 |
+
query=query
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
# Return error result
|
| 360 |
+
return GeminiFileSearchResult(
|
| 361 |
+
answer=f"I apologize, but I encountered an error: {str(e)}",
|
| 362 |
+
sources=[],
|
| 363 |
+
query=query
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
|
| 367 |
+
"""
|
| 368 |
+
Format Gemini sources to match the format expected by the UI.
|
| 369 |
+
|
| 370 |
+
Returns list of document-like objects compatible with existing display code.
|
| 371 |
+
"""
|
| 372 |
+
from langchain.docstore.document import Document
|
| 373 |
+
|
| 374 |
+
formatted_sources = []
|
| 375 |
+
|
| 376 |
+
for i, source in enumerate(result.sources):
|
| 377 |
+
filename = source.get("filename", "Unknown")
|
| 378 |
+
|
| 379 |
+
# Try to extract metadata from filename (e.g., "Kalangala DLG Report of Auditor General 2021.pdf")
|
| 380 |
+
year = None
|
| 381 |
+
district = None
|
| 382 |
+
source_name = "Gemini File Search"
|
| 383 |
+
|
| 384 |
+
# Parse filename for year
|
| 385 |
+
import re
|
| 386 |
+
year_match = re.search(r'\b(20\d{2})\b', filename)
|
| 387 |
+
if year_match:
|
| 388 |
+
year = int(year_match.group(1))
|
| 389 |
+
|
| 390 |
+
# Parse filename for district/source
|
| 391 |
+
if "Kalangala" in filename:
|
| 392 |
+
district = "Kalangala"
|
| 393 |
+
source_name = "Kalangala DLG"
|
| 394 |
+
elif "Gulu" in filename:
|
| 395 |
+
district = "Gulu"
|
| 396 |
+
source_name = "Gulu DLG"
|
| 397 |
+
elif "KCCA" in filename:
|
| 398 |
+
district = "Kampala"
|
| 399 |
+
source_name = "KCCA"
|
| 400 |
+
elif "MAAIF" in filename:
|
| 401 |
+
source_name = "MAAIF"
|
| 402 |
+
elif "MWTS" in filename:
|
| 403 |
+
source_name = "MWTS"
|
| 404 |
+
elif "Consolidated" in filename:
|
| 405 |
+
source_name = "Consolidated"
|
| 406 |
+
|
| 407 |
+
# Create a Document object compatible with existing code
|
| 408 |
+
doc = Document(
|
| 409 |
+
page_content=source.get("content", ""),
|
| 410 |
+
metadata={
|
| 411 |
+
"filename": filename,
|
| 412 |
+
"source": source_name,
|
| 413 |
+
"score": source.get("score"),
|
| 414 |
+
"chunk_index": i,
|
| 415 |
+
"page": None, # Gemini doesn't provide page numbers
|
| 416 |
+
"year": year,
|
| 417 |
+
"district": district,
|
| 418 |
+
"chunk_id": f"gemini_{i}",
|
| 419 |
+
"_id": f"gemini_{i}",
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
formatted_sources.append(doc)
|
| 423 |
+
logger.info(f"π Formatted source {i+1}: {filename} ({year}, {source_name})")
|
| 424 |
+
|
| 425 |
+
logger.info(f"β
Formatted {len(formatted_sources)} sources for display")
|
| 426 |
+
return formatted_sources
|
| 427 |
+
|
src/{loader.py β llm/loader.py}
RENAMED
|
File without changes
|
src/pipeline.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""Main pipeline orchestrator for the Audit QA system."""
|
|
|
|
| 2 |
import time
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from typing import Dict, Any, List, Optional
|
|
@@ -11,11 +13,21 @@ except ModuleNotFoundError as me:
|
|
| 11 |
from langchain.schema import Document
|
| 12 |
|
| 13 |
from .logging import log_error
|
| 14 |
-
|
| 15 |
-
from .loader import chunks_to_documents
|
| 16 |
from .vectorstore import VectorStoreManager
|
|
|
|
| 17 |
from .retrieval.context import ContextRetriever
|
| 18 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
|
|
@@ -41,12 +53,13 @@ class PipelineManager:
|
|
| 41 |
"""
|
| 42 |
Initialize the pipeline manager.
|
| 43 |
"""
|
|
|
|
|
|
|
| 44 |
self.config = config or {}
|
|
|
|
| 45 |
self.vectorstore_manager = None
|
| 46 |
self.context_retriever = None # Initialize as None
|
| 47 |
-
|
| 48 |
-
self.report_service = None
|
| 49 |
-
self.chunks = None
|
| 50 |
|
| 51 |
# Initialize components
|
| 52 |
self._initialize_components()
|
|
@@ -118,13 +131,7 @@ class PipelineManager:
|
|
| 118 |
try:
|
| 119 |
# Load config if not provided
|
| 120 |
if not self.config:
|
| 121 |
-
|
| 122 |
-
from src.config.loader import load_config
|
| 123 |
-
self.config = load_config()
|
| 124 |
-
except ImportError:
|
| 125 |
-
# Try alternate import path
|
| 126 |
-
from src.config.loader import load_config
|
| 127 |
-
self.config = load_config()
|
| 128 |
|
| 129 |
# Validate config structure
|
| 130 |
if not isinstance(self.config, dict):
|
|
@@ -159,7 +166,6 @@ class PipelineManager:
|
|
| 159 |
print("β
VectorStoreManager initialized successfully")
|
| 160 |
except Exception as vs_error:
|
| 161 |
print(f"β Error initializing VectorStoreManager: {vs_error}")
|
| 162 |
-
import traceback
|
| 163 |
traceback.print_exc()
|
| 164 |
self.vectorstore_manager = None
|
| 165 |
raise # Re-raise to be caught by outer try-except
|
|
@@ -175,40 +181,35 @@ class PipelineManager:
|
|
| 175 |
except Exception as e:
|
| 176 |
try:
|
| 177 |
# Try direct instantiation with config
|
| 178 |
-
from src.llm.adapters import get_llm_client
|
| 179 |
self.llm_client = get_llm_client("openai", self.config)
|
| 180 |
print("β
LLM CLIENT: Initialized using direct get_llm_client function with config")
|
| 181 |
except Exception as e2:
|
| 182 |
print(f"β LLM CLIENT: Registry methods failed - {e2}")
|
| 183 |
# Try to create a simple LLM client directly
|
| 184 |
try:
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 196 |
else:
|
| 197 |
-
print("β LLM CLIENT:
|
| 198 |
except Exception as e3:
|
| 199 |
print(f"β LLM CLIENT: Direct instantiation also failed - {e3}")
|
| 200 |
self.llm_client = None
|
| 201 |
|
| 202 |
# Load system prompt
|
| 203 |
-
from src.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 204 |
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 205 |
|
| 206 |
# Initialize report service
|
| 207 |
try:
|
| 208 |
-
try:
|
| 209 |
-
from src.reporting.service import ReportService
|
| 210 |
-
except ImportError:
|
| 211 |
-
from src.reporting.service import ReportService
|
| 212 |
self.report_service = ReportService()
|
| 213 |
except Exception as e:
|
| 214 |
print(f"Warning: Could not initialize report service: {e}")
|
|
@@ -216,7 +217,6 @@ class PipelineManager:
|
|
| 216 |
|
| 217 |
except Exception as e:
|
| 218 |
print(f"β Error initializing components: {e}")
|
| 219 |
-
import traceback
|
| 220 |
traceback.print_exc()
|
| 221 |
# Don't set vectorstore_manager to None if it was already set
|
| 222 |
if not hasattr(self, 'vectorstore_manager') or self.vectorstore_manager is None:
|
|
@@ -337,7 +337,6 @@ class PipelineManager:
|
|
| 337 |
return False
|
| 338 |
except Exception as init_error:
|
| 339 |
print(f"β Error initializing vector store manager: {init_error}")
|
| 340 |
-
import traceback
|
| 341 |
traceback.print_exc()
|
| 342 |
return False
|
| 343 |
|
|
@@ -352,7 +351,6 @@ class PipelineManager:
|
|
| 352 |
except Exception as e:
|
| 353 |
print(f"β Error connecting to vector store: {e}")
|
| 354 |
log_error(e, {"component": "vectorstore_connection"})
|
| 355 |
-
import traceback
|
| 356 |
traceback.print_exc()
|
| 357 |
|
| 358 |
# If it's a dimension mismatch error, try with force_recreate
|
|
@@ -541,9 +539,6 @@ Answer:"""
|
|
| 541 |
if auto_infer_filters and not any([reports, sources, subtype]):
|
| 542 |
print(f"π€ AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
|
| 543 |
try:
|
| 544 |
-
# Import get_available_metadata here to avoid circular imports
|
| 545 |
-
from src.retrieval.filter import get_available_metadata, infer_filters_from_query
|
| 546 |
-
|
| 547 |
# Get available metadata
|
| 548 |
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
|
| 549 |
|
|
|
|
| 1 |
"""Main pipeline orchestrator for the Audit QA system."""
|
| 2 |
+
import os
|
| 3 |
import time
|
| 4 |
+
import traceback
|
| 5 |
from pathlib import Path
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import Dict, Any, List, Optional
|
|
|
|
| 13 |
from langchain.schema import Document
|
| 14 |
|
| 15 |
from .logging import log_error
|
| 16 |
+
|
| 17 |
+
from .llm.loader import chunks_to_documents
|
| 18 |
from .vectorstore import VectorStoreManager
|
| 19 |
+
from .reporting.service import ReportService
|
| 20 |
from .retrieval.context import ContextRetriever
|
| 21 |
+
from .llm.adapters import LLMRegistry, get_llm_client
|
| 22 |
+
from .llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 23 |
+
from .config.loader import load_config, get_embedding_model_for_collection
|
| 24 |
+
from .retrieval.filter import get_available_metadata, infer_filters_from_query
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from langchain_openai import ChatOpenAI
|
| 28 |
+
LANGCHAIN_OPENAI_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
LANGCHAIN_OPENAI_AVAILABLE = False
|
| 31 |
|
| 32 |
|
| 33 |
|
|
|
|
| 53 |
"""
|
| 54 |
Initialize the pipeline manager.
|
| 55 |
"""
|
| 56 |
+
self.chunks = None
|
| 57 |
+
self.llm_client = None
|
| 58 |
self.config = config or {}
|
| 59 |
+
self.report_service = None
|
| 60 |
self.vectorstore_manager = None
|
| 61 |
self.context_retriever = None # Initialize as None
|
| 62 |
+
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# Initialize components
|
| 65 |
self._initialize_components()
|
|
|
|
| 131 |
try:
|
| 132 |
# Load config if not provided
|
| 133 |
if not self.config:
|
| 134 |
+
self.config = load_config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Validate config structure
|
| 137 |
if not isinstance(self.config, dict):
|
|
|
|
| 166 |
print("β
VectorStoreManager initialized successfully")
|
| 167 |
except Exception as vs_error:
|
| 168 |
print(f"β Error initializing VectorStoreManager: {vs_error}")
|
|
|
|
| 169 |
traceback.print_exc()
|
| 170 |
self.vectorstore_manager = None
|
| 171 |
raise # Re-raise to be caught by outer try-except
|
|
|
|
| 181 |
except Exception as e:
|
| 182 |
try:
|
| 183 |
# Try direct instantiation with config
|
|
|
|
| 184 |
self.llm_client = get_llm_client("openai", self.config)
|
| 185 |
print("β
LLM CLIENT: Initialized using direct get_llm_client function with config")
|
| 186 |
except Exception as e2:
|
| 187 |
print(f"β LLM CLIENT: Registry methods failed - {e2}")
|
| 188 |
# Try to create a simple LLM client directly
|
| 189 |
try:
|
| 190 |
+
if LANGCHAIN_OPENAI_AVAILABLE:
|
| 191 |
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
| 192 |
+
if api_key:
|
| 193 |
+
self.llm_client = ChatOpenAI(
|
| 194 |
+
model="gpt-3.5-turbo",
|
| 195 |
+
api_key=api_key,
|
| 196 |
+
temperature=0.1,
|
| 197 |
+
max_tokens=1000
|
| 198 |
+
)
|
| 199 |
+
print("β
LLM CLIENT: Initialized using direct ChatOpenAI")
|
| 200 |
+
else:
|
| 201 |
+
print("β LLM CLIENT: No API key available")
|
| 202 |
else:
|
| 203 |
+
print("β LLM CLIENT: langchain-openai not available")
|
| 204 |
except Exception as e3:
|
| 205 |
print(f"β LLM CLIENT: Direct instantiation also failed - {e3}")
|
| 206 |
self.llm_client = None
|
| 207 |
|
| 208 |
# Load system prompt
|
|
|
|
| 209 |
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 210 |
|
| 211 |
# Initialize report service
|
| 212 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
self.report_service = ReportService()
|
| 214 |
except Exception as e:
|
| 215 |
print(f"Warning: Could not initialize report service: {e}")
|
|
|
|
| 217 |
|
| 218 |
except Exception as e:
|
| 219 |
print(f"β Error initializing components: {e}")
|
|
|
|
| 220 |
traceback.print_exc()
|
| 221 |
# Don't set vectorstore_manager to None if it was already set
|
| 222 |
if not hasattr(self, 'vectorstore_manager') or self.vectorstore_manager is None:
|
|
|
|
| 337 |
return False
|
| 338 |
except Exception as init_error:
|
| 339 |
print(f"β Error initializing vector store manager: {init_error}")
|
|
|
|
| 340 |
traceback.print_exc()
|
| 341 |
return False
|
| 342 |
|
|
|
|
| 351 |
except Exception as e:
|
| 352 |
print(f"β Error connecting to vector store: {e}")
|
| 353 |
log_error(e, {"component": "vectorstore_connection"})
|
|
|
|
| 354 |
traceback.print_exc()
|
| 355 |
|
| 356 |
# If it's a dimension mismatch error, try with force_recreate
|
|
|
|
| 539 |
if auto_infer_filters and not any([reports, sources, subtype]):
|
| 540 |
print(f"π€ AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
|
| 541 |
try:
|
|
|
|
|
|
|
|
|
|
| 542 |
# Get available metadata
|
| 543 |
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
|
| 544 |
|
src/reporting/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
"""Report metadata and utilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from .metadata import get_report_metadata, get_available_sources
|
| 4 |
from .service import ReportService
|
|
|
|
| 1 |
+
"""Report metadata and utilities.
|
| 2 |
+
|
| 3 |
+
This module is kept for backward compatibility with pipeline.py.
|
| 4 |
+
For feedback-related functionality, use src.feedback instead.
|
| 5 |
+
"""
|
| 6 |
|
| 7 |
from .metadata import get_report_metadata, get_available_sources
|
| 8 |
from .service import ReportService
|
src/reporting/feedback_schema.py
CHANGED
|
@@ -4,10 +4,12 @@ Feedback Schema for RAG Chatbot
|
|
| 4 |
This module defines dataclasses for feedback data structures
|
| 5 |
and provides Snowflake schema generation.
|
| 6 |
"""
|
| 7 |
-
|
|
|
|
| 8 |
from dataclasses import dataclass, asdict, field
|
| 9 |
from typing import List, Optional, Dict, Any, Union
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
@@ -39,34 +41,20 @@ class UserFeedback:
|
|
| 39 |
open_ended_feedback: Optional[str]
|
| 40 |
score: int
|
| 41 |
is_feedback_about_last_retrieval: bool
|
| 42 |
-
retrieved_data: List[RetrievalEntry]
|
| 43 |
conversation_id: str
|
| 44 |
timestamp: float
|
| 45 |
message_count: int
|
| 46 |
has_retrievals: bool
|
| 47 |
retrieval_count: int
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 51 |
|
| 52 |
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
"""Convert to dictionary with nested data structures"""
|
| 54 |
result = asdict(self)
|
| 55 |
-
# Handle nested objects
|
| 56 |
-
if self.retrieved_data:
|
| 57 |
-
result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
|
| 58 |
-
return result
|
| 59 |
-
|
| 60 |
-
def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
|
| 61 |
-
"""Serialize retrieval entry to dict"""
|
| 62 |
-
# If raw data exists, use it (it's already properly formatted)
|
| 63 |
-
if hasattr(entry, '_raw_data') and entry._raw_data:
|
| 64 |
-
return entry._raw_data
|
| 65 |
-
|
| 66 |
-
# Otherwise, serialize the dataclass
|
| 67 |
-
result = asdict(entry)
|
| 68 |
-
if entry.documents_retrieved:
|
| 69 |
-
result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
|
| 70 |
return result
|
| 71 |
|
| 72 |
def to_snowflake_schema(self) -> Dict[str, Any]:
|
|
@@ -81,28 +69,28 @@ class UserFeedback:
|
|
| 81 |
"message_count": "INTEGER",
|
| 82 |
"has_retrievals": "BOOLEAN",
|
| 83 |
"retrieval_count": "INTEGER",
|
| 84 |
-
"
|
| 85 |
-
"
|
|
|
|
|
|
|
| 86 |
"created_at": "TIMESTAMP_NTZ",
|
| 87 |
-
"
|
| 88 |
-
#
|
| 89 |
-
# [
|
| 90 |
# {
|
| 91 |
-
# "
|
| 92 |
-
# "
|
| 93 |
-
# "timestamp": 1234567890,
|
| 94 |
-
# "docs_retrieved": [
|
| 95 |
-
# {"filename": "...", "page": 14, "score": 0.95, ...},
|
| 96 |
-
# ...
|
| 97 |
-
# ]
|
| 98 |
# },
|
| 99 |
# ...
|
| 100 |
# ]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}
|
| 102 |
return schema
|
| 103 |
|
| 104 |
@classmethod
|
| 105 |
-
def get_snowflake_create_table_sql(cls, table_name: str = "
|
| 106 |
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 107 |
schema = cls.to_snowflake_schema(None)
|
| 108 |
|
|
@@ -117,16 +105,13 @@ class UserFeedback:
|
|
| 117 |
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 118 |
{columns_str},
|
| 119 |
PRIMARY KEY (feedback_id)
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
--
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
--
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
-- Create index on score for feedback analysis
|
| 129 |
-
CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
|
| 130 |
"""
|
| 131 |
return sql
|
| 132 |
|
|
@@ -150,47 +135,27 @@ DOCUMENT_SCHEMA = {
|
|
| 150 |
}
|
| 151 |
|
| 152 |
|
| 153 |
-
def generate_snowflake_schema_sql() -> str:
|
| 154 |
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 159 |
"""Create UserFeedback instance from dictionary"""
|
| 160 |
-
# Parse retrieved_data if present
|
| 161 |
-
retrieved_data = []
|
| 162 |
-
if "retrieved_data" in data and data["retrieved_data"]:
|
| 163 |
-
for entry_dict in data.get("retrieved_data", []):
|
| 164 |
-
# Map the actual structure from rag_retrieval_history
|
| 165 |
-
# Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
|
| 166 |
-
try:
|
| 167 |
-
# Try to map to expected structure
|
| 168 |
-
entry = RetrievalEntry(
|
| 169 |
-
rag_query=entry_dict.get("rag_query_expansion", ""),
|
| 170 |
-
documents_retrieved=[], # Empty for now, will store as raw data
|
| 171 |
-
conversation_length=len(entry_dict.get("conversation_up_to", [])),
|
| 172 |
-
filters_applied=None,
|
| 173 |
-
timestamp=entry_dict.get("timestamp", None)
|
| 174 |
-
)
|
| 175 |
-
# Store raw data in the entry
|
| 176 |
-
entry._raw_data = entry_dict # Store original for preservation
|
| 177 |
-
retrieved_data.append(entry)
|
| 178 |
-
except Exception as e:
|
| 179 |
-
# If mapping fails, store as-is without strict typing
|
| 180 |
-
pass
|
| 181 |
-
|
| 182 |
return UserFeedback(
|
| 183 |
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 184 |
open_ended_feedback=data.get("open_ended_feedback"),
|
| 185 |
score=data["score"],
|
| 186 |
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 187 |
-
retrieved_data=retrieved_data,
|
| 188 |
conversation_id=data["conversation_id"],
|
| 189 |
timestamp=data["timestamp"],
|
| 190 |
message_count=data["message_count"],
|
| 191 |
has_retrievals=data["has_retrievals"],
|
| 192 |
retrieval_count=data["retrieval_count"],
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
-
|
|
|
|
| 4 |
This module defines dataclasses for feedback data structures
|
| 5 |
and provides Snowflake schema generation.
|
| 6 |
"""
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
from dataclasses import dataclass, asdict, field
|
| 10 |
from typing import List, Optional, Dict, Any, Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass
|
|
|
|
| 41 |
open_ended_feedback: Optional[str]
|
| 42 |
score: int
|
| 43 |
is_feedback_about_last_retrieval: bool
|
|
|
|
| 44 |
conversation_id: str
|
| 45 |
timestamp: float
|
| 46 |
message_count: int
|
| 47 |
has_retrievals: bool
|
| 48 |
retrieval_count: int
|
| 49 |
+
transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
|
| 50 |
+
retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
|
| 51 |
+
feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
|
| 52 |
+
retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
|
| 53 |
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 54 |
|
| 55 |
def to_dict(self) -> Dict[str, Any]:
|
| 56 |
"""Convert to dictionary with nested data structures"""
|
| 57 |
result = asdict(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
return result
|
| 59 |
|
| 60 |
def to_snowflake_schema(self) -> Dict[str, Any]:
|
|
|
|
| 69 |
"message_count": "INTEGER",
|
| 70 |
"has_retrievals": "BOOLEAN",
|
| 71 |
"retrieval_count": "INTEGER",
|
| 72 |
+
"transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
|
| 73 |
+
"retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
|
| 74 |
+
"feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
|
| 75 |
+
"retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
|
| 76 |
"created_at": "TIMESTAMP_NTZ",
|
| 77 |
+
# transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
| 78 |
+
# retrievals structure: [
|
|
|
|
| 79 |
# {
|
| 80 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
|
| 81 |
+
# "user_message_trigger": "final user message that triggered this retrieval"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# },
|
| 83 |
# ...
|
| 84 |
# ]
|
| 85 |
+
# feedback_score_related_retrieval_docs structure: {
|
| 86 |
+
# "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
|
| 87 |
+
# "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
|
| 88 |
+
# }
|
| 89 |
}
|
| 90 |
return schema
|
| 91 |
|
| 92 |
@classmethod
|
| 93 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
|
| 94 |
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 95 |
schema = cls.to_snowflake_schema(None)
|
| 96 |
|
|
|
|
| 105 |
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 106 |
{columns_str},
|
| 107 |
PRIMARY KEY (feedback_id)
|
| 108 |
+
)
|
| 109 |
+
CLUSTER BY (timestamp, conversation_id, score);
|
| 110 |
+
-- Note: Snowflake doesn't support traditional indexes on regular tables.
|
| 111 |
+
-- Instead, we use CLUSTER BY to optimize queries on these columns.
|
| 112 |
+
-- Snowflake automatically maintains clustering for efficient querying.
|
| 113 |
+
-- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
|
| 114 |
+
-- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
|
|
|
|
|
|
|
|
|
|
| 115 |
"""
|
| 116 |
return sql
|
| 117 |
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
|
| 138 |
+
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
|
| 139 |
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 140 |
+
if table_name is None:
|
| 141 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 142 |
+
return UserFeedback.get_snowflake_create_table_sql(table_name)
|
| 143 |
|
| 144 |
|
| 145 |
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 146 |
"""Create UserFeedback instance from dictionary"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
return UserFeedback(
|
| 148 |
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 149 |
open_ended_feedback=data.get("open_ended_feedback"),
|
| 150 |
score=data["score"],
|
| 151 |
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
|
|
|
| 152 |
conversation_id=data["conversation_id"],
|
| 153 |
timestamp=data["timestamp"],
|
| 154 |
message_count=data["message_count"],
|
| 155 |
has_retrievals=data["has_retrievals"],
|
| 156 |
retrieval_count=data["retrieval_count"],
|
| 157 |
+
transcript=data.get("transcript", []),
|
| 158 |
+
retrievals=data.get("retrievals", []),
|
| 159 |
+
feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
|
| 160 |
+
retrieved_data=data.get("retrieved_data")
|
| 161 |
)
|
|
|
src/reporting/snowflake_connector.py
CHANGED
|
@@ -8,8 +8,11 @@ import os
|
|
| 8 |
import json
|
| 9 |
import logging
|
| 10 |
from typing import Dict, Any, Optional
|
|
|
|
|
|
|
| 11 |
from src.reporting.feedback_schema import UserFeedback
|
| 12 |
|
|
|
|
| 13 |
# Try to import snowflake connector
|
| 14 |
try:
|
| 15 |
import snowflake.connector
|
|
@@ -79,12 +82,16 @@ class SnowflakeFeedbackConnector:
|
|
| 79 |
self._connection.close()
|
| 80 |
print("β
Disconnected from Snowflake")
|
| 81 |
|
| 82 |
-
def insert_feedback(self, feedback: UserFeedback) -> bool:
|
| 83 |
"""Insert a single feedback record into Snowflake"""
|
| 84 |
logger.info("=" * 80)
|
| 85 |
logger.info("π SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
if not self._connection:
|
| 89 |
logger.error("β Not connected to Snowflake. Call connect() first.")
|
| 90 |
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
|
@@ -131,38 +138,53 @@ class SnowflakeFeedbackConnector:
|
|
| 131 |
logger.error(f"β Could not set context: {e}")
|
| 132 |
raise
|
| 133 |
|
| 134 |
-
# Prepare data
|
| 135 |
-
logger.info("π§ DATA PREPARATION: Preparing
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
else:
|
| 148 |
-
|
| 149 |
-
logger.info(" -
|
| 150 |
-
retrieved_data = retrieved_data_raw
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
#
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
| 160 |
else:
|
| 161 |
-
logger.info(f" - Retrieved data is None, using NULL")
|
| 162 |
retrieved_data_for_db = None
|
|
|
|
| 163 |
|
| 164 |
-
# Build SQL with
|
| 165 |
-
|
|
|
|
| 166 |
feedback_id,
|
| 167 |
open_ended_feedback,
|
| 168 |
score,
|
|
@@ -172,23 +194,25 @@ class SnowflakeFeedbackConnector:
|
|
| 172 |
message_count,
|
| 173 |
has_retrievals,
|
| 174 |
retrieval_count,
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
retrieved_data
|
|
|
|
| 179 |
) VALUES (
|
| 180 |
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 181 |
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 182 |
-
%(retrieval_count)s, %(
|
| 183 |
-
%(retrieved_data)s
|
| 184 |
)"""
|
| 185 |
|
| 186 |
logger.info("π SQL PREPARATION: Building INSERT statement...")
|
| 187 |
-
logger.info(f" - Target table:
|
| 188 |
logger.info(f" - Database: {self.database}")
|
| 189 |
logger.info(f" - Schema: {self.schema}")
|
| 190 |
|
| 191 |
# Prepare parameters
|
|
|
|
| 192 |
params = {
|
| 193 |
'feedback_id': feedback.feedback_id,
|
| 194 |
'open_ended_feedback': feedback.open_ended_feedback,
|
|
@@ -199,10 +223,11 @@ class SnowflakeFeedbackConnector:
|
|
| 199 |
'message_count': feedback.message_count,
|
| 200 |
'has_retrievals': feedback.has_retrievals,
|
| 201 |
'retrieval_count': feedback.retrieval_count,
|
| 202 |
-
'
|
| 203 |
-
'
|
| 204 |
-
'
|
| 205 |
-
'retrieved_data': retrieved_data_for_db
|
|
|
|
| 206 |
}
|
| 207 |
|
| 208 |
# Execute insert
|
|
@@ -265,12 +290,16 @@ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
|
| 265 |
)
|
| 266 |
|
| 267 |
|
| 268 |
-
def save_to_snowflake(feedback: UserFeedback) -> bool:
|
| 269 |
"""Helper function to save feedback to Snowflake"""
|
| 270 |
logger.info("=" * 80)
|
| 271 |
logger.info("π΅ SNOWFLAKE SAVE: Starting save process")
|
| 272 |
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
connector = get_snowflake_connector_from_env()
|
| 275 |
|
| 276 |
if not connector:
|
|
@@ -285,7 +314,7 @@ def save_to_snowflake(feedback: UserFeedback) -> bool:
|
|
| 285 |
logger.info("β
SNOWFLAKE SAVE: Connection established")
|
| 286 |
|
| 287 |
logger.info("π₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 288 |
-
success = connector.insert_feedback(feedback)
|
| 289 |
|
| 290 |
logger.info("π SNOWFLAKE SAVE: Disconnecting...")
|
| 291 |
connector.disconnect()
|
|
@@ -302,4 +331,3 @@ def save_to_snowflake(feedback: UserFeedback) -> bool:
|
|
| 302 |
logger.error(f" - Error: {e}")
|
| 303 |
logger.info("=" * 80)
|
| 304 |
return False
|
| 305 |
-
|
|
|
|
| 8 |
import json
|
| 9 |
import logging
|
| 10 |
from typing import Dict, Any, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
from src.reporting.feedback_schema import UserFeedback
|
| 14 |
|
| 15 |
+
|
| 16 |
# Try to import snowflake connector
|
| 17 |
try:
|
| 18 |
import snowflake.connector
|
|
|
|
| 82 |
self._connection.close()
|
| 83 |
print("β
Disconnected from Snowflake")
|
| 84 |
|
| 85 |
+
def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 86 |
"""Insert a single feedback record into Snowflake"""
|
| 87 |
logger.info("=" * 80)
|
| 88 |
logger.info("π SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 89 |
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 90 |
|
| 91 |
+
# Get table name from parameter, env var, or default
|
| 92 |
+
if table_name is None:
|
| 93 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 94 |
+
|
| 95 |
if not self._connection:
|
| 96 |
logger.error("β Not connected to Snowflake. Call connect() first.")
|
| 97 |
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
|
|
|
| 138 |
logger.error(f"β Could not set context: {e}")
|
| 139 |
raise
|
| 140 |
|
| 141 |
+
# Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 142 |
+
logger.info("π§ DATA PREPARATION: Preparing VARIANT columns...")
|
| 143 |
+
feedback_dict = feedback.to_dict()
|
| 144 |
|
| 145 |
+
# Prepare transcript (ARRAY) - convert to JSON string
|
| 146 |
+
transcript_raw = feedback_dict.get('transcript', [])
|
| 147 |
+
if transcript_raw:
|
| 148 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 149 |
+
transcript_for_db = json.dumps(transcript_raw)
|
| 150 |
+
logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
|
| 151 |
+
else:
|
| 152 |
+
transcript_for_db = None
|
| 153 |
+
logger.info(" - Transcript: None")
|
| 154 |
|
| 155 |
+
# Prepare retrievals (ARRAY) - convert to JSON string
|
| 156 |
+
retrievals_raw = feedback_dict.get('retrievals', [])
|
| 157 |
+
if retrievals_raw:
|
| 158 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 159 |
+
retrievals_for_db = json.dumps(retrievals_raw)
|
| 160 |
+
logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
|
| 161 |
else:
|
| 162 |
+
retrievals_for_db = None
|
| 163 |
+
logger.info(" - Retrievals: None")
|
|
|
|
| 164 |
|
| 165 |
+
# Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
|
| 166 |
+
feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
|
| 167 |
+
if feedback_score_related_raw:
|
| 168 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 169 |
+
feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
|
| 170 |
+
logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
|
| 171 |
+
else:
|
| 172 |
+
feedback_score_related_for_db = None
|
| 173 |
+
logger.info(" - Feedback score related docs: None")
|
| 174 |
|
| 175 |
+
# Prepare retrieved_data (preserved old column) - convert to JSON string
|
| 176 |
+
retrieved_data_raw = feedback_dict.get('retrieved_data')
|
| 177 |
+
if retrieved_data_raw:
|
| 178 |
+
# Convert to JSON string (same approach as old retrieved_data)
|
| 179 |
+
retrieved_data_for_db = json.dumps(retrieved_data_raw)
|
| 180 |
+
logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
|
| 181 |
else:
|
|
|
|
| 182 |
retrieved_data_for_db = None
|
| 183 |
+
logger.info(" - Retrieved data (preserved): None")
|
| 184 |
|
| 185 |
+
# Build SQL with new column structure
|
| 186 |
+
# Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
|
| 187 |
+
sql = f"""INSERT INTO {table_name} (
|
| 188 |
feedback_id,
|
| 189 |
open_ended_feedback,
|
| 190 |
score,
|
|
|
|
| 194 |
message_count,
|
| 195 |
has_retrievals,
|
| 196 |
retrieval_count,
|
| 197 |
+
transcript,
|
| 198 |
+
retrievals,
|
| 199 |
+
feedback_score_related_retrieval_docs,
|
| 200 |
+
retrieved_data,
|
| 201 |
+
created_at
|
| 202 |
) VALUES (
|
| 203 |
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 204 |
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 205 |
+
%(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
|
| 206 |
+
%(retrieved_data)s, %(created_at)s
|
| 207 |
)"""
|
| 208 |
|
| 209 |
logger.info("π SQL PREPARATION: Building INSERT statement...")
|
| 210 |
+
logger.info(f" - Target table: {table_name}")
|
| 211 |
logger.info(f" - Database: {self.database}")
|
| 212 |
logger.info(f" - Schema: {self.schema}")
|
| 213 |
|
| 214 |
# Prepare parameters
|
| 215 |
+
# Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
|
| 216 |
params = {
|
| 217 |
'feedback_id': feedback.feedback_id,
|
| 218 |
'open_ended_feedback': feedback.open_ended_feedback,
|
|
|
|
| 223 |
'message_count': feedback.message_count,
|
| 224 |
'has_retrievals': feedback.has_retrievals,
|
| 225 |
'retrieval_count': feedback.retrieval_count,
|
| 226 |
+
'transcript': transcript_for_db, # JSON string
|
| 227 |
+
'retrievals': retrievals_for_db, # JSON string
|
| 228 |
+
'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
|
| 229 |
+
'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
|
| 230 |
+
'created_at': feedback.created_at
|
| 231 |
}
|
| 232 |
|
| 233 |
# Execute insert
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
|
| 293 |
+
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
|
| 294 |
"""Helper function to save feedback to Snowflake"""
|
| 295 |
logger.info("=" * 80)
|
| 296 |
logger.info("π΅ SNOWFLAKE SAVE: Starting save process")
|
| 297 |
logger.info(f"π Feedback ID: {feedback.feedback_id}")
|
| 298 |
|
| 299 |
+
# Get table name from parameter or env var
|
| 300 |
+
if table_name is None:
|
| 301 |
+
table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
|
| 302 |
+
|
| 303 |
connector = get_snowflake_connector_from_env()
|
| 304 |
|
| 305 |
if not connector:
|
|
|
|
| 314 |
logger.info("β
SNOWFLAKE SAVE: Connection established")
|
| 315 |
|
| 316 |
logger.info("π₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 317 |
+
success = connector.insert_feedback(feedback, table_name=table_name)
|
| 318 |
|
| 319 |
logger.info("π SNOWFLAKE SAVE: Disconnecting...")
|
| 320 |
connector.disconnect()
|
|
|
|
| 331 |
logger.error(f" - Error: {e}")
|
| 332 |
logger.info("=" * 80)
|
| 333 |
return False
|
|
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ui_components/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI Components Module
|
| 3 |
+
|
| 4 |
+
This module contains UI-related components including styles, visualizations,
|
| 5 |
+
and utility functions for the Streamlit application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .styles import get_custom_css
|
| 9 |
+
from .components import (
|
| 10 |
+
display_chunk_statistics_charts,
|
| 11 |
+
display_chunk_statistics_table
|
| 12 |
+
)
|
| 13 |
+
from .utils import extract_chunk_statistics
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"get_custom_css",
|
| 17 |
+
"display_chunk_statistics_charts",
|
| 18 |
+
"display_chunk_statistics_table",
|
| 19 |
+
"extract_chunk_statistics"
|
| 20 |
+
]
|
| 21 |
+
|
src/ui_components/components.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI components for displaying statistics and visualizations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
|
| 12 |
+
"""Display statistics as interactive charts for 10+ results."""
|
| 13 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# Wrap everything in one styled container - open it
|
| 17 |
+
st.markdown(f"""
|
| 18 |
+
<div class="retrieval-distribution-container">
|
| 19 |
+
<h3 style="margin-top: 0;">π {title}</h3>
|
| 20 |
+
<div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
|
| 21 |
+
<div class="metric-container">
|
| 22 |
+
<div class="metric-label">Total Chunks</div>
|
| 23 |
+
<div class="metric-value">{stats['total_chunks']}</div>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="metric-container">
|
| 26 |
+
<div class="metric-label">Unique Sources</div>
|
| 27 |
+
<div class="metric-value">{stats['unique_sources']}</div>
|
| 28 |
+
</div>
|
| 29 |
+
<div class="metric-container">
|
| 30 |
+
<div class="metric-label">Unique Years</div>
|
| 31 |
+
<div class="metric-value">{stats['unique_years']}</div>
|
| 32 |
+
</div>
|
| 33 |
+
<div class="metric-container">
|
| 34 |
+
<div class="metric-label">Unique Files</div>
|
| 35 |
+
<div class="metric-value">{stats['unique_filenames']}</div>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
""", unsafe_allow_html=True)
|
| 39 |
+
|
| 40 |
+
# Charts - three columns to include Districts
|
| 41 |
+
col1, col2, col3 = st.columns(3)
|
| 42 |
+
|
| 43 |
+
with col1:
|
| 44 |
+
# Source distribution chart
|
| 45 |
+
if stats['source_distribution']:
|
| 46 |
+
source_df = pd.DataFrame(
|
| 47 |
+
list(stats['source_distribution'].items()),
|
| 48 |
+
columns=['Source', 'Count']
|
| 49 |
+
)
|
| 50 |
+
fig_source = px.bar(
|
| 51 |
+
source_df,
|
| 52 |
+
x='Count',
|
| 53 |
+
y='Source',
|
| 54 |
+
orientation='h',
|
| 55 |
+
title='Distribution by Source',
|
| 56 |
+
color='Count',
|
| 57 |
+
color_continuous_scale='viridis'
|
| 58 |
+
)
|
| 59 |
+
fig_source.update_layout(height=400, showlegend=False)
|
| 60 |
+
st.plotly_chart(fig_source, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 61 |
+
|
| 62 |
+
with col2:
|
| 63 |
+
# Year distribution chart
|
| 64 |
+
if stats['year_distribution']:
|
| 65 |
+
# Filter out 'Unknown' years for the chart
|
| 66 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 67 |
+
if year_dist_filtered:
|
| 68 |
+
year_df = pd.DataFrame(
|
| 69 |
+
list(year_dist_filtered.items()),
|
| 70 |
+
columns=['Year', 'Count']
|
| 71 |
+
)
|
| 72 |
+
# Sort by year as integer but keep as string for categorical display
|
| 73 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 74 |
+
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
|
| 75 |
+
|
| 76 |
+
fig_year = px.bar(
|
| 77 |
+
year_df,
|
| 78 |
+
x='Year',
|
| 79 |
+
y='Count',
|
| 80 |
+
title='Distribution by Year',
|
| 81 |
+
color='Count',
|
| 82 |
+
color_continuous_scale='plasma'
|
| 83 |
+
)
|
| 84 |
+
# Ensure years are treated as categorical (discrete) not continuous
|
| 85 |
+
fig_year.update_xaxes(type='category')
|
| 86 |
+
fig_year.update_layout(height=400, showlegend=False)
|
| 87 |
+
st.plotly_chart(fig_year, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 88 |
+
else:
|
| 89 |
+
st.info("No valid years found in the results")
|
| 90 |
+
|
| 91 |
+
with col3:
|
| 92 |
+
# District distribution chart
|
| 93 |
+
if stats.get('district_distribution'):
|
| 94 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 95 |
+
if district_dist_filtered:
|
| 96 |
+
district_df = pd.DataFrame(
|
| 97 |
+
list(district_dist_filtered.items()),
|
| 98 |
+
columns=['District', 'Count']
|
| 99 |
+
)
|
| 100 |
+
district_df = district_df.sort_values('Count', ascending=False)
|
| 101 |
+
|
| 102 |
+
fig_district = px.bar(
|
| 103 |
+
district_df,
|
| 104 |
+
x='Count',
|
| 105 |
+
y='District',
|
| 106 |
+
orientation='h',
|
| 107 |
+
title='Distribution by District',
|
| 108 |
+
color='Count',
|
| 109 |
+
color_continuous_scale='blues'
|
| 110 |
+
)
|
| 111 |
+
fig_district.update_layout(height=400, showlegend=False)
|
| 112 |
+
st.plotly_chart(fig_district, use_container_width=True) # Note: plotly_chart still uses use_container_width
|
| 113 |
+
else:
|
| 114 |
+
st.info("No valid districts found in the results")
|
| 115 |
+
|
| 116 |
+
# Close the container
|
| 117 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
|
| 121 |
+
"""Display statistics as tables for smaller results with fixed alignment."""
|
| 122 |
+
if not stats or stats.get('total_chunks', 0) == 0:
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# Wrap in styled container
|
| 126 |
+
st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
|
| 127 |
+
|
| 128 |
+
st.subheader(f"π {title}")
|
| 129 |
+
|
| 130 |
+
# Create a container with fixed height for alignment
|
| 131 |
+
stats_container = st.container()
|
| 132 |
+
|
| 133 |
+
with stats_container:
|
| 134 |
+
# Create 4 equal columns for consistent alignment
|
| 135 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 136 |
+
|
| 137 |
+
with col1:
|
| 138 |
+
st.markdown("**ποΈ Districts**")
|
| 139 |
+
if stats.get('district_distribution'):
|
| 140 |
+
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
|
| 141 |
+
if district_dist_filtered:
|
| 142 |
+
district_data = {
|
| 143 |
+
"District": list(district_dist_filtered.keys()),
|
| 144 |
+
"Count": list(district_dist_filtered.values())
|
| 145 |
+
}
|
| 146 |
+
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
|
| 147 |
+
st.dataframe(district_df, hide_index=True, width='stretch')
|
| 148 |
+
else:
|
| 149 |
+
st.write("No district data")
|
| 150 |
+
else:
|
| 151 |
+
st.write("No district data")
|
| 152 |
+
|
| 153 |
+
with col2:
|
| 154 |
+
st.markdown("**π Sources**")
|
| 155 |
+
if stats['source_distribution']:
|
| 156 |
+
source_data = {
|
| 157 |
+
"Source": list(stats['source_distribution'].keys()),
|
| 158 |
+
"Count": list(stats['source_distribution'].values())
|
| 159 |
+
}
|
| 160 |
+
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
|
| 161 |
+
st.dataframe(source_df, hide_index=True, width='stretch')
|
| 162 |
+
else:
|
| 163 |
+
st.write("No source data")
|
| 164 |
+
|
| 165 |
+
with col3:
|
| 166 |
+
st.markdown("**π
Years**")
|
| 167 |
+
if stats['year_distribution']:
|
| 168 |
+
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
|
| 169 |
+
if year_dist_filtered:
|
| 170 |
+
year_data = {
|
| 171 |
+
"Year": list(year_dist_filtered.keys()),
|
| 172 |
+
"Count": list(year_dist_filtered.values())
|
| 173 |
+
}
|
| 174 |
+
year_df = pd.DataFrame(year_data)
|
| 175 |
+
# Sort by year as integer but display as string
|
| 176 |
+
year_df['Year_Int'] = year_df['Year'].astype(int)
|
| 177 |
+
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
|
| 178 |
+
st.dataframe(year_df, hide_index=True, width='stretch')
|
| 179 |
+
else:
|
| 180 |
+
st.write("No year data")
|
| 181 |
+
else:
|
| 182 |
+
st.write("No year data")
|
| 183 |
+
|
| 184 |
+
with col4:
|
| 185 |
+
st.markdown("**π Files**")
|
| 186 |
+
if stats['filename_distribution']:
|
| 187 |
+
filename_items = list(stats['filename_distribution'].items())
|
| 188 |
+
filename_items.sort(key=lambda x: x[1], reverse=True)
|
| 189 |
+
|
| 190 |
+
# Show top files with truncated names
|
| 191 |
+
file_data = {
|
| 192 |
+
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
|
| 193 |
+
"Count": [c for f, c in filename_items[:5]]
|
| 194 |
+
}
|
| 195 |
+
file_df = pd.DataFrame(file_data)
|
| 196 |
+
st.dataframe(file_df, hide_index=True, width='stretch')
|
| 197 |
+
else:
|
| 198 |
+
st.write("No file data")
|
| 199 |
+
|
| 200 |
+
# Close container
|
| 201 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 202 |
+
|
src/ui_components/styles.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom CSS styles for Streamlit application
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_custom_css() -> str:
|
| 7 |
+
"""Get custom CSS styles as a string"""
|
| 8 |
+
return """
|
| 9 |
+
<style>
|
| 10 |
+
.main-header {
|
| 11 |
+
font-size: 2.5rem;
|
| 12 |
+
font-weight: bold;
|
| 13 |
+
color: #1f77b4;
|
| 14 |
+
text-align: center;
|
| 15 |
+
margin-bottom: 1rem;
|
| 16 |
+
width: 100%;
|
| 17 |
+
display: block;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.subtitle {
|
| 21 |
+
font-size: 1.2rem;
|
| 22 |
+
color: #666;
|
| 23 |
+
text-align: center;
|
| 24 |
+
margin-bottom: 2rem;
|
| 25 |
+
width: 100%;
|
| 26 |
+
display: block;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
.session-info {
|
| 30 |
+
background-color: #f0f2f6;
|
| 31 |
+
padding: 10px;
|
| 32 |
+
border-radius: 5px;
|
| 33 |
+
margin-bottom: 20px;
|
| 34 |
+
font-size: 0.9rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.user-message {
|
| 38 |
+
background-color: #007bff;
|
| 39 |
+
color: white;
|
| 40 |
+
padding: 12px 16px;
|
| 41 |
+
border-radius: 18px 18px 4px 18px;
|
| 42 |
+
margin: 8px 0;
|
| 43 |
+
margin-left: 20%;
|
| 44 |
+
word-wrap: break-word;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.bot-message {
|
| 48 |
+
background-color: #f1f3f4;
|
| 49 |
+
color: #333;
|
| 50 |
+
padding: 12px 16px;
|
| 51 |
+
border-radius: 18px 18px 18px 4px;
|
| 52 |
+
margin: 8px 0;
|
| 53 |
+
margin-right: 20%;
|
| 54 |
+
word-wrap: break-word;
|
| 55 |
+
border: 1px solid #e0e0e0;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.filter-section {
|
| 59 |
+
margin-bottom: 20px;
|
| 60 |
+
padding: 15px;
|
| 61 |
+
background-color: #f8f9fa;
|
| 62 |
+
border-radius: 8px;
|
| 63 |
+
border: 1px solid #e9ecef;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.filter-title {
|
| 67 |
+
font-weight: bold;
|
| 68 |
+
margin-bottom: 10px;
|
| 69 |
+
color: #495057;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.feedback-section {
|
| 73 |
+
background-color: #f8f9fa;
|
| 74 |
+
padding: 20px;
|
| 75 |
+
border-radius: 10px;
|
| 76 |
+
margin-top: 30px;
|
| 77 |
+
border: 2px solid #dee2e6;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.retrieval-history {
|
| 81 |
+
background-color: #ffffff;
|
| 82 |
+
padding: 15px;
|
| 83 |
+
border-radius: 5px;
|
| 84 |
+
margin: 10px 0;
|
| 85 |
+
border-left: 4px solid #007bff;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.retrieval-distribution-container {
|
| 89 |
+
background-color: #ffffff;
|
| 90 |
+
padding: 25px;
|
| 91 |
+
border-radius: 10px;
|
| 92 |
+
margin: 20px 0;
|
| 93 |
+
border: 2px solid #e0e0e0;
|
| 94 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.metric-label {
|
| 98 |
+
font-size: 0.9rem;
|
| 99 |
+
color: #555;
|
| 100 |
+
margin-bottom: 5px;
|
| 101 |
+
text-align: center;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.metric-value {
|
| 105 |
+
font-size: 1.8rem;
|
| 106 |
+
font-weight: bold;
|
| 107 |
+
color: #000000;
|
| 108 |
+
text-align: center;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.metric-container {
|
| 112 |
+
text-align: center;
|
| 113 |
+
padding: 10px;
|
| 114 |
+
}
|
| 115 |
+
</style>
|
| 116 |
+
"""
|
| 117 |
+
|
src/ui_components/utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI utility functions for data processing and statistics
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
|
| 10 |
+
"""Extract statistics from retrieved chunks."""
|
| 11 |
+
if not sources:
|
| 12 |
+
return {}
|
| 13 |
+
|
| 14 |
+
sources_list = []
|
| 15 |
+
years = []
|
| 16 |
+
filenames = []
|
| 17 |
+
districts = []
|
| 18 |
+
|
| 19 |
+
for doc in sources:
|
| 20 |
+
metadata = getattr(doc, 'metadata', {})
|
| 21 |
+
|
| 22 |
+
# Extract source
|
| 23 |
+
source = metadata.get('source', 'Unknown')
|
| 24 |
+
sources_list.append(source)
|
| 25 |
+
|
| 26 |
+
# Extract year
|
| 27 |
+
year = metadata.get('year', 'Unknown')
|
| 28 |
+
if year and year != 'Unknown':
|
| 29 |
+
try:
|
| 30 |
+
# Convert to int first, then back to string to ensure it's a proper year
|
| 31 |
+
year_int = int(float(year)) # Handle both int and float strings
|
| 32 |
+
if 1900 <= year_int <= 2030: # Reasonable year range
|
| 33 |
+
years.append(str(year_int))
|
| 34 |
+
else:
|
| 35 |
+
years.append('Unknown')
|
| 36 |
+
except (ValueError, TypeError):
|
| 37 |
+
years.append('Unknown')
|
| 38 |
+
else:
|
| 39 |
+
years.append('Unknown')
|
| 40 |
+
|
| 41 |
+
# Extract filename
|
| 42 |
+
filename = metadata.get('filename', 'Unknown')
|
| 43 |
+
filenames.append(filename)
|
| 44 |
+
|
| 45 |
+
# Extract district
|
| 46 |
+
district = metadata.get('district', 'Unknown')
|
| 47 |
+
if district and district != 'Unknown':
|
| 48 |
+
districts.append(district)
|
| 49 |
+
else:
|
| 50 |
+
districts.append('Unknown')
|
| 51 |
+
|
| 52 |
+
# Count occurrences
|
| 53 |
+
source_counts = Counter(sources_list)
|
| 54 |
+
year_counts = Counter(years)
|
| 55 |
+
filename_counts = Counter(filenames)
|
| 56 |
+
district_counts = Counter(districts)
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
'total_chunks': len(sources),
|
| 60 |
+
'unique_sources': len(source_counts),
|
| 61 |
+
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
|
| 62 |
+
'unique_filenames': len(filename_counts),
|
| 63 |
+
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
|
| 64 |
+
'source_distribution': dict(source_counts),
|
| 65 |
+
'year_distribution': dict(year_counts),
|
| 66 |
+
'filename_distribution': dict(filename_counts),
|
| 67 |
+
'district_distribution': dict(district_counts),
|
| 68 |
+
'sources': sources_list,
|
| 69 |
+
'years': years,
|
| 70 |
+
'filenames': filenames,
|
| 71 |
+
'districts': districts
|
| 72 |
+
}
|
| 73 |
+
|
utils.py β src/utils.py
RENAMED
|
File without changes
|
src/vectorstore.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
| 1 |
"""Vector store management and operations."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Dict, Any, List, Optional
|
| 4 |
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain_qdrant import QdrantVectorStore
|
| 8 |
from langchain.docstore.document import Document
|
| 9 |
from langchain_core.embeddings import Embeddings
|
|
@@ -28,11 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
|
|
| 28 |
|
| 29 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 34 |
else:
|
| 35 |
# Use standard HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 37 |
|
| 38 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -76,12 +99,17 @@ class VectorStoreManager:
|
|
| 76 |
|
| 77 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 78 |
"""Create embeddings model from configuration."""
|
| 79 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
-
|
| 81 |
model_name = self.config["retriever"]["model"]
|
| 82 |
normalize = self.config["retriever"]["normalize"]
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
encode_kwargs = {
|
| 86 |
"normalize_embeddings": normalize,
|
| 87 |
"batch_size": 100,
|
|
@@ -108,6 +136,8 @@ class VectorStoreManager:
|
|
| 108 |
return embeddings
|
| 109 |
|
| 110 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
|
|
|
|
|
|
| 111 |
embeddings = HuggingFaceEmbeddings(
|
| 112 |
model_name=model_name,
|
| 113 |
model_kwargs=model_kwargs,
|
|
|
|
| 1 |
"""Vector store management and operations."""
|
| 2 |
+
import os
|
| 3 |
+
# Disable MPS before importing torch to prevent meta tensor issues on Mac
|
| 4 |
+
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
| 5 |
+
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
|
| 6 |
+
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, Any, List, Optional
|
| 9 |
|
| 10 |
|
| 11 |
import torch
|
| 12 |
+
# Disable MPS backend explicitly to prevent meta tensor issues
|
| 13 |
+
if hasattr(torch.backends, 'mps'):
|
| 14 |
+
# Monkey patch to disable MPS
|
| 15 |
+
original_mps_available = torch.backends.mps.is_available
|
| 16 |
+
torch.backends.mps.is_available = lambda: False
|
| 17 |
+
|
| 18 |
from langchain_qdrant import QdrantVectorStore
|
| 19 |
from langchain.docstore.document import Document
|
| 20 |
from langchain_core.embeddings import Embeddings
|
|
|
|
| 39 |
|
| 40 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 41 |
# Use SentenceTransformer directly for Matryoshka models
|
| 42 |
+
# Fix for meta tensor issue: Explicitly force CPU
|
| 43 |
+
# MPS is already disabled at module level
|
| 44 |
+
# Explicitly pass device="cpu" to prevent MPS/CUDA detection
|
| 45 |
+
self.model = SentenceTransformer(
|
| 46 |
+
model_name,
|
| 47 |
+
truncate_dim=truncate_dim,
|
| 48 |
+
device="cpu" # Force CPU to prevent meta tensor issues
|
| 49 |
+
)
|
| 50 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 51 |
else:
|
| 52 |
# Use standard HuggingFaceEmbeddings
|
| 53 |
+
# Don't pass device parameter - let it load naturally on CPU
|
| 54 |
+
# This prevents the meta tensor error
|
| 55 |
+
if "model_kwargs" not in kwargs:
|
| 56 |
+
kwargs["model_kwargs"] = {}
|
| 57 |
+
# Remove device from model_kwargs if present to prevent meta tensor issues
|
| 58 |
+
kwargs["model_kwargs"].pop("device", None)
|
| 59 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 60 |
|
| 61 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
| 99 |
|
| 100 |
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 101 |
"""Create embeddings model from configuration."""
|
|
|
|
|
|
|
| 102 |
model_name = self.config["retriever"]["model"]
|
| 103 |
normalize = self.config["retriever"]["normalize"]
|
| 104 |
|
| 105 |
+
# Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
|
| 106 |
+
# The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
|
| 107 |
+
# MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
|
| 108 |
+
model_kwargs = {
|
| 109 |
+
"device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
|
| 110 |
+
"trust_remote_code": True, # Some models need this
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
encode_kwargs = {
|
| 114 |
"normalize_embeddings": normalize,
|
| 115 |
"batch_size": 100,
|
|
|
|
| 136 |
return embeddings
|
| 137 |
|
| 138 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
| 139 |
+
# Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
|
| 140 |
+
# but ensure we're not using meta device
|
| 141 |
embeddings = HuggingFaceEmbeddings(
|
| 142 |
model_name=model_name,
|
| 143 |
model_kwargs=model_kwargs,
|