.gitignore ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================
2
+ # PYTHON
3
+ # ==========================================
4
+ __pycache__/
5
+ *.py[cod]
6
+ *.pyo
7
+ *.pyd
8
+ *$py.class
9
+
10
+ # Virtual environments
11
+ .venv/
12
+ venv/
13
+ env/
14
+ ENV/
15
+ .conda/
16
+ .venv*/
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ *.so
20
+ *.dll
21
+ *.dylib
22
+
23
+ # Logs and debug
24
+ *.log
25
+ *.out
26
+ *.err
27
+ logs/
28
+ debug/
29
+ *.sqlite3
30
+
31
+ # ==========================================
32
+ # BUILD / PACKAGING
33
+ # ==========================================
34
+ build/
35
+ dist/
36
+ *.egg-info/
37
+ .eggs/
38
+ pip-wheel-metadata/
39
+ .wheels/
40
+
41
+ # ==========================================
42
+ # JUPYTER / NOTEBOOKS
43
+ # ==========================================
44
+ .ipynb_checkpoints/
45
+ *.ipynb_convert/
46
+
47
+ # ==========================================
48
+ # DATA / MODELS / CACHE
49
+ # ==========================================
50
+ data/
51
+ datasets/
52
+ .cache/
53
+ *.ckpt
54
+ *.h5
55
+ *.hdf5
56
+ *.tflite
57
+ *.onnx
58
+ *.pth
59
+ *.pt
60
+ *.joblib
61
+ *.pkl
62
+ *.pickle
63
+ *.npz
64
+ *.npy
65
+ outputs/
66
+ artifacts/
67
+ checkpoints/
68
+ runs/
69
+ wandb/
70
+ mlruns/
71
+ lightning_logs/
72
+
73
+ # Hugging Face
74
+ huggingface/
75
+ ~/.cache/huggingface/
76
+ ~/.cache/torch/
77
+ ~/.cache/datasets/
78
+ ~/.cache/transformers/
79
+
80
+ # ==========================================
81
+ # EDITORS / TOOLS
82
+ # ==========================================
83
+ .vscode/
84
+ .idea/
85
+ *.swp
86
+ *.swo
87
+ *.bak
88
+ .DS_Store
89
+ Thumbs.db
90
+
91
+ # ==========================================
92
+ # ENV FILES / CREDENTIALS
93
+ # ==========================================
94
+ .env
95
+ .env.*
96
+ *.env.local
97
+ secrets.*
98
+ config.json
99
+ token.json
100
+
101
+ # ==========================================
102
+ # TESTS / TEMP FILES
103
+ # ==========================================
104
+ __tests__/
105
+ .tox/
106
+ .coverage
107
+ .cache/
108
+ pytest_cache/
109
+ tmp/
110
+ temp/
111
+ *.tmp
112
+ *.temp
113
+
114
+
115
+ local_*
Dockerfile CHANGED
@@ -59,6 +59,7 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
59
  CMD curl --fail http://localhost:8501/_stcore/health || exit 1
60
 
61
  #temp developement commands
 
62
  # RUN mkdir /app/conversations && chmod -R 777 conversations
63
  # RUN mkdir /app/feedback && chmod -R 777 feedback
64
 
 
59
  CMD curl --fail http://localhost:8501/_stcore/health || exit 1
60
 
61
  #temp developement commands
62
+ RUN pip3 install plotly
63
  # RUN mkdir /app/conversations && chmod -R 777 conversations
64
  # RUN mkdir /app/feedback && chmod -R 777 feedback
65
 
add_district_metadata.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to add District metadata to Qdrant chunks based on filename analysis.
4
+ Handles Uganda districts, ministry mappings, and LLM inference for ambiguous cases.
5
+ """
6
+ import re
7
+ import yaml
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Optional
11
+
12
+
13
+ from qdrant_client import QdrantClient
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class DistrictMapping:
22
+ """Mapping for district-related entities"""
23
+ name: str
24
+ aliases: List[str]
25
+ is_district: bool = True
26
+
27
+
28
+ class DistrictMetadataProcessor:
29
+ def __init__(self, config_path: str = "src/config/settings.yaml"):
30
+ # Load config manually
31
+ with open(config_path, 'r') as f:
32
+ self.config = yaml.safe_load(f)
33
+
34
+ # Initialize Qdrant client (will be imported when needed)
35
+ self.llm_client = None
36
+ self.qdrant_client = None
37
+ self.collection_name = self.config["qdrant"]["collection_name"]
38
+
39
+ # Initialize district mappings
40
+ self.district_mappings = self._initialize_district_mappings()
41
+ self.ministry_mappings = self._initialize_ministry_mappings()
42
+
43
+ def _initialize_district_mappings(self) -> Dict[str, DistrictMapping]:
44
+ """Initialize Uganda districts and their aliases"""
45
+ districts = [
46
+ # Central Region
47
+ DistrictMapping("Kampala", ["KCCA", "Kampala Capital City Authority"]),
48
+ DistrictMapping("Wakiso", ["Wakiso"]),
49
+ DistrictMapping("Mukono", ["Mukono"]),
50
+ DistrictMapping("Luweero", ["Luweero"]),
51
+ DistrictMapping("Nakaseke", ["Nakaseke"]),
52
+ DistrictMapping("Nakasongola", ["Nakasongola"]),
53
+ DistrictMapping("Kayunga", ["Kayunga"]),
54
+ DistrictMapping("Buikwe", ["Buikwe"]),
55
+ DistrictMapping("Buvuma", ["Buvuma"]),
56
+
57
+ # Northern Region
58
+ DistrictMapping("Gulu", ["Gulu", "Gulu DLG"]),
59
+ DistrictMapping("Kitgum", ["Kitgum"]),
60
+ DistrictMapping("Pader", ["Pader"]),
61
+ DistrictMapping("Agago", ["Agago"]),
62
+ DistrictMapping("Lamwo", ["Lamwo"]),
63
+ DistrictMapping("Nwoya", ["Nwoya"]),
64
+ DistrictMapping("Amuru", ["Amuru"]),
65
+ DistrictMapping("Omoro", ["Omoro"]),
66
+ DistrictMapping("Oyam", ["Oyam"]),
67
+ DistrictMapping("Kole", ["Kole"]),
68
+ DistrictMapping("Apac", ["Apac", "Apac District"]),
69
+ DistrictMapping("Lira", ["Lira"]),
70
+ DistrictMapping("Alebtong", ["Alebtong"]),
71
+ DistrictMapping("Amolatar", ["Amolatar"]),
72
+ DistrictMapping("Dokolo", ["Dokolo"]),
73
+ DistrictMapping("Otuke", ["Otuke"]),
74
+ DistrictMapping("Kwania", ["Kwania"]),
75
+
76
+ # Eastern Region
77
+ DistrictMapping("Jinja", ["Jinja"]),
78
+ DistrictMapping("Kamuli", ["Kamuli"]),
79
+ DistrictMapping("Iganga", ["Iganga"]),
80
+ DistrictMapping("Bugiri", ["Bugiri"]),
81
+ DistrictMapping("Mayuge", ["Mayuge"]),
82
+ DistrictMapping("Namayingo", ["Namayingo"]),
83
+ DistrictMapping("Busia", ["Busia"]),
84
+ DistrictMapping("Tororo", ["Tororo"]),
85
+ DistrictMapping("Pallisa", ["Pallisa"]),
86
+ DistrictMapping("Kumi", ["Kumi"]),
87
+ DistrictMapping("Bukedea", ["Bukedea"]),
88
+ DistrictMapping("Soroti", ["Soroti"]),
89
+ DistrictMapping("Serere", ["Serere"]),
90
+ DistrictMapping("Ngora", ["Ngora"]),
91
+ DistrictMapping("Kaberamaido", ["Kaberamaido"]),
92
+ DistrictMapping("Kalaki", ["Kalaki"]),
93
+ DistrictMapping("Kapelebyong", ["Kapelebyong"]),
94
+ DistrictMapping("Amuria", ["Amuria"]),
95
+ DistrictMapping("Katakwi", ["Katakwi"]),
96
+ DistrictMapping("Kotido", ["Kotido"]),
97
+ DistrictMapping("Abim", ["Abim"]),
98
+ DistrictMapping("Kaabong", ["Kaabong", "Kaabong District"]),
99
+ DistrictMapping("Karenga", ["Karenga"]),
100
+ DistrictMapping("Moroto", ["Moroto"]),
101
+ DistrictMapping("Napak", ["Napak"]),
102
+ DistrictMapping("Nabilatuk", ["Nabilatuk"]),
103
+ DistrictMapping("Amudat", ["Amudat"]),
104
+ DistrictMapping("Nakapiripirit", ["Nakapiripirit"]),
105
+ DistrictMapping("Bukwo", ["Bukwo"]),
106
+ DistrictMapping("Kween", ["Kween"]),
107
+ DistrictMapping("Kapchorwa", ["Kapchorwa"]),
108
+ DistrictMapping("Sironko", ["Sironko"]),
109
+ DistrictMapping("Manafwa", ["Manafwa"]),
110
+ DistrictMapping("Bududa", ["Bududa"]),
111
+ DistrictMapping("Mbale", ["Mbale"]),
112
+ DistrictMapping("Butaleja", ["Butaleja"]),
113
+ DistrictMapping("Namisindwa", ["Namisindwa"]),
114
+ DistrictMapping("Bulambuli", ["Bulambuli"]),
115
+
116
+ # Western Region
117
+ DistrictMapping("Masaka", ["Masaka"]),
118
+ DistrictMapping("Kalungu", ["Kalungu"]),
119
+ DistrictMapping("Bukomansimbi", ["Bukomansimbi"]),
120
+ DistrictMapping("Lwengo", ["Lwengo"]),
121
+ DistrictMapping("Sembabule", ["Sembabule"]),
122
+ DistrictMapping("Rakai", ["Rakai"]),
123
+ DistrictMapping("Kyotera", ["Kyotera"]),
124
+ DistrictMapping("Mpigi", ["Mpigi"]),
125
+ DistrictMapping("Butambala", ["Butambala"]),
126
+ DistrictMapping("Gomba", ["Gomba"]),
127
+ DistrictMapping("Mityana", ["Mityana"]),
128
+ DistrictMapping("Mubende", ["Mubende"]),
129
+ DistrictMapping("Kassanda", ["Kassanda"]),
130
+ DistrictMapping("Kiboga", ["Kiboga"]),
131
+ DistrictMapping("Kyankwanzi", ["Kyankwanzi"]),
132
+ DistrictMapping("Hoima", ["Hoima"]),
133
+ DistrictMapping("Kikuube", ["Kikuube"]),
134
+ DistrictMapping("Kakumiro", ["Kakumiro"]),
135
+ DistrictMapping("Kibaale", ["Kibaale"]),
136
+ DistrictMapping("Kagadi", ["Kagadi"]),
137
+ DistrictMapping("Buliisa", ["Buliisa"]),
138
+ DistrictMapping("Masindi", ["Masindi"]),
139
+ DistrictMapping("Kiryandongo", ["Kiryandongo"]),
140
+ DistrictMapping("Buliisa", ["Buliisa"]),
141
+ DistrictMapping("Pakwach", ["Pakwach"]),
142
+ DistrictMapping("Nebbi", ["Nebbi"]),
143
+ DistrictMapping("Zombo", ["Zombo"]),
144
+ DistrictMapping("Arua", ["Arua"]),
145
+ DistrictMapping("Terego", ["Terego"]),
146
+ DistrictMapping("Madi-Okollo", ["Madi-Okollo"]),
147
+ DistrictMapping("Obongi", ["Obongi"]),
148
+ DistrictMapping("Moyo", ["Moyo"]),
149
+ DistrictMapping("Yumbe", ["Yumbe"]),
150
+ DistrictMapping("Koboko", ["Koboko"]),
151
+ DistrictMapping("Maracha", ["Maracha"]),
152
+ DistrictMapping("Adjumani", ["Adjumani"]),
153
+
154
+ # South Western Region
155
+ DistrictMapping("Mbarara", ["Mbarara"]),
156
+ DistrictMapping("Ibanda", ["Ibanda"]),
157
+ DistrictMapping("Isingiro", ["Isingiro"]),
158
+ DistrictMapping("Kiruhura", ["Kiruhura"]),
159
+ DistrictMapping("Kazo", ["Kazo"]),
160
+ DistrictMapping("Ntungamo", ["Ntungamo"]),
161
+ DistrictMapping("Rwampara", ["Rwampara"]),
162
+ DistrictMapping("Rubanda", ["Rubanda"]),
163
+ DistrictMapping("Rukiga", ["Rukiga"]),
164
+ DistrictMapping("Kanungu", ["Kanungu"]),
165
+ DistrictMapping("Rukungiri", ["Rukungiri"]),
166
+ DistrictMapping("Kisoro", ["Kisoro"]),
167
+ DistrictMapping("Bundibugyo", ["Bundibugyo"]),
168
+ DistrictMapping("Ntoroko", ["Ntoroko"]),
169
+ DistrictMapping("Kasese", ["Kasese"]),
170
+ DistrictMapping("Bunyangabu", ["Bunyangabu"]),
171
+ DistrictMapping("Fort Portal", ["Fort Portal"]),
172
+ DistrictMapping("Kabarole", ["Kabarole"]),
173
+ DistrictMapping("Kyenjojo", ["Kyenjojo"]),
174
+ DistrictMapping("Kamwenge", ["Kamwenge"]),
175
+ DistrictMapping("Kitagwenda", ["Kitagwenda"]),
176
+ DistrictMapping("Kyegegwa", ["Kyegegwa"]),
177
+ DistrictMapping("Mitooma", ["Mitooma"]),
178
+ DistrictMapping("Rubirizi", ["Rubirizi"]),
179
+ DistrictMapping("Sheema", ["Sheema"]),
180
+ DistrictMapping("Bushenyi", ["Bushenyi"]),
181
+
182
+ # Special cases
183
+ DistrictMapping("Kalangala", ["Kalangala", "Kalangala DLG"]),
184
+ ]
185
+
186
+ # Create mapping dictionary
187
+ mapping_dict = {}
188
+ for district in districts:
189
+ mapping_dict[district.name.lower()] = district
190
+ for alias in district.aliases:
191
+ mapping_dict[alias.lower()] = district
192
+ return mapping_dict
193
+
194
+ def _initialize_ministry_mappings(self) -> Dict[str, str]:
195
+ """Initialize ministry and organization mappings"""
196
+ return {
197
+ "maaif": "Ministry of Agriculture, Animal Industry and Fisheries",
198
+ "mwts": "Ministry of Works and Transport",
199
+ "kcca": "Kampala Capital City Authority",
200
+ "oag": "Office of the Auditor General",
201
+ "arsdp": "Albertine Regional Sustainable Development Project",
202
+ "avcdp": "Agriculture Value Chain Development Project",
203
+ "ida": "International Development Association",
204
+ "dlg": "District Local Government",
205
+ "lg": "Local Government",
206
+ }
207
+
208
+ def _extract_district_from_filename(self, filename: str) -> Optional[str]:
209
+ """Extract district from filename using pattern matching"""
210
+ filename_lower = filename.lower()
211
+
212
+ # Check for explicit district mentions
213
+ for key, district_mapping in self.district_mappings.items():
214
+ if key in filename_lower:
215
+ return district_mapping.name
216
+
217
+ # Check for ministry/organization patterns that are NOT districts
218
+ for ministry_key in self.ministry_mappings.keys():
219
+ if ministry_key in filename_lower:
220
+ return None # This is a ministry, not a district
221
+
222
+ # Check for patterns like "District Local Government"
223
+ district_pattern = r'(\w+)\s+district\s+local\s+government'
224
+ match = re.search(district_pattern, filename_lower)
225
+ if match:
226
+ district_name = match.group(1).title()
227
+ if district_name.lower() in self.district_mappings:
228
+ return self.district_mappings[district_name.lower()].name
229
+
230
+ # Check for patterns like "DLG Report"
231
+ dlg_pattern = r'(\w+)\s+dlg\s+report'
232
+ match = re.search(dlg_pattern, filename_lower)
233
+ if match:
234
+ district_name = match.group(1).title()
235
+ if district_name.lower() in self.district_mappings:
236
+ return self.district_mappings[district_name.lower()].name
237
+
238
+ return None
239
+
240
+ def _infer_district_with_llm(self, filename: str) -> Optional[str]:
241
+ """Use LLM to infer district from filename when pattern matching fails"""
242
+ # For now, return None - LLM integration can be added later
243
+ logger.info(f"LLM inference needed for filename: {filename}")
244
+ return None
245
+
246
+ def infer_district(self, filename: str) -> Optional[str]:
247
+ """Main method to infer district from filename"""
248
+ # First try pattern matching
249
+ district = self._extract_district_from_filename(filename)
250
+ if district:
251
+ return district
252
+
253
+ # If pattern matching fails, use LLM
254
+ return self._infer_district_with_llm(filename)
255
+
256
+ def fetch_chunks_batch(self, batch_size: int = 100, offset: int = 0) -> List[Dict]:
257
+ """Fetch a batch of chunks from Qdrant (metadata only)"""
258
+ try:
259
+ # Import Qdrant client when needed
260
+ if self.qdrant_client is None:
261
+ self.qdrant_client = QdrantClient(
262
+ url=self.config["qdrant"]["url"],
263
+ api_key=self.config["qdrant"]["api_key"]
264
+ )
265
+
266
+ # Get points with metadata only (no vectors)
267
+ points = self.qdrant_client.scroll(
268
+ collection_name=self.collection_name,
269
+ limit=batch_size,
270
+ offset=offset,
271
+ with_payload=True,
272
+ with_vectors=False
273
+ )[0]
274
+
275
+ return points
276
+ except Exception as e:
277
+ logger.error(f"Failed to fetch batch: {e}")
278
+ return []
279
+
280
+ def update_chunks_with_district(self, points: List[Dict]) -> int:
281
+ """Update chunks with district metadata"""
282
+ updated_count = 0
283
+
284
+ # Import Qdrant client when needed
285
+ if self.qdrant_client is None:
286
+ from qdrant_client import QdrantClient
287
+ self.qdrant_client = QdrantClient(
288
+ url=self.config["qdrant"]["url"],
289
+ api_key=self.config["qdrant"]["api_key"]
290
+ )
291
+
292
+ for point in points:
293
+ try:
294
+ point_id = point.id
295
+ metadata = point.payload.get("metadata", {})
296
+ filename = metadata.get("filename", "")
297
+
298
+ if not filename:
299
+ logger.warning(f"Point {point_id} has no filename")
300
+ continue
301
+
302
+ # Infer district
303
+ district = self.infer_district(filename)
304
+
305
+ # Update metadata
306
+ updated_metadata = metadata.copy()
307
+ updated_metadata["district"] = district
308
+
309
+ # Update point in Qdrant
310
+ self.qdrant_client.set_payload(
311
+ collection_name=self.collection_name,
312
+ payload={"metadata": updated_metadata},
313
+ points=[point_id]
314
+ )
315
+
316
+ updated_count += 1
317
+ logger.info(f"Updated point {point_id}: {filename} -> {district}")
318
+
319
+ except Exception as e:
320
+ logger.error(f"Failed to update point {point_id}: {e}")
321
+
322
+ return updated_count
323
+
324
+ def process_all_chunks(self, batch_size: int = 100):
325
+ """Process all chunks in batches"""
326
+ total_updated = 0
327
+ offset = 0
328
+
329
+ logger.info(f"Starting to process chunks in batches of {batch_size}")
330
+
331
+ while True:
332
+ # Fetch batch
333
+ points = self.fetch_chunks_batch(batch_size, offset)
334
+ if not points:
335
+ break
336
+
337
+ logger.info(f"Processing batch: {len(points)} points (offset: {offset})")
338
+
339
+ # Update batch
340
+ updated_count = self.update_chunks_with_district(points)
341
+ total_updated += updated_count
342
+
343
+ logger.info(f"Updated {updated_count} points in this batch")
344
+
345
+ # Move to next batch
346
+ offset += batch_size
347
+
348
+ logger.info(f"Total updated: {total_updated} points")
349
+ return total_updated
350
+
351
+ def main():
352
+ """Main function to run the district metadata processor"""
353
+ try:
354
+ processor = DistrictMetadataProcessor()
355
+
356
+ # Test with a small batch first
357
+ logger.info("Testing with first 10 chunks...")
358
+ test_points = processor.fetch_chunks_batch(10, 0)
359
+
360
+ if test_points:
361
+ logger.info("Test batch fetched successfully. Processing...")
362
+ for point in test_points:
363
+ filename = point.payload.get("metadata", {}).get("filename", "")
364
+ district = processor.infer_district(filename)
365
+ logger.info(f"Test: {filename} -> {district}")
366
+
367
+ # Ask user if they want to proceed with full processing
368
+ response = input("\nProceed with full processing? (y/n): ")
369
+ if response.lower() == 'y':
370
+ processor.process_all_chunks(batch_size=100)
371
+ else:
372
+ logger.info("Processing cancelled by user")
373
+
374
+ except Exception as e:
375
+ logger.error(f"Error in main: {e}")
376
+ raise
377
+
378
+ if __name__ == "__main__":
379
+ main()
app.py CHANGED
@@ -3,7 +3,36 @@ Intelligent Audit Report Chatbot UI
3
  """
4
 
5
  import os
6
- import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
9
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
@@ -29,42 +58,33 @@ except (ValueError, TypeError):
29
 
30
  # ===== Setup HuggingFace cache directories BEFORE any model imports =====
31
  # CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
32
- # This ensures models downloaded during Docker build are found at runtime
33
- cache_dir = "/app/.cache/huggingface"
34
- os.environ["HF_HOME"] = cache_dir
35
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
36
- os.environ["HF_DATASETS_CACHE"] = cache_dir
37
- os.environ["HF_HUB_CACHE"] = cache_dir
38
- os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
39
-
40
- # Ensure cache directory exists (created in Dockerfile, but ensure it's there)
41
- try:
42
- os.makedirs(cache_dir, mode=0o755, exist_ok=True)
43
- except (PermissionError, OSError) as e:
44
- # If we can't create it, log but continue (might already exist from Dockerfile)
45
- # HuggingFace will try to create subdirectories, but we need parent to exist
46
- pass
47
-
48
- import time
49
- import json
50
- import uuid
51
- import logging
52
- from pathlib import Path
53
-
54
- import argparse
55
- import streamlit as st
56
- from langchain_core.messages import HumanMessage, AIMessage
57
-
58
- from multi_agent_chatbot import get_multi_agent_chatbot
59
- from smart_chatbot import get_chatbot as get_smart_chatbot
60
- from src.reporting.feedback_schema import create_feedback_from_dict
61
 
62
  # Configure logging
63
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
64
  logger = logging.getLogger(__name__)
65
 
66
  # Log environment setup for debugging
67
- logger.info(f"πŸ“ HuggingFace cache directory: {os.environ.get('HF_HOME', 'NOT SET')}")
 
 
68
  logger.info(f"πŸ”§ OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
69
 
70
 
@@ -76,84 +96,9 @@ st.set_page_config(
76
  page_title="Intelligent Audit Report Chatbot"
77
  )
78
 
79
- # Custom CSS
80
- st.markdown("""
81
- <style>
82
- .main-header {
83
- font-size: 2.5rem;
84
- font-weight: bold;
85
- color: #1f77b4;
86
- text-align: center;
87
- margin-bottom: 1rem;
88
- }
89
-
90
- .subtitle {
91
- font-size: 1.2rem;
92
- color: #666;
93
- text-align: center;
94
- margin-bottom: 2rem;
95
- }
96
-
97
- .session-info {
98
- background-color: #f0f2f6;
99
- padding: 10px;
100
- border-radius: 5px;
101
- margin-bottom: 20px;
102
- font-size: 0.9rem;
103
- }
104
-
105
- .user-message {
106
- background-color: #007bff;
107
- color: white;
108
- padding: 12px 16px;
109
- border-radius: 18px 18px 4px 18px;
110
- margin: 8px 0;
111
- margin-left: 20%;
112
- word-wrap: break-word;
113
- }
114
-
115
- .bot-message {
116
- background-color: #f1f3f4;
117
- color: #333;
118
- padding: 12px 16px;
119
- border-radius: 18px 18px 18px 4px;
120
- margin: 8px 0;
121
- margin-right: 20%;
122
- word-wrap: break-word;
123
- border: 1px solid #e0e0e0;
124
- }
125
-
126
- .filter-section {
127
- margin-bottom: 20px;
128
- padding: 15px;
129
- background-color: #f8f9fa;
130
- border-radius: 8px;
131
- border: 1px solid #e9ecef;
132
- }
133
-
134
- .filter-title {
135
- font-weight: bold;
136
- margin-bottom: 10px;
137
- color: #495057;
138
- }
139
-
140
- .feedback-section {
141
- background-color: #f8f9fa;
142
- padding: 20px;
143
- border-radius: 10px;
144
- margin-top: 30px;
145
- border: 2px solid #dee2e6;
146
- }
147
-
148
- .retrieval-history {
149
- background-color: #ffffff;
150
- padding: 15px;
151
- border-radius: 5px;
152
- margin: 10px 0;
153
- border-left: 4px solid #007bff;
154
- }
155
- </style>
156
- """, unsafe_allow_html=True)
157
 
158
  def get_system_type():
159
  """Get the current system type"""
@@ -163,14 +108,17 @@ def get_system_type():
163
  else:
164
  return "Multi-Agent System"
165
 
166
- def get_chatbot():
167
- """Initialize and return the chatbot based on system type"""
168
- # Check environment variable for system type
169
- system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
170
- if system == 'smart':
171
- return get_smart_chatbot()
172
  else:
173
- return get_multi_agent_chatbot()
 
 
 
 
 
174
 
175
  def serialize_messages(messages):
176
  """Serialize LangChain messages to dictionaries"""
@@ -215,13 +163,18 @@ def serialize_documents(sources):
215
 
216
  return serialized
217
 
 
 
 
 
218
  @st.cache_data
219
  def load_filter_options():
220
  try:
221
- with open("src/config/filter_options.json", "r") as f:
 
222
  return json.load(f)
223
  except FileNotFoundError:
224
- st.info([x for x in os.listdir() if x.endswith('.json')])
225
  st.error("filter_options.json not found. Please run the metadata analysis script.")
226
  return {"sources": [], "years": [], "districts": [], 'filenames': []}
227
 
@@ -238,11 +191,48 @@ def main():
238
  # Track RAG retrieval history for feedback
239
  if 'rag_retrieval_history' not in st.session_state:
240
  st.session_state.rag_retrieval_history = []
241
- # Initialize chatbot only once per app session (cached)
242
- if 'chatbot' not in st.session_state:
243
- with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
244
- st.session_state.chatbot = get_chatbot()
245
- st.success("βœ… AI system ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  # Reset conversation history if needed (but keep chatbot cached)
248
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
@@ -254,17 +244,43 @@ def main():
254
  st.session_state.reset_conversation = False
255
  st.rerun()
256
 
257
- # Header with system indicator
 
258
  col1, col2 = st.columns([3, 1])
259
  with col1:
260
- st.markdown('<h1 class="main-header">πŸ€– Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
261
  with col2:
262
- system_type = get_system_type()
263
- if "Multi-Agent" in system_type:
264
- st.success(f"πŸ”§ {system_type}")
265
- else:
266
- st.info(f"πŸ”§ {system_type}")
267
- st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # Session info
270
  duration = int(time.time() - st.session_state.session_start_time)
@@ -280,6 +296,34 @@ def main():
280
 
281
  # Sidebar for filters
282
  with st.sidebar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  st.markdown("### πŸ” Search Filters")
284
  st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
285
 
@@ -298,7 +342,7 @@ def main():
298
  # Determine if filename filter is active
299
  filename_mode = len(selected_filenames) > 0
300
  # Sources filter
301
- st.markdown('<div class="filter-section">', unsafe_allow_html=True)
302
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
303
  selected_sources = st.multiselect(
304
  "Select sources:",
@@ -311,7 +355,7 @@ def main():
311
  st.markdown('</div>', unsafe_allow_html=True)
312
 
313
  # Years filter
314
- st.markdown('<div class="filter-section">', unsafe_allow_html=True)
315
  st.markdown('<div class="filter-title">πŸ“… Years</div>', unsafe_allow_html=True)
316
  selected_years = st.multiselect(
317
  "Select years:",
@@ -324,7 +368,7 @@ def main():
324
  st.markdown('</div>', unsafe_allow_html=True)
325
 
326
  # Districts filter
327
- st.markdown('<div class="filter-section">', unsafe_allow_html=True)
328
  st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True)
329
  selected_districts = st.multiselect(
330
  "Select districts:",
@@ -375,26 +419,37 @@ def main():
375
  if 'input_counter' not in st.session_state:
376
  st.session_state.input_counter = 0
377
 
 
 
 
 
 
 
 
 
 
 
 
378
  user_input = st.text_input(
379
  "Type your message here...",
380
  placeholder="Ask about budget allocations, expenditures, or audit findings...",
381
- key=f"user_input_{st.session_state.input_counter}",
382
- label_visibility="collapsed"
 
383
  )
384
 
385
  with col2:
386
- send_button = st.button("Send", key="send_button", use_container_width=True)
387
 
388
  # Clear chat button
389
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
390
  st.session_state.reset_conversation = True
391
  # Clear all conversation files
392
- import os
393
- conversations_dir = "conversations"
394
- if os.path.exists(conversations_dir):
395
- for file in os.listdir(conversations_dir):
396
- if file.endswith('.json'):
397
- os.remove(os.path.join(conversations_dir, file))
398
  st.rerun()
399
 
400
  # Handle user input
@@ -436,6 +491,36 @@ def main():
436
  if rag_result:
437
  sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  # Get the actual RAG query
440
  actual_rag_query = chat_result.get('actual_rag_query', '')
441
  if actual_rag_query:
@@ -445,12 +530,25 @@ def main():
445
  else:
446
  formatted_query = "No RAG query available"
447
 
 
 
 
 
 
 
 
 
448
  retrieval_entry = {
449
  "conversation_up_to": serialize_messages(st.session_state.messages),
450
  "rag_query_expansion": formatted_query,
451
- "docs_retrieved": serialize_documents(sources)
 
 
452
  }
453
  st.session_state.rag_retrieval_history.append(retrieval_entry)
 
 
 
454
  else:
455
  response = chat_result
456
  st.session_state.last_rag_result = None
@@ -480,6 +578,16 @@ def main():
480
  # Dictionary format from multi-agent system
481
  sources = rag_result['sources']
482
 
 
 
 
 
 
 
 
 
 
 
483
  if sources and len(sources) > 0:
484
  # Count unique filenames
485
  unique_filenames = set()
@@ -487,16 +595,40 @@ def main():
487
  filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
488
  unique_filenames.add(filename)
489
 
490
- st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 10):**")
491
  if len(unique_filenames) < len(sources):
492
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
493
 
494
- for i, doc in enumerate(sources[:10]): # Show top 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  # Get relevance score and ID if available
496
  metadata = getattr(doc, 'metadata', {})
497
- score = metadata.get('reranked_score', metadata.get('original_score', None))
498
- chunk_id = metadata.get('_id', 'Unknown')
499
- score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
 
 
 
 
 
 
 
 
 
500
 
501
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
502
  # Display document metadata with emojis
@@ -543,200 +675,409 @@ def main():
543
  if 'feedback_submitted' not in st.session_state:
544
  st.session_state.feedback_submitted = False
545
 
546
- # Feedback form
547
- with st.form("feedback_form", clear_on_submit=False):
548
- col1, col2 = st.columns([1, 1])
549
-
550
- with col1:
551
- feedback_score = st.slider(
552
- "Rate this conversation (1-5)",
553
- min_value=1,
554
- max_value=5,
555
- help="How satisfied are you with the conversation?"
556
- )
557
-
558
- with col2:
559
- is_feedback_about_last_retrieval = st.checkbox(
560
- "Feedback about last retrieval only",
561
- value=True,
562
- help="If checked, feedback applies to the most recent document retrieval"
563
- )
564
-
565
- open_ended_feedback = st.text_area(
566
- "Your feedback (optional)",
567
- placeholder="Tell us what went well or what could be improved...",
568
- height=100
569
- )
570
-
571
- # Disable submit if no score selected
572
- submit_disabled = feedback_score is None
573
-
574
- submitted = st.form_submit_button(
575
- "πŸ“€ Submit Feedback",
576
- use_container_width=True,
577
- disabled=submit_disabled
578
- )
579
-
580
- if submitted and not st.session_state.feedback_submitted:
581
- # Log the feedback data being submitted
582
- print("=" * 80)
583
- print("πŸ”„ FEEDBACK SUBMISSION: Starting...")
584
- print("=" * 80)
585
- st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
586
 
587
- # Create feedback data dictionary
588
- feedback_dict = {
589
- "open_ended_feedback": open_ended_feedback,
590
- "score": feedback_score,
591
- "is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
592
- "retrieved_data": st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
593
- "conversation_id": st.session_state.conversation_id,
594
- "timestamp": time.time(),
595
- "message_count": len(st.session_state.messages),
596
- "has_retrievals": has_retrievals,
597
- "retrieval_count": len(st.session_state.rag_retrieval_history)
598
- }
599
 
600
- print(f"πŸ“ FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
 
 
 
 
 
601
 
602
- # Create UserFeedback dataclass instance
603
- feedback_obj = None # Initialize outside try block
604
- try:
605
- feedback_obj = create_feedback_from_dict(feedback_dict)
606
- print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
607
- st.write(f"βœ… **Feedback Object Created**")
608
- st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
609
- st.write(f"- Score: {feedback_obj.score}/5")
610
- st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
611
-
612
- # Convert back to dict for JSON serialization
613
- feedback_data = feedback_obj.to_dict()
614
- except Exception as e:
615
- print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
616
- st.error(f"Failed to create feedback object: {e}")
617
- feedback_data = feedback_dict
618
-
619
- # Display the data being submitted
620
- st.json(feedback_data)
621
 
622
- # Save feedback to file - use absolute path in /app to ensure writability
623
- feedback_dir = Path("/app/feedback")
624
- try:
625
- # Ensure directory exists with write permissions (777 for compatibility)
626
- feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
627
- except (PermissionError, OSError) as e:
628
- logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
629
- # Fallback to relative path
630
- feedback_dir = Path("feedback")
631
- feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
632
 
633
- feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
 
 
 
 
634
 
635
- try:
636
- # Ensure parent directory exists before writing
637
- feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
- # Save to local file
640
- print(f"πŸ’Ύ FEEDBACK SAVE: Saving to local file: {feedback_file}")
641
- with open(feedback_file, 'w') as f:
642
- json.dump(feedback_data, f, indent=2, default=str)
643
 
644
- print(f"βœ… FEEDBACK SAVE: Local file saved successfully")
645
- st.success("βœ… Thank you for your feedback! It has been saved locally.")
646
- st.balloons()
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
- # Save to Snowflake if enabled and credentials available
649
- logger.info("πŸ”„ FEEDBACK SAVE: Starting Snowflake save process...")
650
- logger.info(f"πŸ“Š FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
651
 
 
 
652
  try:
653
- import os
654
- snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
655
- logger.info(f"πŸ” SNOWFLAKE CHECK: enabled={snowflake_enabled}")
 
 
 
656
 
657
- if snowflake_enabled:
658
- if feedback_obj:
659
- try:
660
- from src.reporting.snowflake_connector import save_to_snowflake
661
- logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
662
- print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
663
-
664
- if save_to_snowflake(feedback_obj):
665
- logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
666
- print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
667
- st.success("βœ… Feedback also saved to Snowflake!")
668
- else:
669
- logger.warning("⚠️ SNOWFLAKE UI: Save failed")
670
- print("⚠️ SNOWFLAKE UI: Save failed") # Also print to terminal
671
- st.warning("⚠️ Snowflake save failed, but local save succeeded")
672
- except Exception as e:
673
- logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
674
- print(f"❌ SNOWFLAKE UI ERROR: {e}") # Also print to terminal
675
- import traceback
676
- traceback.print_exc() # Print full traceback to terminal
677
- st.warning(f"⚠️ Could not save to Snowflake: {e}")
678
- else:
679
- logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
680
- print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
681
- st.warning("⚠️ Skipping Snowflake save (feedback object not created)")
682
- else:
683
- logger.info("πŸ’‘ SNOWFLAKE UI: Integration disabled")
684
- print("πŸ’‘ SNOWFLAKE UI: Integration disabled") # Also print to terminal
685
- st.info("πŸ’‘ Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
686
- except NameError as e:
687
- import traceback
688
- traceback.print_exc()
689
- logger.error(f"❌ NameError in Snowflake save: {e}")
690
- print(f"❌ NameError in Snowflake save: {e}") # Also print to terminal
691
- st.warning(f"⚠️ Snowflake save error: {e}")
692
  except Exception as e:
693
- logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
694
- print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") # Also print to terminal
695
- st.warning(f"⚠️ Snowflake save error: {e}")
696
 
697
- # Mark feedback as submitted to prevent resubmission
698
- st.session_state.feedback_submitted = True
699
 
700
- print("=" * 80)
701
- print(f"βœ… FEEDBACK SUBMISSION: Completed successfully")
702
- print("=" * 80)
 
 
 
 
 
 
 
703
 
704
- # Log file location
705
- st.info(f"πŸ“ Feedback saved to: {feedback_file}")
706
 
707
- except Exception as e:
708
- print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
709
- print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
710
- import traceback
711
- traceback.print_exc()
712
- st.error(f"❌ Error saving feedback: {e}")
713
- st.write(f"Debug error: {str(e)}")
714
-
715
- elif st.session_state.feedback_submitted:
716
- st.success("βœ… Feedback already submitted for this conversation!")
717
- if st.button("πŸ”„ Submit New Feedback", key="new_feedback_button"):
718
- st.session_state.feedback_submitted = False
719
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
 
721
  # Display retrieval history stats
722
  if st.session_state.rag_retrieval_history:
723
  st.markdown("---")
724
  st.markdown("#### πŸ“Š Retrieval History")
725
 
726
- with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
727
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
728
- st.markdown(f"**Retrieval #{idx}**")
 
 
 
 
 
729
 
730
  # Display the actual RAG query
731
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
 
732
  st.code(rag_query_expansion, language="text")
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  # Display summary stats
 
735
  st.json({
736
- "conversation_length": len(entry.get("conversation_up_to", [])),
737
- "documents_retrieved": len(entry.get("docs_retrieved", []))
738
  })
739
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  # Auto-scroll to bottom
742
  st.markdown("""
@@ -745,5 +1086,32 @@ def main():
745
  </script>
746
  """, unsafe_allow_html=True)
747
 
 
748
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
  main()
 
3
  """
4
 
5
  import os
6
+
7
+ import time
8
+ import json
9
+ import uuid
10
+ import logging
11
+ import traceback
12
+ from pathlib import Path
13
+
14
+ from collections import Counter
15
+ from typing import List, Dict, Any, Optional
16
+
17
+
18
+ import pandas as pd
19
+ import streamlit as st
20
+ import plotly.express as px
21
+ from langchain_core.messages import HumanMessage, AIMessage
22
+
23
+
24
+ from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
25
+ from src.feedback import FeedbackManager
26
+ from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
27
+
28
+ from src.config.paths import (
29
+ IS_DEPLOYED,
30
+ PROJECT_DIR,
31
+ HF_CACHE_DIR,
32
+ FEEDBACK_DIR,
33
+ CONVERSATIONS_DIR,
34
+ )
35
+
36
 
37
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
38
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
 
58
 
59
  # ===== Setup HuggingFace cache directories BEFORE any model imports =====
60
  # CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
61
+ # Only override cache directories in deployed environment (local uses defaults)
62
+ if IS_DEPLOYED and HF_CACHE_DIR:
63
+ cache_dir = str(HF_CACHE_DIR)
64
+ os.environ["HF_HOME"] = cache_dir
65
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
66
+ os.environ["HF_DATASETS_CACHE"] = cache_dir
67
+ os.environ["HF_HUB_CACHE"] = cache_dir
68
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
69
+
70
+ # Ensure cache directory exists (created in Dockerfile, but ensure it's there)
71
+ try:
72
+ os.makedirs(cache_dir, mode=0o755, exist_ok=True)
73
+ except (PermissionError, OSError):
74
+ # If we can't create it, log but continue (might already exist from Dockerfile)
75
+ pass
76
+ else:
77
+ from dotenv import load_dotenv
78
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # Configure logging
81
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
82
  logger = logging.getLogger(__name__)
83
 
84
  # Log environment setup for debugging
85
+ logger.info(f"🌍 Environment: {'DEPLOYED' if IS_DEPLOYED else 'LOCAL'}")
86
+ logger.info(f"πŸ“ PROJECT_DIR: {PROJECT_DIR}")
87
+ logger.info(f"πŸ“ HuggingFace cache: {os.environ.get('HF_HOME', 'DEFAULT (not overridden)')}")
88
  logger.info(f"πŸ”§ OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
89
 
90
 
 
96
  page_title="Intelligent Audit Report Chatbot"
97
  )
98
 
99
+
100
+ st.markdown(get_custom_css(), unsafe_allow_html=True)
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def get_system_type():
104
  """Get the current system type"""
 
108
  else:
109
  return "Multi-Agent System"
110
 
111
+ def get_chatbot(version: str = "v1"):
112
+ """Initialize and return the chatbot based on version"""
113
+ if version == "beta":
114
+ return get_gemini_chatbot()
 
 
115
  else:
116
+ # Check environment variable for system type (v1)
117
+ system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
118
+ if system == 'smart':
119
+ return get_smart_chatbot()
120
+ else:
121
+ return get_multi_agent_chatbot()
122
 
123
  def serialize_messages(messages):
124
  """Serialize LangChain messages to dictionaries"""
 
163
 
164
  return serialized
165
 
166
+
167
+ feedback_manager = FeedbackManager()
168
+
169
+
170
  @st.cache_data
171
  def load_filter_options():
172
  try:
173
+ filter_options_path = PROJECT_DIR / "src" / "config" / "filter_options.json"
174
+ with open(filter_options_path, "r") as f:
175
  return json.load(f)
176
  except FileNotFoundError:
177
+ st.info(f"Looking for filter_options.json in: {PROJECT_DIR / 'src' / 'config'}")
178
  st.error("filter_options.json not found. Please run the metadata analysis script.")
179
  return {"sources": [], "years": [], "districts": [], 'filenames': []}
180
 
 
191
  # Track RAG retrieval history for feedback
192
  if 'rag_retrieval_history' not in st.session_state:
193
  st.session_state.rag_retrieval_history = []
194
+ # Version selection (v1 or beta)
195
+ if 'chatbot_version' not in st.session_state:
196
+ st.session_state.chatbot_version = "v1"
197
+
198
+ # Initialize chatbot based on version (only if not already initialized for this version)
199
+ chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
200
+
201
+ # Check if we need to initialize: chatbot doesn't exist OR version changed
202
+ needs_init = (
203
+ chatbot_version_key not in st.session_state or
204
+ st.session_state.get('_last_version') != st.session_state.chatbot_version
205
+ )
206
+
207
+ if needs_init:
208
+ try:
209
+ # Different spinner messages for different versions
210
+ if st.session_state.chatbot_version == "beta":
211
+ spinner_msg = "πŸ”„ Initializing Gemini FSA"
212
+ else:
213
+ spinner_msg = "πŸ”„ Loading AI models and connecting to database..."
214
+
215
+ with st.spinner(spinner_msg):
216
+ st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
217
+ st.session_state['_last_version'] = st.session_state.chatbot_version
218
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
219
+ print("βœ… AI system ready!")
220
+ except Exception as e:
221
+ st.error(f"❌ Failed to initialize chatbot: {str(e)}")
222
+ # Only show Gemini-specific error message for beta version
223
+ if st.session_state.chatbot_version == "beta":
224
+ st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
225
+ else:
226
+ st.error("Please check your configuration and ensure all required models and databases are accessible.")
227
+ # Reset to v1 to prevent infinite loop
228
+ st.session_state.chatbot_version = "v1"
229
+ st.session_state['_last_version'] = "v1"
230
+ if 'chatbot' in st.session_state:
231
+ del st.session_state['chatbot']
232
+ st.stop() # Stop execution to prevent infinite loop
233
+ else:
234
+ # Chatbot already initialized for this version, just use it
235
+ st.session_state.chatbot = st.session_state[chatbot_version_key]
236
 
237
  # Reset conversation history if needed (but keep chatbot cached)
238
  if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
 
244
  st.session_state.reset_conversation = False
245
  st.rerun()
246
 
247
+
248
+ # Version selection radio button (top right)
249
  col1, col2 = st.columns([3, 1])
250
  with col1:
251
+ st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
252
  with col2:
253
+ st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
254
+ selected_version = st.radio(
255
+ "**Version:**",
256
+ options=["v1", "beta"],
257
+ index=0 if st.session_state.chatbot_version == "v1" else 1,
258
+ horizontal=True,
259
+ key="version_selector",
260
+ help="Select v1 (default RAG system) or beta (Gemini FSA)"
261
+ )
262
+
263
+ # Update version if changed
264
+ if selected_version != st.session_state.chatbot_version:
265
+ # Store the old version to check if we need to switch
266
+ old_version = st.session_state.chatbot_version
267
+ st.session_state.chatbot_version = selected_version
268
+
269
+ # If chatbot for new version already exists, just switch to it
270
+ new_chatbot_key = f"chatbot_{selected_version}"
271
+ if new_chatbot_key in st.session_state:
272
+ # Chatbot already exists, just switch
273
+ st.session_state.chatbot = st.session_state[new_chatbot_key]
274
+ st.session_state['_last_version'] = selected_version
275
+ else:
276
+ # Need to initialize new version - will be handled by initialization logic above
277
+ st.session_state['_last_version'] = old_version # Set to old to trigger init check
278
+
279
+ st.rerun()
280
+
281
+ # Show version info
282
+ if st.session_state.chatbot_version == "beta":
283
+ st.info("πŸ”¬ **Beta Mode**: Using Google Gemini FSA")
284
 
285
  # Session info
286
  duration = int(time.time() - st.session_state.session_start_time)
 
296
 
297
  # Sidebar for filters
298
  with st.sidebar:
299
+ # Instructions section (collapsible)
300
+ with st.expander("πŸ“– How to Use", expanded=False):
301
+ st.markdown("""
302
+ #### 🎯 Using Filters
303
+
304
+ 1. **Select filters** from the sidebar to narrow your search:
305
+
306
+ 2. **Leave filters empty** to search across all data
307
+
308
+ 3. **Type your question** in the chat input at the bottom
309
+
310
+ 4. **Click "Send"** to submit your question
311
+
312
+ #### πŸ’‘ Tips
313
+
314
+ - Use specific questions for better results
315
+ - Combine multiple filters for precise searches
316
+ - Check the "Retrieved Documents" tab to see source material
317
+
318
+ #### ⚠️ Important
319
+
320
+ **When finished, please close the browser window** to free up computational resources.
321
+
322
+ ---
323
+
324
+ For more detailed help, see the example questions at the bottom of the page.
325
+ """)
326
+
327
  st.markdown("### πŸ” Search Filters")
328
  st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
329
 
 
342
  # Determine if filename filter is active
343
  filename_mode = len(selected_filenames) > 0
344
  # Sources filter
345
+ # st.markdown('<div class="filter-section">', unsafe_allow_html=True)
346
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
347
  selected_sources = st.multiselect(
348
  "Select sources:",
 
355
  st.markdown('</div>', unsafe_allow_html=True)
356
 
357
  # Years filter
358
+ # st.markdown('<div class="filter-section">', unsafe_allow_html=True)
359
  st.markdown('<div class="filter-title">πŸ“… Years</div>', unsafe_allow_html=True)
360
  selected_years = st.multiselect(
361
  "Select years:",
 
368
  st.markdown('</div>', unsafe_allow_html=True)
369
 
370
  # Districts filter
371
+ # st.markdown('<div class="filter-section">', unsafe_allow_html=True)
372
  st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True)
373
  selected_districts = st.multiselect(
374
  "Select districts:",
 
419
  if 'input_counter' not in st.session_state:
420
  st.session_state.input_counter = 0
421
 
422
+ # Handle pending question from example questions section
423
+ if 'pending_question' in st.session_state and st.session_state.pending_question:
424
+ default_value = st.session_state.pending_question
425
+ # Increment counter to force new input widget
426
+ st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
427
+ del st.session_state.pending_question
428
+ key_suffix = st.session_state.input_counter
429
+ else:
430
+ default_value = ""
431
+ key_suffix = st.session_state.input_counter
432
+
433
  user_input = st.text_input(
434
  "Type your message here...",
435
  placeholder="Ask about budget allocations, expenditures, or audit findings...",
436
+ key=f"user_input_{key_suffix}",
437
+ label_visibility="collapsed",
438
+ value=default_value if default_value else None
439
  )
440
 
441
  with col2:
442
+ send_button = st.button("Send", key="send_button", width='stretch')
443
 
444
  # Clear chat button
445
  if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
446
  st.session_state.reset_conversation = True
447
  # Clear all conversation files
448
+ conversations_path = CONVERSATIONS_DIR
449
+ if conversations_path.exists():
450
+ for file in conversations_path.iterdir():
451
+ if file.suffix == '.json':
452
+ file.unlink()
 
453
  st.rerun()
454
 
455
  # Handle user input
 
491
  if rag_result:
492
  sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
493
 
494
+ # For Gemini, also check gemini_result for sources
495
+ if not sources or len(sources) == 0:
496
+ gemini_result = chat_result.get('gemini_result')
497
+ print(f"πŸ” DEBUG: Checking gemini_result for sources...")
498
+ print(f" gemini_result exists: {gemini_result is not None}")
499
+ if gemini_result:
500
+ print(f" gemini_result type: {type(gemini_result)}")
501
+ print(f" has sources attr: {hasattr(gemini_result, 'sources')}")
502
+ if hasattr(gemini_result, 'sources'):
503
+ print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}")
504
+
505
+ if gemini_result and hasattr(gemini_result, 'sources'):
506
+ # Format Gemini sources for display
507
+ if hasattr(st.session_state.chatbot, 'gemini_client'):
508
+ sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
509
+ print(f"βœ… Formatted {len(sources)} sources from gemini_client")
510
+ elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
511
+ sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
512
+ print(f"βœ… Formatted {len(sources)} sources from _format_gemini_sources")
513
+
514
+ # Update rag_result with sources if we found them
515
+ if sources and len(sources) > 0:
516
+ if isinstance(rag_result, dict):
517
+ rag_result['sources'] = sources
518
+ elif hasattr(rag_result, 'sources'):
519
+ rag_result.sources = sources
520
+ # Update last_rag_result with sources
521
+ st.session_state.last_rag_result = rag_result
522
+ print(f"βœ… Updated rag_result with {len(sources)} sources")
523
+
524
  # Get the actual RAG query
525
  actual_rag_query = chat_result.get('actual_rag_query', '')
526
  if actual_rag_query:
 
530
  else:
531
  formatted_query = "No RAG query available"
532
 
533
+ # Extract filters from active filters
534
+ filters_used = {
535
+ "sources": st.session_state.active_filters.get('sources', []),
536
+ "years": st.session_state.active_filters.get('years', []),
537
+ "districts": st.session_state.active_filters.get('districts', []),
538
+ "filenames": st.session_state.active_filters.get('filenames', [])
539
+ }
540
+
541
  retrieval_entry = {
542
  "conversation_up_to": serialize_messages(st.session_state.messages),
543
  "rag_query_expansion": formatted_query,
544
+ "docs_retrieved": serialize_documents(sources),
545
+ "filters_applied": filters_used,
546
+ "timestamp": time.time()
547
  }
548
  st.session_state.rag_retrieval_history.append(retrieval_entry)
549
+
550
+ # Debug logging
551
+ print(f"πŸ“Š RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history")
552
  else:
553
  response = chat_result
554
  st.session_state.last_rag_result = None
 
578
  # Dictionary format from multi-agent system
579
  sources = rag_result['sources']
580
 
581
+ # For Gemini, also check if we need to format sources from gemini_result
582
+ if (not sources or len(sources) == 0) and isinstance(rag_result, dict):
583
+ gemini_result = rag_result.get('gemini_result')
584
+ if gemini_result and hasattr(gemini_result, 'sources'):
585
+ # Format Gemini sources for display
586
+ if hasattr(st.session_state.chatbot, 'gemini_client'):
587
+ sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
588
+ elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
589
+ sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
590
+
591
  if sources and len(sources) > 0:
592
  # Count unique filenames
593
  unique_filenames = set()
 
595
  filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
596
  unique_filenames.add(filename)
597
 
598
+ st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
599
  if len(unique_filenames) < len(sources):
600
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
601
 
602
+ # Extract and display statistics
603
+ stats = extract_chunk_statistics(sources)
604
+
605
+ # Show charts for 10+ results, tables for fewer
606
+ if len(sources) >= 10:
607
+ display_chunk_statistics_charts(stats, "Retrieval Statistics")
608
+ # Also show tables below charts for detailed view
609
+ st.markdown("---")
610
+ display_chunk_statistics_table(stats, "Retrieval Distribution")
611
+ else:
612
+ display_chunk_statistics_table(stats, "Retrieval Distribution")
613
+
614
+ st.markdown("---")
615
+ st.markdown("### πŸ“„ Document Details")
616
+
617
+ for i, doc in enumerate(sources): # Show all documents
618
  # Get relevance score and ID if available
619
  metadata = getattr(doc, 'metadata', {})
620
+ # Handle both standard RAG scores and Gemini scores
621
+ score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
622
+ chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
623
+ if score is not None:
624
+ try:
625
+ score_text = f" (Score: {float(score):.3f})"
626
+ except (ValueError, TypeError):
627
+ score_text = ""
628
+ else:
629
+ score_text = ""
630
+ if chunk_id and chunk_id != 'Unknown':
631
+ score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
632
 
633
  with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
634
  # Display document metadata with emojis
 
675
  if 'feedback_submitted' not in st.session_state:
676
  st.session_state.feedback_submitted = False
677
 
678
+ # Feedback form - only show if feedback not already submitted
679
+ if not st.session_state.feedback_submitted:
680
+ with st.form("feedback_form", clear_on_submit=False):
681
+ col1, col2 = st.columns([1, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
 
683
+ with col1:
684
+ feedback_score = st.slider(
685
+ "Rate this conversation (1-5)",
686
+ min_value=1,
687
+ max_value=5,
688
+ help="How satisfied are you with the conversation?"
689
+ )
 
 
 
 
 
690
 
691
+ with col2:
692
+ is_feedback_about_last_retrieval = st.checkbox(
693
+ "Feedback about last retrieval only",
694
+ value=True,
695
+ help="If checked, feedback applies to the most recent document retrieval"
696
+ )
697
 
698
+ open_ended_feedback = st.text_area(
699
+ "Your feedback (optional)",
700
+ placeholder="Tell us what went well or what could be improved...",
701
+ height=100
702
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
 
704
+ # Disable submit if no score selected
705
+ submit_disabled = feedback_score is None
 
 
 
 
 
 
 
 
706
 
707
+ submitted = st.form_submit_button(
708
+ "πŸ“€ Submit Feedback",
709
+ width='stretch',
710
+ disabled=submit_disabled
711
+ )
712
 
713
+ if submitted:
714
+ # Log the feedback data being submitted
715
+ print("=" * 80)
716
+ print("πŸ”„ FEEDBACK SUBMISSION: Starting...")
717
+ print("=" * 80)
718
+ st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
719
+
720
+ # Extract transcript from messages
721
+ transcript = feedback_manager.extract_transcript(st.session_state.messages)
722
+
723
+ # Build retrievals structure
724
+ retrievals = feedback_manager.build_retrievals_structure(
725
+
726
+ st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
727
+ st.session_state.messages
728
+ )
729
+
730
+ # Build feedback_score_related_retrieval_docs
731
+
732
+ feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
733
+ is_feedback_about_last_retrieval,
734
+ st.session_state.messages,
735
+ st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
736
+ )
737
 
738
+ # Preserve old retrieved_data format for backward compatibility
739
+ retrieved_data_old_format = st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
 
 
740
 
741
+ # Create feedback data dictionary
742
+ feedback_dict = {
743
+ "open_ended_feedback": open_ended_feedback,
744
+ "score": feedback_score,
745
+ "is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
746
+ "conversation_id": st.session_state.conversation_id,
747
+ "timestamp": time.time(),
748
+ "message_count": len(st.session_state.messages),
749
+ "has_retrievals": has_retrievals,
750
+ "retrieval_count": len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0,
751
+ "transcript": transcript,
752
+ "retrievals": retrievals,
753
+ "feedback_score_related_retrieval_docs": feedback_score_related_retrieval_docs,
754
+ "retrieved_data": retrieved_data_old_format # Preserved old column
755
+ }
756
 
757
+ print(f"πŸ“ FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
 
 
758
 
759
+ # Create UserFeedback dataclass instance
760
+ feedback_obj = None # Initialize outside try block
761
  try:
762
+ feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
763
+ print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
764
+ st.write(f"βœ… **Feedback Object Created**")
765
+ st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
766
+ st.write(f"- Score: {feedback_obj.score}/5")
767
+ st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
768
 
769
+ # Convert back to dict for JSON serialization
770
+ feedback_data = feedback_obj.to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  except Exception as e:
772
+ print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
773
+ st.error(f"Failed to create feedback object: {e}")
774
+ feedback_data = feedback_dict
775
 
776
+ # Display the data being submitted
777
+ st.json(feedback_data)
778
 
779
+ # Save feedback to file - use PROJECT_DIR to ensure writability
780
+ feedback_dir = FEEDBACK_DIR
781
+ try:
782
+ # Ensure directory exists with write permissions (777 for compatibility)
783
+ feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
784
+ except (PermissionError, OSError) as e:
785
+ logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
786
+ # Fallback to relative path
787
+ feedback_dir = Path("feedback")
788
+ feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
789
 
790
+ feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
 
791
 
792
+ try:
793
+ # Ensure parent directory exists before writing
794
+ feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
795
+
796
+ # Save to local file first
797
+ print(f"πŸ’Ύ FEEDBACK SAVE: Saving to local file: {feedback_file}")
798
+ with open(feedback_file, 'w') as f:
799
+ json.dump(feedback_data, f, indent=2, default=str)
800
+
801
+ print(f"βœ… FEEDBACK SAVE: Local file saved successfully")
802
+
803
+ # Save to Snowflake if enabled and credentials available
804
+ logger.info("πŸ”„ FEEDBACK SAVE: Starting Snowflake save process...")
805
+ logger.info(f"πŸ“Š FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
806
+
807
+ snowflake_success = False
808
+ try:
809
+ snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
810
+ logger.info(f"πŸ” SNOWFLAKE CHECK: enabled={snowflake_enabled}")
811
+
812
+ if snowflake_enabled:
813
+ if feedback_obj:
814
+ try:
815
+ logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
816
+ print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
817
+
818
+ # Show spinner while saving to Snowflake (can take 10-15 seconds)
819
+ # This includes: connection establishment (~5s), data preparation, and SQL execution (~5s)
820
+ with st.spinner("πŸ’Ύ Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"):
821
+ snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
822
+
823
+ if snowflake_success:
824
+ logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
825
+ print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
826
+ else:
827
+ logger.warning("⚠️ SNOWFLAKE UI: Save failed")
828
+ print("⚠️ SNOWFLAKE UI: Save failed")
829
+ except Exception as e:
830
+ logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
831
+ print(f"❌ SNOWFLAKE UI ERROR: {e}")
832
+ traceback.print_exc()
833
+ snowflake_success = False
834
+ else:
835
+ logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
836
+ print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
837
+ snowflake_success = False
838
+ else:
839
+ logger.info("πŸ’‘ SNOWFLAKE UI: Integration disabled")
840
+ print("πŸ’‘ SNOWFLAKE UI: Integration disabled")
841
+ # If Snowflake is disabled, consider it successful (local save only)
842
+ snowflake_success = True
843
+
844
+ except Exception as e:
845
+ logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
846
+ print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
847
+ snowflake_success = False
848
+
849
+ # Only show success if Snowflake save succeeded (or if Snowflake is disabled)
850
+ if snowflake_success:
851
+ st.success("βœ… Thank you for your feedback! It has been saved successfully.")
852
+ st.balloons()
853
+ else:
854
+ st.warning("⚠️ Feedback saved locally, but Snowflake save failed. Please check logs.")
855
+
856
+ # Mark feedback as submitted to prevent resubmission
857
+ st.session_state.feedback_submitted = True
858
+
859
+ print("=" * 80)
860
+ print(f"βœ… FEEDBACK SUBMISSION: Completed successfully")
861
+ print("=" * 80)
862
+
863
+ # Log file location
864
+ st.info(f"πŸ“ Feedback saved to: {feedback_file}")
865
+
866
+ except Exception as e:
867
+ print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
868
+ print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
869
+ traceback.print_exc()
870
+ st.error(f"❌ Error saving feedback: {e}")
871
+ st.write(f"Debug error: {str(e)}")
872
+ else:
873
+ # Feedback already submitted - show success message and reset option
874
+ st.success("βœ… Feedback already submitted for this conversation!")
875
+ col1, col2 = st.columns([1, 1])
876
+ with col1:
877
+ if st.button("πŸ”„ Submit New Feedback", key="new_feedback_button", width='stretch'):
878
+ try:
879
+ st.session_state.feedback_submitted = False
880
+ st.rerun()
881
+ except Exception as e:
882
+ # Handle any Streamlit API exceptions gracefully
883
+ logger.error(f"Error resetting feedback state: {e}")
884
+ st.error(f"Error resetting feedback. Please refresh the page.")
885
+ with col2:
886
+ if st.button("πŸ“‹ View Conversation", key="view_conversation_button", width='stretch'):
887
+ # Scroll to conversation - this is handled by the auto-scroll at bottom
888
+ pass
889
 
890
  # Display retrieval history stats
891
  if st.session_state.rag_retrieval_history:
892
  st.markdown("---")
893
  st.markdown("#### πŸ“Š Retrieval History")
894
 
895
+ with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
896
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
897
+ st.markdown(f"### **Retrieval #{idx}**")
898
+
899
+ # Display timestamp if available
900
+ if entry.get("timestamp"):
901
+ timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
902
+ st.caption(f"πŸ• {timestamp_str}")
903
 
904
  # Display the actual RAG query
905
  rag_query_expansion = entry.get("rag_query_expansion", "No query available")
906
+ st.markdown("**πŸ” RAG Query:**")
907
  st.code(rag_query_expansion, language="text")
908
 
909
+ # Display filters used
910
+ filters_applied = entry.get("filters_applied", {})
911
+ if filters_applied and any(filters_applied.values()):
912
+ st.markdown("**🎯 Filters Applied:**")
913
+ filter_display = {}
914
+ if filters_applied.get("sources"):
915
+ filter_display["Sources"] = filters_applied["sources"]
916
+ if filters_applied.get("years"):
917
+ filter_display["Years"] = filters_applied["years"]
918
+ if filters_applied.get("districts"):
919
+ filter_display["Districts"] = filters_applied["districts"]
920
+ if filters_applied.get("filenames"):
921
+ filter_display["Filenames"] = filters_applied["filenames"]
922
+
923
+ if filter_display:
924
+ st.json(filter_display)
925
+ else:
926
+ st.info("No filters applied")
927
+ else:
928
+ st.info("No filters applied")
929
+
930
+ # Display conversation history up to retrieval point
931
+ conversation_up_to = entry.get("conversation_up_to", [])
932
+ if conversation_up_to:
933
+ st.markdown("**πŸ’¬ Conversation History (up to retrieval point):**")
934
+ with st.expander(f"View {len(conversation_up_to)} messages", expanded=False):
935
+ for msg_idx, msg in enumerate(conversation_up_to, 1):
936
+ role = msg.get("type", "unknown")
937
+ content = msg.get("content", "")
938
+
939
+ if role == "HumanMessage" or role == "human":
940
+ st.markdown(f"**πŸ‘€ User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
941
+ elif role == "AIMessage" or role == "ai":
942
+ st.markdown(f"**πŸ€– Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}")
943
+ else:
944
+ st.info("No conversation history available")
945
+
946
+ # Display documents retrieved
947
+ docs_retrieved = entry.get("docs_retrieved", [])
948
+ if docs_retrieved:
949
+ st.markdown(f"**πŸ“„ Documents Retrieved ({len(docs_retrieved)}):**")
950
+ with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
951
+ for doc_idx, doc in enumerate(docs_retrieved, 1):
952
+ st.markdown(f"**Document {doc_idx}:**")
953
+
954
+ # Display metadata
955
+ metadata = doc.get("metadata", {})
956
+ if metadata:
957
+ col1, col2, col3 = st.columns(3)
958
+ with col1:
959
+ st.write(f"πŸ“„ **File:** {metadata.get('filename', 'Unknown')}")
960
+ with col2:
961
+ st.write(f"πŸ›οΈ **Source:** {metadata.get('source', 'Unknown')}")
962
+ with col3:
963
+ st.write(f"πŸ“… **Year:** {metadata.get('year', 'Unknown')}")
964
+
965
+ # Additional metadata
966
+ if metadata.get('district'):
967
+ st.write(f"πŸ“ **District:** {metadata.get('district')}")
968
+ if metadata.get('page'):
969
+ st.write(f"πŸ“– **Page:** {metadata.get('page')}")
970
+ if metadata.get('score') is not None:
971
+ st.write(f"⭐ **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"⭐ **Score:** {metadata.get('score')}")
972
+
973
+ # Display content preview (first 200 chars)
974
+ content = doc.get("content", doc.get("page_content", ""))
975
+ if content:
976
+ st.markdown("**Content Preview:**")
977
+ st.text_area(
978
+ "Content Preview",
979
+ value=content[:200] + ("..." if len(content) > 200 else ""),
980
+ height=100,
981
+ disabled=True,
982
+ label_visibility="collapsed",
983
+ key=f"retrieval_{idx}_doc_{doc_idx}_preview"
984
+ )
985
+
986
+ if doc_idx < len(docs_retrieved):
987
+ st.markdown("---")
988
+ else:
989
+ st.info("No documents retrieved")
990
+
991
  # Display summary stats
992
+ st.markdown("**πŸ“Š Summary:**")
993
  st.json({
994
+ "conversation_length": len(conversation_up_to),
995
+ "documents_retrieved": len(docs_retrieved)
996
  })
997
+
998
+ if idx < len(st.session_state.rag_retrieval_history):
999
+ st.markdown("---")
1000
+
1001
+ # Example Questions Section
1002
+ st.markdown("---")
1003
+ st.markdown("### πŸ’‘ Example Questions")
1004
+ st.markdown("Click on any question below to use it, or modify the editable examples:")
1005
+
1006
+ # Initialize example question state
1007
+ if 'custom_question_1' not in st.session_state:
1008
+ st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?"
1009
+ if 'custom_question_2' not in st.session_state:
1010
+ st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?"
1011
+
1012
+ # Question 1: Filename insights (fixed, clickable)
1013
+ st.markdown("#### πŸ“„ Question 1: List insights from a specific file")
1014
+ col1, col2 = st.columns([3, 1])
1015
+ with col1:
1016
+ example_q1 = "List couple of insights from the filename."
1017
+ st.markdown(f"**Example:** `{example_q1}`")
1018
+ st.info("πŸ’‘ **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
1019
+ with col2:
1020
+ if st.button("πŸ“‹ Use This Question", key="use_example_1", width='stretch'):
1021
+ st.session_state.pending_question = example_q1
1022
+ st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1023
+ st.rerun()
1024
+
1025
+ st.markdown("---")
1026
+
1027
+ # Questions 2 & 3: Editable examples
1028
+ st.markdown("#### ✏️ Customizable Questions (Edit and use)")
1029
+
1030
+ # Question 2
1031
+ # st.markdown("**Question 2:**")
1032
+ custom_q1 = st.text_area(
1033
+ "Edit question 2:",
1034
+ value=st.session_state.custom_question_1,
1035
+ height=80,
1036
+ key="edit_question_2",
1037
+ help="Modify this question to fit your needs, then click 'Use This Question'"
1038
+ )
1039
+ col1, col2 = st.columns([1, 4])
1040
+ with col1:
1041
+ if st.button("πŸ“‹ Use Question 2", key="use_custom_1", width='stretch'):
1042
+ if custom_q1.strip():
1043
+ st.session_state.pending_question = custom_q1.strip()
1044
+ st.session_state.custom_question_1 = custom_q1.strip()
1045
+ st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1046
+ st.rerun()
1047
+ else:
1048
+ st.warning("Please enter a question first!")
1049
+ with col2:
1050
+ st.caption("πŸ’‘ Tip: Add specific details like dates, names, or amounts to get more precise answers")
1051
+
1052
+ st.info("πŸ’‘ **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
1053
+
1054
+ st.markdown("---")
1055
+
1056
+ # Question 3
1057
+ # st.markdown("**Question 3:**")
1058
+ custom_q2 = st.text_area(
1059
+ "Edit question 3:",
1060
+ value=st.session_state.custom_question_2,
1061
+ height=80,
1062
+ key="edit_question_3",
1063
+ help="Modify this question to fit your needs, then click 'Use This Question'"
1064
+ )
1065
+ col1, col2 = st.columns([1, 4])
1066
+ with col1:
1067
+ if st.button("πŸ“‹ Use Question 3", key="use_custom_2", width='stretch'):
1068
+ if custom_q2.strip():
1069
+ st.session_state.pending_question = custom_q2.strip()
1070
+ st.session_state.custom_question_2 = custom_q2.strip()
1071
+ st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
1072
+ st.rerun()
1073
+ else:
1074
+ st.warning("Please enter a question first!")
1075
+ with col2:
1076
+ st.caption("πŸ’‘ Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
1077
+
1078
+
1079
+ # Store selected question for next render (handled in input section above)
1080
+ # This ensures the question populates the input field correctly
1081
 
1082
  # Auto-scroll to bottom
1083
  st.markdown("""
 
1086
  </script>
1087
  """, unsafe_allow_html=True)
1088
 
1089
+
1090
  if __name__ == "__main__":
1091
+ # Check if running in Streamlit context
1092
+ try:
1093
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
1094
+ if get_script_run_ctx() is None:
1095
+ # Not in Streamlit runtime - show helpful message
1096
+ print("=" * 80)
1097
+ print("⚠️ WARNING: This is a Streamlit app!")
1098
+ print("=" * 80)
1099
+ print("\nPlease run this app using:")
1100
+ print(" streamlit run app.py")
1101
+ print("\nNot: python app.py")
1102
+ print("\nThe app will not function correctly when run with 'python app.py'")
1103
+ print("=" * 80)
1104
+ import sys
1105
+ sys.exit(1)
1106
+ except ImportError:
1107
+ # Streamlit not installed or not in Streamlit context
1108
+ print("=" * 80)
1109
+ print("⚠️ WARNING: This is a Streamlit app!")
1110
+ print("=" * 80)
1111
+ print("\nPlease run this app using:")
1112
+ print(" streamlit run app.py")
1113
+ print("\nNot: python app.py")
1114
+ print("=" * 80)
1115
+ import sys
1116
+ sys.exit(1)
1117
  main()
src/agents/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent modules for chatbot implementations
3
+ """
4
+
5
+ from .smart_chatbot import get_chatbot as get_smart_chatbot
6
+ from .multi_agent_chatbot import get_multi_agent_chatbot
7
+ from .gemini_chatbot import get_gemini_chatbot
8
+
9
+ __all__ = ["get_smart_chatbot", "get_multi_agent_chatbot", "get_gemini_chatbot"]
10
+
src/agents/gemini_chatbot.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Chatbot (Beta Version)
3
+
4
+ This chatbot uses Google Gemini File Search API for RAG.
5
+ It provides a simpler architecture: Main Agent + Gemini Agent
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ import logging
12
+ import traceback
13
+ from pathlib import Path
14
+ from typing import Dict, List, Any, Optional, TypedDict
15
+
16
+ from langgraph.graph import StateGraph, END
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
19
+
20
+ from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
21
+ from src.config.paths import CONVERSATIONS_DIR
22
+
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GeminiState(TypedDict):
28
+ """State for Gemini chatbot conversation flow"""
29
+ conversation_id: str
30
+ messages: List[Any]
31
+ current_query: str
32
+ query_context: Optional[Dict[str, Any]]
33
+ gemini_result: Optional[GeminiFileSearchResult]
34
+ final_response: Optional[str]
35
+ agent_logs: List[str]
36
+ conversation_context: Dict[str, Any]
37
+ session_start_time: float
38
+ last_ai_message_time: float
39
+ filters: Optional[Dict[str, Any]]
40
+
41
+
42
+ class GeminiRAGChatbot:
43
+ """Gemini File Search RAG chatbot (Beta version)"""
44
+
45
+ def __init__(self):
46
+ """Initialize the Gemini chatbot"""
47
+ logger.info("πŸ€– INITIALIZING: Gemini File Search Chatbot (Beta)")
48
+
49
+ # Initialize Gemini File Search client
50
+ try:
51
+ self.gemini_client = GeminiFileSearchClient()
52
+ logger.info("βœ… Gemini File Search client initialized")
53
+ except Exception as e:
54
+ logger.error(f"❌ Failed to initialize Gemini client: {e}")
55
+ raise RuntimeError(f"Gemini client initialization failed: {e}")
56
+
57
+ # Build the LangGraph with LangSmith tracing if enabled
58
+ self.graph = self._build_graph()
59
+
60
+ # Enable LangSmith tracing if configured
61
+ langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
62
+ if langsmith_enabled:
63
+ logger.info("πŸ” LangSmith tracing enabled")
64
+ langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
65
+ logger.info(f"πŸ“Š LangSmith project: {langsmith_project}")
66
+
67
+ # Conversations directory
68
+ self.conversations_dir = CONVERSATIONS_DIR
69
+ try:
70
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
71
+ except (PermissionError, OSError) as e:
72
+ logger.warning(f"Could not create conversations directory: {e}")
73
+ self.conversations_dir = Path("conversations")
74
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
75
+
76
+ logger.info("βœ… Gemini File Search Chatbot initialized")
77
+
78
+ def _build_graph(self) -> StateGraph:
79
+ """Build the LangGraph for Gemini chatbot"""
80
+ graph = StateGraph(GeminiState)
81
+
82
+ # Add nodes
83
+ graph.add_node("main_agent", self._main_agent)
84
+ graph.add_node("gemini_agent", self._gemini_agent)
85
+
86
+ # Define the flow
87
+ graph.set_entry_point("main_agent")
88
+ graph.add_edge("main_agent", "gemini_agent")
89
+ graph.add_edge("gemini_agent", END)
90
+
91
+ return graph.compile()
92
+
93
+ def _main_agent(self, state: GeminiState) -> GeminiState:
94
+ """Main Agent: Extracts filters and prepares query"""
95
+ logger.info("🎯 MAIN AGENT: Processing query")
96
+
97
+ query = state["current_query"]
98
+ messages = state["messages"]
99
+
100
+ # Extract UI filters if present in query
101
+ ui_filters = self._extract_ui_filters(query)
102
+
103
+ # Extract context from conversation
104
+ context = self._extract_context_from_conversation(messages, ui_filters)
105
+
106
+ # Store context and filters
107
+ state["query_context"] = context
108
+ state["filters"] = context.get("filters", {})
109
+
110
+ logger.info(f"🎯 MAIN AGENT: Filters extracted: {state['filters']}")
111
+
112
+ return state
113
+
114
+ def _gemini_agent(self, state: GeminiState) -> GeminiState:
115
+ """Gemini Agent: Performs file search and generates response"""
116
+ logger.info("πŸ” GEMINI AGENT: Starting file search")
117
+
118
+ query = state["current_query"]
119
+ filters = state.get("filters", {})
120
+
121
+ # Perform Gemini file search
122
+ try:
123
+ result = self.gemini_client.search(query=query, filters=filters)
124
+ logger.info(f"βœ… GEMINI AGENT: Search completed, {len(result.sources)} sources found")
125
+
126
+ # Enhance response with document references
127
+ enhanced_response = self._enhance_response_with_references(
128
+ result.answer,
129
+ result.sources,
130
+ query
131
+ )
132
+
133
+ state["gemini_result"] = result
134
+ state["final_response"] = enhanced_response
135
+ state["last_ai_message_time"] = time.time()
136
+
137
+ state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
138
+
139
+ except Exception as e:
140
+ logger.error(f"❌ GEMINI AGENT ERROR: {e}")
141
+ traceback.print_exc()
142
+ state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
143
+ state["last_ai_message_time"] = time.time()
144
+
145
+ return state
146
+
147
+ def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
148
+ """Enhance Gemini response to include document references and format nicely"""
149
+ if not sources or not answer:
150
+ return answer
151
+
152
+ # Use LLM to intelligently add document references and format nicely
153
+ try:
154
+ from src.llm.adapters import get_llm_client
155
+ llm = get_llm_client()
156
+
157
+ # Prepare document summaries for the LLM
158
+ doc_summaries = []
159
+ for idx, doc in enumerate(sources, 1):
160
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
161
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
162
+
163
+ filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
164
+ year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
165
+ source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
166
+ district = metadata.get('district', '') if isinstance(metadata, dict) else ''
167
+
168
+ doc_info = f"{filename}"
169
+ if year and year != 'Unknown':
170
+ doc_info += f" ({year})"
171
+ if source and source != 'Unknown':
172
+ doc_info += f" - {source}"
173
+ if district:
174
+ doc_info += f" - {district}"
175
+
176
+ doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...")
177
+
178
+ prompt = f"""You are enhancing a response from a document search system. The original response is:
179
+
180
+ {answer}
181
+
182
+ The following documents were retrieved and used to generate this response:
183
+
184
+ {chr(10).join(doc_summaries)}
185
+
186
+ CRITICAL RULES:
187
+ 1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate
188
+ 2. The response should ONLY contain information from the retrieved documents listed above
189
+ 3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
190
+ 4. Add document references [Doc i] at the end of sentences that use information from specific documents
191
+ 5. Only reference documents that are actually used in the response
192
+ 6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
193
+ 7. Keep the response natural, conversational, and well-formatted
194
+ 8. Use proper formatting: paragraphs, line breaks, and structure for readability
195
+ 9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting
196
+ 10. If multiple documents support the same claim, use [Doc i, Doc j] format
197
+ 11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
198
+
199
+ Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary."""
200
+
201
+ enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
202
+
203
+ # Fallback: if LLM fails, just return original with basic formatting
204
+ if not enhanced or len(enhanced) < len(answer) * 0.5:
205
+ logger.warning("LLM enhancement failed, using original response with basic formatting")
206
+ # Basic formatting: add line breaks after periods for readability
207
+ formatted = answer.replace('. ', '.\n\n')
208
+ if sources:
209
+ ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
210
+ formatted += f"\n\n*Based on documents: {ref_list}*"
211
+ return formatted
212
+
213
+ return enhanced
214
+
215
+ except Exception as e:
216
+ logger.warning(f"Failed to enhance response with references: {e}")
217
+ # Fallback: add basic formatting and references at the end
218
+ formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting
219
+ if sources:
220
+ ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
221
+ formatted += f"\n\n*Based on documents: {ref_list}*"
222
+ return formatted
223
+
224
+ def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
225
+ """Extract UI filters from query if present"""
226
+ filters = {}
227
+
228
+ if "FILTER CONTEXT:" in query:
229
+ filter_section = query.split("FILTER CONTEXT:")[1]
230
+ if "USER QUERY:" in filter_section:
231
+ filter_section = filter_section.split("USER QUERY:")[0]
232
+ filter_section = filter_section.strip()
233
+
234
+ if "Sources:" in filter_section:
235
+ sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
236
+ if sources_line:
237
+ sources_str = sources_line[0].split("Sources:")[1].strip()
238
+ if sources_str and sources_str != "None":
239
+ filters["sources"] = [s.strip() for s in sources_str.split(",")]
240
+
241
+ if "Years:" in filter_section:
242
+ years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
243
+ if years_line:
244
+ years_str = years_line[0].split("Years:")[1].strip()
245
+ if years_str and years_str != "None":
246
+ filters["year"] = [y.strip() for y in years_str.split(",")]
247
+
248
+ if "Districts:" in filter_section:
249
+ districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
250
+ if districts_line:
251
+ districts_str = districts_line[0].split("Districts:")[1].strip()
252
+ if districts_str and districts_str != "None":
253
+ filters["district"] = [d.strip() for d in districts_str.split(",")]
254
+
255
+ if "Filenames:" in filter_section:
256
+ filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
257
+ if filenames_line:
258
+ filenames_str = filenames_line[0].split("Filenames:")[1].strip()
259
+ if filenames_str and filenames_str != "None":
260
+ filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
261
+
262
+ return filters
263
+
264
+ def _extract_context_from_conversation(
265
+ self,
266
+ messages: List[Any],
267
+ ui_filters: Dict[str, List[str]]
268
+ ) -> Dict[str, Any]:
269
+ """Extract context from conversation history"""
270
+ # Use UI filters if available
271
+ filters = ui_filters.copy() if ui_filters else {}
272
+
273
+ # For Gemini, we pass filters directly to the search function
274
+ # The filters will be used to add context to the query
275
+
276
+ return {
277
+ "filters": filters,
278
+ "has_filters": bool(filters)
279
+ }
280
+
281
+ def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
282
+ """Main chat interface"""
283
+ logger.info(f"πŸ’¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
284
+
285
+ # Load conversation
286
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
287
+ conversation = self._load_conversation(conversation_file)
288
+
289
+ # Add user message
290
+ conversation["messages"].append(HumanMessage(content=user_input))
291
+
292
+ # Prepare state
293
+ state = GeminiState(
294
+ conversation_id=conversation_id,
295
+ messages=conversation["messages"],
296
+ current_query=user_input,
297
+ query_context=None,
298
+ gemini_result=None,
299
+ final_response=None,
300
+ agent_logs=[],
301
+ conversation_context=conversation.get("context", {}),
302
+ session_start_time=conversation["session_start_time"],
303
+ last_ai_message_time=conversation["last_ai_message_time"],
304
+ filters=None
305
+ )
306
+
307
+ # Run graph
308
+ final_state = self.graph.invoke(state)
309
+
310
+ # Add AI response to conversation
311
+ if final_state["final_response"]:
312
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
313
+
314
+ # Update conversation
315
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
316
+ conversation["context"] = final_state["conversation_context"]
317
+
318
+ # Save conversation
319
+ self._save_conversation(conversation_file, conversation)
320
+
321
+ # Format sources for display
322
+ sources = []
323
+ gemini_result = final_state.get("gemini_result")
324
+ if gemini_result:
325
+ sources = self.gemini_client.format_sources_for_display(gemini_result)
326
+ logger.info(f"πŸ“‹ GEMINI CHAT: Formatted {len(sources)} sources for display")
327
+
328
+ return {
329
+ 'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
330
+ 'rag_result': {
331
+ 'sources': sources,
332
+ 'answer': final_state["final_response"]
333
+ },
334
+ 'agent_logs': final_state["agent_logs"],
335
+ 'actual_rag_query': final_state["current_query"],
336
+ 'gemini_result': gemini_result # Include raw result for tracking
337
+ }
338
+
339
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
340
+ """Load conversation from file"""
341
+ if conversation_file.exists():
342
+ try:
343
+ with open(conversation_file) as f:
344
+ data = json.load(f)
345
+ messages = []
346
+ for msg_data in data.get("messages", []):
347
+ if msg_data["type"] == "human":
348
+ messages.append(HumanMessage(content=msg_data["content"]))
349
+ elif msg_data["type"] == "ai":
350
+ messages.append(AIMessage(content=msg_data["content"]))
351
+ data["messages"] = messages
352
+ return data
353
+ except Exception as e:
354
+ logger.warning(f"Could not load conversation: {e}")
355
+
356
+ return {
357
+ "messages": [],
358
+ "session_start_time": time.time(),
359
+ "last_ai_message_time": time.time(),
360
+ "context": {}
361
+ }
362
+
363
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
364
+ """Save conversation to file"""
365
+ try:
366
+ conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
367
+
368
+ messages_data = []
369
+ for msg in conversation["messages"]:
370
+ if isinstance(msg, HumanMessage):
371
+ messages_data.append({"type": "human", "content": msg.content})
372
+ elif isinstance(msg, AIMessage):
373
+ messages_data.append({"type": "ai", "content": msg.content})
374
+
375
+ conversation_data = {
376
+ "messages": messages_data,
377
+ "session_start_time": conversation["session_start_time"],
378
+ "last_ai_message_time": conversation["last_ai_message_time"],
379
+ "context": conversation.get("context", {})
380
+ }
381
+
382
+ with open(conversation_file, 'w') as f:
383
+ json.dump(conversation_data, f, indent=2)
384
+
385
+ except Exception as e:
386
+ logger.error(f"Could not save conversation: {e}")
387
+
388
+
389
+ def get_gemini_chatbot():
390
+ """Get Gemini chatbot instance"""
391
+ return GeminiRAGChatbot()
392
+
multi_agent_chatbot.py β†’ src/agents/multi_agent_chatbot.py RENAMED
@@ -8,24 +8,26 @@ This system implements a 3-agent architecture:
8
 
9
  Each agent has specialized prompts and responsibilities.
10
  """
 
11
  import json
12
  import time
13
  import logging
 
14
  from pathlib import Path
15
  from datetime import datetime
16
  from dataclasses import dataclass
17
  from typing import Dict, List, Any, Optional, TypedDict
18
 
19
-
20
  from langchain_core.tools import tool
21
  from langgraph.graph import StateGraph, END
22
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
23
  from langchain_core.prompts import ChatPromptTemplate
 
24
 
25
 
26
  from src.pipeline import PipelineManager
27
- from src.config.loader import load_config
28
  from src.llm.adapters import get_llm_client
 
 
29
 
30
 
31
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -46,6 +48,7 @@ class QueryContext:
46
  needs_follow_up: bool = False
47
  follow_up_question: Optional[str] = None
48
 
 
49
  class MultiAgentState(TypedDict):
50
  """State for the multi-agent conversation flow"""
51
  conversation_id: str
@@ -61,6 +64,7 @@ class MultiAgentState(TypedDict):
61
  session_start_time: float
62
  last_ai_message_time: float
63
 
 
64
  class MultiAgentRAGChatbot:
65
  """Multi-agent RAG chatbot with specialized agents"""
66
 
@@ -112,7 +116,6 @@ class MultiAgentRAGChatbot:
112
  logger.info("βœ… Pipeline manager initialized and models loaded")
113
  except Exception as e:
114
  logger.error(f"❌ Failed to initialize pipeline manager: {e}")
115
- import traceback
116
  traceback.print_exc()
117
  raise RuntimeError(f"Pipeline manager initialization failed: {e}")
118
 
@@ -129,7 +132,6 @@ class MultiAgentRAGChatbot:
129
  raise # Re-raise RuntimeError as-is
130
  except Exception as e:
131
  logger.error(f"❌ Error during vector store connection: {e}")
132
- import traceback
133
  traceback.print_exc()
134
  raise RuntimeError(f"Vector store connection failed: {e}")
135
 
@@ -139,8 +141,8 @@ class MultiAgentRAGChatbot:
139
  # Build the multi-agent graph
140
  self.graph = self._build_graph()
141
 
142
- # Conversations directory - use absolute path in /app to ensure writability
143
- self.conversations_dir = Path("/app/conversations")
144
  try:
145
  # Use 777 permissions for maximum compatibility (HF Spaces runs as different user)
146
  self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
@@ -158,9 +160,9 @@ class MultiAgentRAGChatbot:
158
 
159
  def _load_dynamic_data(self):
160
  """Load dynamic data from filter_options.json and add_district_metadata.py"""
161
- # Load filter options
162
  try:
163
- fo = Path("src/config/filter_options.json")
164
  if fo.exists():
165
  with open(fo) as f:
166
  data = json.load(f)
@@ -178,7 +180,7 @@ class MultiAgentRAGChatbot:
178
  self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
179
  self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
180
 
181
- # Enrich district list from add_district_metadata.py
182
  try:
183
  from add_district_metadata import DistrictMetadataProcessor
184
  proc = DistrictMetadataProcessor()
@@ -206,6 +208,59 @@ class MultiAgentRAGChatbot:
206
  logger.info(f" Sources: {self.source_whitelist}")
207
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def _build_graph(self) -> StateGraph:
210
  """Build the multi-agent LangGraph"""
211
  graph = StateGraph(MultiAgentState)
@@ -510,6 +565,10 @@ class MultiAgentRAGChatbot:
510
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
511
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
512
  - Always return districts as JSON arrays when multiple districts are mentioned
 
 
 
 
513
  - If no exact matches found, set extracted values to null
514
 
515
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
@@ -590,7 +649,6 @@ Analyze this query using ONLY the exact values provided above:""")
590
  # Clean and parse JSON with better error handling
591
  try:
592
  # Remove comments (// and /* */) from JSON
593
- import re
594
  # Remove single-line comments
595
  content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
596
  # Remove multi-line comments
@@ -603,7 +661,6 @@ Analyze this query using ONLY the exact values provided above:""")
603
  logger.error(f"❌ Raw content: {content[:200]}...")
604
 
605
  # Try to extract JSON from text if embedded
606
- import re
607
  json_match = re.search(r'\{.*\}', content, re.DOTALL)
608
  if json_match:
609
  try:
@@ -656,13 +713,9 @@ Analyze this query using ONLY the exact values provided above:""")
656
  # Validate each district in the array
657
  valid_districts = []
658
  for district in extracted_district:
659
- if district in self.district_whitelist:
660
- valid_districts.append(district)
661
- else:
662
- # Try removing "District" suffix
663
- district_name = district.replace(" District", "").replace(" district", "")
664
- if district_name in self.district_whitelist:
665
- valid_districts.append(district_name)
666
 
667
  if valid_districts:
668
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
@@ -671,16 +724,15 @@ Analyze this query using ONLY the exact values provided above:""")
671
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
672
  extracted_district = None
673
  else:
674
- # Single district validation
675
- if extracted_district not in self.district_whitelist:
676
- # Try removing "District" suffix
677
- district_name = extracted_district.replace(" District", "").replace(" district", "")
678
- if district_name in self.district_whitelist:
679
- logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{district_name}'")
680
- extracted_district = district_name
681
- else:
682
- logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
683
- extracted_district = None
684
 
685
  # Validate source (handle both single values and arrays)
686
  if extracted_source:
@@ -918,6 +970,23 @@ Rewrite the best retrieval query:""")
918
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
919
 
920
  # Merge with extracted context for missing filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  if not filters.get("year") and context.extracted_year:
922
  # Handle both single values and arrays
923
  if isinstance(context.extracted_year, list):
@@ -926,16 +995,6 @@ Rewrite the best retrieval query:""")
926
  filters["year"] = [context.extracted_year]
927
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
928
 
929
- if not filters.get("district") and context.extracted_district:
930
- # Handle both single values and arrays
931
- if isinstance(context.extracted_district, list):
932
- # Normalize district names to title case (match Qdrant metadata format)
933
- normalized = [d.title() for d in context.extracted_district]
934
- filters["district"] = normalized
935
- else:
936
- filters["district"] = [context.extracted_district.title()]
937
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
938
-
939
  if not filters.get("sources") and context.extracted_source:
940
  # Handle both single values and arrays
941
  if isinstance(context.extracted_source, list):
@@ -963,12 +1022,21 @@ Rewrite the best retrieval query:""")
963
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
964
 
965
  if context.extracted_district:
966
- # Handle both single values and arrays
967
  if isinstance(context.extracted_district, list):
968
- filters["district"] = context.extracted_district
 
 
 
 
 
 
 
969
  else:
970
- filters["district"] = [context.extracted_district]
971
- logger.info(f"πŸ”§ FILTER BUILDING: Added extracted district filter: {context.extracted_district}")
 
 
972
 
973
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
974
  return filters
@@ -978,49 +1046,212 @@ Rewrite the best retrieval query:""")
978
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
979
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
980
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
 
 
 
 
 
 
 
 
 
 
981
 
982
  # Create response prompt
983
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
984
  response_prompt = ChatPromptTemplate.from_messages([
985
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
986
 
 
 
 
 
 
 
 
 
987
  RULES:
988
  1. Answer the user's question directly and clearly
989
- 2. Use the retrieved documents as evidence
990
  3. Be conversational, not technical
991
  4. Don't mention scores, retrieval details, or technical implementation
992
  5. If relevant documents were found, reference them naturally
993
- 6. If no relevant documents, explain based on your knowledge (if you have it) or just say you do not have enough information.
994
- 7. If the passages have useful facts or numbers, use them in your answer.
995
- 8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
996
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
997
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
998
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
999
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1000
  13. You do not need to use every passage. Only use the ones that help answer the question.
1001
- 14. If the documents do not have the information needed to answer the question, just say you do not have enough information.
1002
-
 
1003
 
1004
  TONE: Professional but friendly, like talking to a colleague."""),
1005
- HumanMessage(content=f"""User Question: {query}
 
 
 
1006
 
1007
  Retrieved Documents: {len(documents)} documents found
1008
 
 
 
 
 
 
 
1009
  RAG Answer: {rag_answer}
1010
 
1011
- Generate a conversational response:""")
 
 
 
 
 
 
 
1012
  ])
1013
 
1014
  try:
1015
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1016
  response = self.llm.invoke(response_prompt.format_messages())
1017
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1018
- return response.content.strip()
 
 
 
 
 
 
 
 
1019
  except Exception as e:
1020
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1021
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1022
  return rag_answer # Fallback to RAG answer
1023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1025
  """Generate conversational response using only LLM knowledge and conversation history"""
1026
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
@@ -1178,7 +1409,6 @@ Generate a conversational response based on your knowledge:""")
1178
 
1179
  except Exception as e:
1180
  logger.error(f"Could not save conversation: {e}")
1181
- import traceback
1182
  logger.error(f"Traceback: {traceback.format_exc()}")
1183
 
1184
 
 
8
 
9
  Each agent has specialized prompts and responsibilities.
10
  """
11
+ import re
12
  import json
13
  import time
14
  import logging
15
+ import traceback
16
  from pathlib import Path
17
  from datetime import datetime
18
  from dataclasses import dataclass
19
  from typing import Dict, List, Any, Optional, TypedDict
20
 
 
21
  from langchain_core.tools import tool
22
  from langgraph.graph import StateGraph, END
 
23
  from langchain_core.prompts import ChatPromptTemplate
24
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
25
 
26
 
27
  from src.pipeline import PipelineManager
 
28
  from src.llm.adapters import get_llm_client
29
+ from src.config.paths import PROJECT_DIR, CONVERSATIONS_DIR
30
+ from src.config.loader import load_config, get_embedding_model_for_collection
31
 
32
 
33
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
48
  needs_follow_up: bool = False
49
  follow_up_question: Optional[str] = None
50
 
51
+
52
  class MultiAgentState(TypedDict):
53
  """State for the multi-agent conversation flow"""
54
  conversation_id: str
 
64
  session_start_time: float
65
  last_ai_message_time: float
66
 
67
+
68
  class MultiAgentRAGChatbot:
69
  """Multi-agent RAG chatbot with specialized agents"""
70
 
 
116
  logger.info("βœ… Pipeline manager initialized and models loaded")
117
  except Exception as e:
118
  logger.error(f"❌ Failed to initialize pipeline manager: {e}")
 
119
  traceback.print_exc()
120
  raise RuntimeError(f"Pipeline manager initialization failed: {e}")
121
 
 
132
  raise # Re-raise RuntimeError as-is
133
  except Exception as e:
134
  logger.error(f"❌ Error during vector store connection: {e}")
 
135
  traceback.print_exc()
136
  raise RuntimeError(f"Vector store connection failed: {e}")
137
 
 
141
  # Build the multi-agent graph
142
  self.graph = self._build_graph()
143
 
144
+ # Conversations directory - use PROJECT_DIR for local vs deployed compatibility
145
+ self.conversations_dir = CONVERSATIONS_DIR
146
  try:
147
  # Use 777 permissions for maximum compatibility (HF Spaces runs as different user)
148
  self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
 
160
 
161
  def _load_dynamic_data(self):
162
  """Load dynamic data from filter_options.json and add_district_metadata.py"""
163
+ # Load filter options - use PROJECT_DIR relative path
164
  try:
165
+ fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
166
  if fo.exists():
167
  with open(fo) as f:
168
  data = json.load(f)
 
180
  self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
181
  self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
182
 
183
+ # Enrich district list from add_district_metadata.py (if available)
184
  try:
185
  from add_district_metadata import DistrictMetadataProcessor
186
  proc = DistrictMetadataProcessor()
 
208
  logger.info(f" Sources: {self.source_whitelist}")
209
  logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
210
 
211
+ def _normalize_district_name(self, district: str) -> Optional[str]:
212
+ """Normalize district name with fuzzy matching for common misspellings."""
213
+ if not district:
214
+ return None
215
+
216
+ district = district.strip()
217
+
218
+ # Direct match
219
+ if district in self.district_whitelist:
220
+ return district
221
+
222
+ # Remove "District" suffix
223
+ district_name = district.replace(" District", "").replace(" district", "").strip()
224
+ if district_name in self.district_whitelist:
225
+ return district_name
226
+
227
+ # Common misspellings mapping
228
+ misspelling_map = {
229
+ "kalagala": "Kalangala",
230
+ "Kalagala": "Kalangala",
231
+ "KALAGALA": "Kalangala",
232
+ "kalangala": "Kalangala",
233
+ "gulu": "Gulu",
234
+ "GULU": "Gulu",
235
+ "kampala": "Kampala",
236
+ "KAMPALA": "Kampala",
237
+ }
238
+
239
+ # Check misspelling map (case-insensitive)
240
+ district_lower = district_name.lower()
241
+ if district_lower in misspelling_map:
242
+ corrected = misspelling_map[district_lower]
243
+ if corrected in self.district_whitelist:
244
+ return corrected
245
+
246
+ # Fuzzy matching for similar names (simple Levenshtein-like check)
247
+ # Check if the district name is very similar to any whitelist entry
248
+ for whitelist_district in self.district_whitelist:
249
+ # Case-insensitive comparison
250
+ if district_name.lower() == whitelist_district.lower():
251
+ return whitelist_district
252
+
253
+ # Check if one is a substring of the other (for partial matches)
254
+ if len(district_name) >= 4 and len(whitelist_district) >= 4:
255
+ if district_name.lower() in whitelist_district.lower() or whitelist_district.lower() in district_name.lower():
256
+ # Only return if it's a strong match (at least 80% of characters match)
257
+ min_len = min(len(district_name), len(whitelist_district))
258
+ max_len = max(len(district_name), len(whitelist_district))
259
+ if min_len / max_len >= 0.8:
260
+ return whitelist_district
261
+
262
+ return None
263
+
264
  def _build_graph(self) -> StateGraph:
265
  """Build the multi-agent LangGraph"""
266
  graph = StateGraph(MultiAgentState)
 
565
  - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
566
  - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
567
  - Always return districts as JSON arrays when multiple districts are mentioned
568
+ - **COMMON MISSPELLINGS**: Handle common misspellings intelligently:
569
+ * "Kalagala" (missing 'n') should be extracted as "Kalangala"
570
+ * "kalagala", "Kalagala", "KALAGALA" should all be normalized to "Kalangala"
571
+ * Similar case-insensitive variations should be normalized to the correct district name
572
  - If no exact matches found, set extracted values to null
573
 
574
  4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
 
649
  # Clean and parse JSON with better error handling
650
  try:
651
  # Remove comments (// and /* */) from JSON
 
652
  # Remove single-line comments
653
  content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
654
  # Remove multi-line comments
 
661
  logger.error(f"❌ Raw content: {content[:200]}...")
662
 
663
  # Try to extract JSON from text if embedded
 
664
  json_match = re.search(r'\{.*\}', content, re.DOTALL)
665
  if json_match:
666
  try:
 
713
  # Validate each district in the array
714
  valid_districts = []
715
  for district in extracted_district:
716
+ normalized = self._normalize_district_name(district)
717
+ if normalized:
718
+ valid_districts.append(normalized)
 
 
 
 
719
 
720
  if valid_districts:
721
  extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
 
724
  logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
725
  extracted_district = None
726
  else:
727
+ # Single district validation with fuzzy matching
728
+ normalized = self._normalize_district_name(extracted_district)
729
+ if normalized:
730
+ if normalized != extracted_district:
731
+ logger.info(f"πŸ” QUERY ANALYSIS: Normalized district '{extracted_district}' to '{normalized}'")
732
+ extracted_district = normalized
733
+ else:
734
+ logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
735
+ extracted_district = None
 
736
 
737
  # Validate source (handle both single values and arrays)
738
  if extracted_source:
 
970
  logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} β†’ normalized: {normalized_districts}")
971
 
972
  # Merge with extracted context for missing filters
973
+ if not filters.get("district") and context.extracted_district:
974
+ # Normalize district names using the normalization function
975
+ if isinstance(context.extracted_district, list):
976
+ normalized_districts = []
977
+ for d in context.extracted_district:
978
+ normalized = self._normalize_district_name(d)
979
+ if normalized:
980
+ normalized_districts.append(normalized)
981
+ if normalized_districts:
982
+ filters["district"] = normalized_districts
983
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
984
+ else:
985
+ normalized = self._normalize_district_name(context.extracted_district)
986
+ if normalized:
987
+ filters["district"] = [normalized]
988
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
989
+
990
  if not filters.get("year") and context.extracted_year:
991
  # Handle both single values and arrays
992
  if isinstance(context.extracted_year, list):
 
995
  filters["year"] = [context.extracted_year]
996
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
997
 
 
 
 
 
 
 
 
 
 
 
998
  if not filters.get("sources") and context.extracted_source:
999
  # Handle both single values and arrays
1000
  if isinstance(context.extracted_source, list):
 
1022
  logger.info(f"πŸ”§ FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
1023
 
1024
  if context.extracted_district:
1025
+ # Normalize district names using the normalization function
1026
  if isinstance(context.extracted_district, list):
1027
+ normalized_districts = []
1028
+ for d in context.extracted_district:
1029
+ normalized = self._normalize_district_name(d)
1030
+ if normalized:
1031
+ normalized_districts.append(normalized)
1032
+ if normalized_districts:
1033
+ filters["district"] = normalized_districts
1034
+ logger.info(f"πŸ”§ FILTER BUILDING: Added districts filter from context: {context.extracted_district} β†’ normalized: {normalized_districts}")
1035
  else:
1036
+ normalized = self._normalize_district_name(context.extracted_district)
1037
+ if normalized:
1038
+ filters["district"] = [normalized]
1039
+ logger.info(f"πŸ”§ FILTER BUILDING: Added district filter from context: {context.extracted_district} β†’ normalized: {normalized}")
1040
 
1041
  logger.info(f"πŸ”§ FILTER BUILDING: Final filters: {filters}")
1042
  return filters
 
1046
  logger.info("πŸ’¬ RESPONSE GENERATION: Starting conversational response generation")
1047
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Processing {len(documents)} documents")
1048
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Query: '{query[:50]}...'")
1049
+ logger.info(f"πŸ’¬ RESPONSE GENERATION: Conversation history: {len(messages)} messages")
1050
+
1051
+ # Build conversation history context
1052
+ conversation_context = self._build_conversation_context(messages)
1053
+
1054
+ # Build detailed document information
1055
+ document_details = self._build_document_details(documents)
1056
+
1057
+ # Extract correct district/source/year names from documents (to correct misspellings)
1058
+ correct_names = self._extract_correct_names_from_documents(documents)
1059
 
1060
  # Create response prompt
1061
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Building response prompt")
1062
  response_prompt = ChatPromptTemplate.from_messages([
1063
  SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
1064
 
1065
+ CRITICAL RULES - NO HALLUCINATION:
1066
+ 1. **ONLY use information from the retrieved documents provided below**
1067
+ 2. **EVERY sentence with facts, numbers, or specific claims MUST have a [Doc i] reference**
1068
+ 3. **If a document doesn't contain the information, DO NOT make it up**
1069
+ 4. **If the user asks about a year/district that's NOT in the retrieved documents, explicitly state that**
1070
+ 5. **Check the document years/districts before making any claims about them**
1071
+ 6. **USE CORRECT NAMES**: If the conversation mentions a misspelled district/source name (e.g., "Kalagala"), use the CORRECT spelling from the document metadata (e.g., "Kalangala"). Always use the exact names from document metadata, not misspellings from conversation.
1072
+
1073
  RULES:
1074
  1. Answer the user's question directly and clearly
1075
+ 2. Use ONLY the retrieved documents as evidence - DO NOT use your training data
1076
  3. Be conversational, not technical
1077
  4. Don't mention scores, retrieval details, or technical implementation
1078
  5. If relevant documents were found, reference them naturally
1079
+ 6. If no relevant documents, say you do not have enough information - DO NOT hallucinate
1080
+ 7. If the passages have useful facts or numbers, use them in your answer WITH references
1081
+ 8. **MANDATORY**: When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
1082
  9. Do not use the sentence 'Doc i says ...' to say where information came from.
1083
  10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
1084
  11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
1085
  12. If it makes sense, use bullet points and lists to make your answers easier to understand.
1086
  13. You do not need to use every passage. Only use the ones that help answer the question.
1087
+ 14. **VERIFY**: Before mentioning any year, district, or number, check that it exists in the retrieved documents. If it doesn't, say "I don't have information about [year/district] in the retrieved documents."
1088
+ 15. **NO HALLUCINATION**: If documents show years 2021, 2022, 2023 but user asks about 2020, DO NOT provide 2020 data. Instead say "The retrieved documents cover 2021-2023, but I don't have information for 2020."
1089
+ 16. **USE CORRECT SPELLING**: Always use the district/source names exactly as they appear in the document metadata below, even if the conversation history has misspellings.
1090
 
1091
  TONE: Professional but friendly, like talking to a colleague."""),
1092
+ HumanMessage(content=f"""Conversation History:
1093
+ {conversation_context}
1094
+
1095
+ Current User Question: {query}
1096
 
1097
  Retrieved Documents: {len(documents)} documents found
1098
 
1099
+ CORRECT NAMES TO USE (from document metadata - use these exact spellings):
1100
+ {correct_names}
1101
+
1102
+ Full Document Details:
1103
+ {document_details}
1104
+
1105
  RAG Answer: {rag_answer}
1106
 
1107
+ CRITICAL:
1108
+ - Responses should be grounded to what is available in the retrieved documents
1109
+ - If user asks about a specific year but documents show other years, or districts or sources then explicitly state "can't provide response on ... because ..."
1110
+ - Every factual claim MUST have [Doc i] reference
1111
+ - If information is not in documents, explicitly state it's not available
1112
+ - **USE THE CORRECT DISTRICT/SOURCE NAMES from the document metadata above, not misspellings from conversation**
1113
+
1114
+ Generate a conversational response with proper document references:""")
1115
  ])
1116
 
1117
  try:
1118
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Calling LLM for final response")
1119
  response = self.llm.invoke(response_prompt.format_messages())
1120
  logger.info(f"πŸ’¬ RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
1121
+
1122
+ # Post-process response to ensure no hallucination
1123
+ final_response = self._validate_and_enhance_response(
1124
+ response.content.strip(),
1125
+ documents,
1126
+ query
1127
+ )
1128
+
1129
+ return final_response
1130
  except Exception as e:
1131
  logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
1132
  logger.info(f"πŸ’¬ RESPONSE GENERATION: Using RAG answer as fallback")
1133
  return rag_answer # Fallback to RAG answer
1134
 
1135
+ def _build_conversation_context(self, messages: List[Any]) -> str:
1136
+ """Build conversation history context for response generation."""
1137
+ if not messages:
1138
+ return "No previous conversation."
1139
+
1140
+ context_lines = []
1141
+ # Show last 6 messages for context (to capture the current exchange)
1142
+ for msg in messages[-6:]:
1143
+ if isinstance(msg, HumanMessage):
1144
+ context_lines.append(f"User: {msg.content}")
1145
+ elif isinstance(msg, AIMessage):
1146
+ context_lines.append(f"Assistant: {msg.content}")
1147
+
1148
+ return "\n".join(context_lines) if context_lines else "No previous conversation."
1149
+
1150
+ def _build_document_details(self, documents: List[Any]) -> str:
1151
+ """Build detailed document information for response generation."""
1152
+ if not documents:
1153
+ return "No documents retrieved."
1154
+
1155
+ details = []
1156
+ for i, doc in enumerate(documents[:15], 1): # Show up to 15 documents
1157
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1158
+ content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
1159
+
1160
+ if isinstance(metadata, dict):
1161
+ filename = metadata.get('filename', 'Unknown')
1162
+ year = metadata.get('year', 'Unknown')
1163
+ district = metadata.get('district', 'Unknown')
1164
+ source = metadata.get('source', 'Unknown')
1165
+ page = metadata.get('page', metadata.get('page_label', 'Unknown'))
1166
+
1167
+ doc_info = f"[Doc {i}]"
1168
+ doc_info += f"\n Filename: {filename}"
1169
+ doc_info += f"\n Year: {year}"
1170
+ doc_info += f"\n District: {district}"
1171
+ doc_info += f"\n Source: {source}"
1172
+ if page != 'Unknown':
1173
+ doc_info += f"\n Page: {page}"
1174
+ doc_info += f"\n Content: {content[:300]}{'...' if len(content) > 300 else ''}"
1175
+ details.append(doc_info)
1176
+
1177
+ return "\n\n".join(details) if details else "No document details available."
1178
+
1179
+ def _extract_correct_names_from_documents(self, documents: List[Any]) -> str:
1180
+ """Extract correct district/source names from documents to correct misspellings."""
1181
+ districts = set()
1182
+ sources = set()
1183
+ years = set()
1184
+
1185
+ for doc in documents:
1186
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1187
+ if isinstance(metadata, dict):
1188
+ if metadata.get('district'):
1189
+ districts.add(str(metadata['district']))
1190
+ if metadata.get('source'):
1191
+ sources.add(str(metadata['source']))
1192
+ if metadata.get('year'):
1193
+ years.add(str(metadata['year']))
1194
+
1195
+ result = []
1196
+ if districts:
1197
+ result.append(f"Districts: {', '.join(sorted(districts))}")
1198
+ if sources:
1199
+ result.append(f"Sources: {', '.join(sorted(sources))}")
1200
+ if years:
1201
+ result.append(f"Years: {', '.join(sorted(years))}")
1202
+
1203
+ if result:
1204
+ return "\n".join(result) + "\n\nIMPORTANT: Use these EXACT spellings in your response, even if the conversation history has misspellings."
1205
+ return "No metadata available."
1206
+
1207
+ def _validate_and_enhance_response(self, response: str, documents: List[Any], query: str) -> str:
1208
+ """Validate response and ensure all claims are referenced."""
1209
+ # Extract years and districts from documents
1210
+ doc_years = set()
1211
+ doc_districts = set()
1212
+ doc_sources = set()
1213
+
1214
+ for doc in documents:
1215
+ metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
1216
+ if isinstance(metadata, dict):
1217
+ if metadata.get('year'):
1218
+ doc_years.add(str(metadata['year']))
1219
+ if metadata.get('district'):
1220
+ doc_districts.add(str(metadata['district']))
1221
+ if metadata.get('source'):
1222
+ doc_sources.add(str(metadata['source']))
1223
+
1224
+ # Correct misspellings in response using correct names from documents
1225
+ # response = self._correct_misspellings_in_response(response, doc_districts, doc_sources)
1226
+
1227
+ # Check if response mentions years not in documents
1228
+ year_pattern = r'\b(20\d{2})\b'
1229
+ mentioned_years = set(re.findall(year_pattern, response))
1230
+
1231
+ # Check if user query mentions a year
1232
+ query_years = set(re.findall(year_pattern, query))
1233
+
1234
+ # If user asks about a year not in documents, add a warning
1235
+ missing_years = query_years - doc_years
1236
+ if missing_years and doc_years:
1237
+ warning = f"\n\n⚠️ Note: The retrieved documents cover years {', '.join(sorted(doc_years))}, but I don't have information for {', '.join(sorted(missing_years))} in the retrieved documents."
1238
+ if warning not in response:
1239
+ response = response + warning
1240
+
1241
+ # Check if response has document references
1242
+ doc_ref_pattern = r'\[Doc\s+\d+\]'
1243
+ has_refs = bool(re.search(doc_ref_pattern, response))
1244
+
1245
+ # If response has factual claims but no references, add a note
1246
+ if not has_refs and len(documents) > 0:
1247
+ # Check if response has numbers or specific claims (simple heuristic)
1248
+ has_numbers = bool(re.search(r'\d+', response))
1249
+ if has_numbers and len(response) > 50:
1250
+ logger.warning("⚠️ Response contains factual claims but no document references")
1251
+ # Don't modify response, but log the issue
1252
+
1253
+ return response
1254
+
1255
  def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1256
  """Generate conversational response using only LLM knowledge and conversation history"""
1257
  logger.info("πŸ’¬ RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
 
1409
 
1410
  except Exception as e:
1411
  logger.error(f"Could not save conversation: {e}")
 
1412
  logger.error(f"Traceback: {traceback.format_exc()}")
1413
 
1414
 
smart_chatbot.py β†’ src/agents/smart_chatbot.py RENAMED
@@ -26,6 +26,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
26
 
27
  from src.pipeline import PipelineManager
28
  from src.config.loader import load_config
 
29
 
30
 
31
  @dataclass
@@ -161,7 +162,7 @@ class IntelligentRAGChatbot:
161
 
162
  # Try to load district whitelist from filter_options.json
163
  try:
164
- fo = Path("filter_options.json")
165
  if fo.exists():
166
  with open(fo) as f:
167
  data = json.load(f)
@@ -174,7 +175,7 @@ class IntelligentRAGChatbot:
174
  except Exception:
175
  self.district_whitelist = self.available_metadata['districts']
176
 
177
- # Enrich whitelist from add_district_metadata.py if available
178
  try:
179
  from add_district_metadata import DistrictMetadataProcessor
180
  proc = DistrictMetadataProcessor()
@@ -195,7 +196,7 @@ class IntelligentRAGChatbot:
195
 
196
  # Get dynamic year list from filter_options.json
197
  try:
198
- fo = Path("filter_options.json")
199
  if fo.exists():
200
  with open(fo) as f:
201
  data = json.load(f)
 
26
 
27
  from src.pipeline import PipelineManager
28
  from src.config.loader import load_config
29
+ from src.config.paths import PROJECT_DIR
30
 
31
 
32
  @dataclass
 
162
 
163
  # Try to load district whitelist from filter_options.json
164
  try:
165
+ fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
166
  if fo.exists():
167
  with open(fo) as f:
168
  data = json.load(f)
 
175
  except Exception:
176
  self.district_whitelist = self.available_metadata['districts']
177
 
178
+ # Enrich whitelist from add_district_metadata.py if available (optional module)
179
  try:
180
  from add_district_metadata import DistrictMetadataProcessor
181
  proc = DistrictMetadataProcessor()
 
196
 
197
  # Get dynamic year list from filter_options.json
198
  try:
199
+ fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
200
  if fo.exists():
201
  with open(fo) as f:
202
  data = json.load(f)
src/config/paths.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Path configuration for local vs deployed environments.
3
+
4
+ This module handles different paths for local development vs deployed (HF Spaces) environments.
5
+ """
6
+ import os
7
+ from pathlib import Path
8
+
9
+ # Determine if we're in a deployed environment (HF Spaces/Docker) or local
10
+ # Check for environment variable or Docker-like paths
11
+ IS_DEPLOYED = (
12
+ os.getenv("DEPLOYED", "false").lower() == "true" or
13
+ os.path.exists("/app") or
14
+ os.getenv("SPACES_ID") is not None or
15
+ os.path.exists("/.dockerenv")
16
+ )
17
+
18
+ # PROJECT_DIR: Base directory for application files
19
+ # In deployed: /app, in local: current working directory or project root
20
+ if IS_DEPLOYED:
21
+ PROJECT_DIR = Path("/app")
22
+ else:
23
+ # For local development, use current working directory or find project root
24
+ cwd = Path.cwd()
25
+ # Try to find project root (directory containing this src/ folder)
26
+ project_root = cwd
27
+ while project_root != project_root.parent:
28
+ if (project_root / "src" / "config").exists():
29
+ break
30
+ project_root = project_root.parent
31
+ PROJECT_DIR = project_root
32
+
33
+ # Cache directories - different for local vs deployed
34
+ # Local: Use default user cache locations (don't override)
35
+ # Deployed: Use PROJECT_DIR/.cache
36
+ if IS_DEPLOYED:
37
+ CACHE_DIR = PROJECT_DIR / ".cache"
38
+ HF_CACHE_DIR = CACHE_DIR / "huggingface"
39
+ STREAMLIT_CACHE_DIR = CACHE_DIR / "streamlit"
40
+ else:
41
+ # For local, use default user cache (let libraries use their defaults)
42
+ HF_CACHE_DIR = None # Will use HF defaults (~/.cache/huggingface)
43
+ STREAMLIT_CACHE_DIR = None # Will use Streamlit defaults
44
+
45
+ # Application directories
46
+ FEEDBACK_DIR = PROJECT_DIR / "feedback"
47
+ CONVERSATIONS_DIR = PROJECT_DIR / "conversations"
48
+ STREAMLIT_CONFIG_DIR = PROJECT_DIR / ".streamlit"
49
+
50
+ # Log the configuration
51
+ if __name__ == "__main__":
52
+ print(f"IS_DEPLOYED: {IS_DEPLOYED}")
53
+ print(f"PROJECT_DIR: {PROJECT_DIR}")
54
+ print(f"HF_CACHE_DIR: {HF_CACHE_DIR}")
55
+ print(f"FEEDBACK_DIR: {FEEDBACK_DIR}")
56
+ print(f"CONVERSATIONS_DIR: {CONVERSATIONS_DIR}")
57
+
58
+
59
+
src/feedback/__init__.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Management Module
3
+
4
+ This module provides a unified interface for handling user feedback,
5
+ including data preparation, validation, and Snowflake storage.
6
+ """
7
+
8
+ from typing import Dict, Any, List, Optional
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+
11
+ from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
12
+ from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
13
+
14
+
15
+ class FeedbackManager:
16
+ """
17
+ Unified manager for feedback operations.
18
+
19
+ This class provides a single interface for all feedback-related functionality,
20
+ including data preparation, validation, and storage.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize the FeedbackManager"""
25
+ pass
26
+
27
+ @staticmethod
28
+ def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
29
+ """Extract transcript from messages - only user and bot messages, no extra metadata"""
30
+ transcript = []
31
+ for msg in messages:
32
+ if isinstance(msg, HumanMessage):
33
+ transcript.append({
34
+ "role": "user",
35
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
36
+ })
37
+ elif isinstance(msg, AIMessage):
38
+ transcript.append({
39
+ "role": "assistant",
40
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
41
+ })
42
+ return transcript
43
+
44
+ @staticmethod
45
+ def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
46
+ """Build retrievals structure from retrieval history"""
47
+ retrievals = []
48
+
49
+ for entry in rag_retrieval_history:
50
+ # Get the user message that triggered this retrieval
51
+ # The entry has conversation_up_to which includes messages up to that point
52
+ conversation_up_to = entry.get("conversation_up_to", [])
53
+
54
+ # Find the last user message in conversation_up_to (this is the trigger)
55
+ user_message_trigger = ""
56
+ for msg_dict in reversed(conversation_up_to):
57
+ if msg_dict.get("type") == "HumanMessage":
58
+ user_message_trigger = msg_dict.get("content", "")
59
+ break
60
+
61
+ # Fallback: if not found in conversation_up_to, get from actual messages
62
+ # This handles edge cases where conversation_up_to might be incomplete
63
+ if not user_message_trigger:
64
+ # Find which retrieval this is (0-indexed)
65
+ retrieval_idx = rag_retrieval_history.index(entry)
66
+ # The user message that triggered this retrieval is at position (retrieval_idx * 2)
67
+ # because each retrieval is preceded by: user message, bot response, user message, ...
68
+ # But we need to account for the fact that the first retrieval happens after the first user message
69
+ user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
70
+ if retrieval_idx < len(user_msgs):
71
+ user_message_trigger = str(user_msgs[retrieval_idx].content)
72
+ elif user_msgs:
73
+ # Fallback to last user message
74
+ user_message_trigger = str(user_msgs[-1].content)
75
+
76
+ # Get retrieved documents and truncate content to 100 chars
77
+ docs_retrieved = entry.get("docs_retrieved", [])
78
+ retrieved_docs = []
79
+ for doc in docs_retrieved:
80
+ doc_copy = doc.copy()
81
+ # Truncate content to 100 characters (keep all other fields)
82
+ if "content" in doc_copy:
83
+ doc_copy["content"] = doc_copy["content"][:100]
84
+ retrieved_docs.append(doc_copy)
85
+
86
+ retrievals.append({
87
+ "retrieved_docs": retrieved_docs,
88
+ "user_message_trigger": user_message_trigger
89
+ })
90
+
91
+ return retrievals
92
+
93
+ @staticmethod
94
+ def build_feedback_score_related_retrieval_docs(
95
+ is_feedback_about_last_retrieval: bool,
96
+ messages: List[Any],
97
+ rag_retrieval_history: List[Dict[str, Any]]
98
+ ) -> Optional[Dict[str, Any]]:
99
+ """Build feedback_score_related_retrieval_docs structure"""
100
+ if not rag_retrieval_history:
101
+ return None
102
+
103
+ # Get the relevant retrieval entry
104
+ if is_feedback_about_last_retrieval:
105
+ relevant_entry = rag_retrieval_history[-1]
106
+ else:
107
+ # If feedback is about all retrievals, use the last one as default
108
+ relevant_entry = rag_retrieval_history[-1]
109
+
110
+ # Get conversation up to that point
111
+ conversation_up_to = relevant_entry.get("conversation_up_to", [])
112
+
113
+ # Convert to transcript format (role/content)
114
+ conversation_up_to_point = []
115
+ for msg_dict in conversation_up_to:
116
+ if msg_dict.get("type") == "HumanMessage":
117
+ conversation_up_to_point.append({
118
+ "role": "user",
119
+ "content": msg_dict.get("content", "")
120
+ })
121
+ elif msg_dict.get("type") == "AIMessage":
122
+ conversation_up_to_point.append({
123
+ "role": "assistant",
124
+ "content": msg_dict.get("content", "")
125
+ })
126
+
127
+ # Get retrieved docs with full content (not truncated)
128
+ retrieved_docs = relevant_entry.get("docs_retrieved", [])
129
+
130
+ return {
131
+ "conversation_up_to_point": conversation_up_to_point,
132
+ "retrieved_docs": retrieved_docs
133
+ }
134
+
135
+ @staticmethod
136
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
137
+ """Create UserFeedback instance from dictionary"""
138
+ return create_feedback_from_dict(data)
139
+
140
+ @staticmethod
141
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
142
+ """Save feedback to Snowflake"""
143
+ return save_to_snowflake(feedback, table_name)
144
+
145
+ @staticmethod
146
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
147
+ """Generate Snowflake schema SQL"""
148
+ return generate_snowflake_schema_sql(table_name)
149
+
150
+
151
+ __all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]
152
+
src/feedback/feedback_schema.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Schema for RAG Chatbot
3
+
4
+ This module defines dataclasses for feedback data structures
5
+ and provides Snowflake schema generation.
6
+ """
7
+ import os
8
+ from datetime import datetime
9
+ from dataclasses import dataclass, asdict, field
10
+ from typing import List, Optional, Dict, Any, Union
11
+
12
+
13
+
14
+ @dataclass
15
+ class RetrievedDocument:
16
+ """Single retrieved document metadata"""
17
+ doc_id: str
18
+ filename: str
19
+ page: int
20
+ score: float
21
+ content: str
22
+ metadata: Dict[str, Any]
23
+
24
+
25
+ @dataclass
26
+ class RetrievalEntry:
27
+ """Single retrieval operation metadata"""
28
+ rag_query: str
29
+ documents_retrieved: List[RetrievedDocument]
30
+ conversation_length: int
31
+ filters_applied: Optional[Dict[str, Any]] = None
32
+ timestamp: Optional[float] = None
33
+ _raw_data: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ @dataclass
37
+ class UserFeedback:
38
+ """User feedback submission data"""
39
+ feedback_id: str
40
+ open_ended_feedback: Optional[str]
41
+ score: int
42
+ is_feedback_about_last_retrieval: bool
43
+ conversation_id: str
44
+ timestamp: float
45
+ message_count: int
46
+ has_retrievals: bool
47
+ retrieval_count: int
48
+ transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
49
+ retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
50
+ feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
51
+ retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
52
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert to dictionary with nested data structures"""
56
+ result = asdict(self)
57
+ return result
58
+
59
+ def to_snowflake_schema(self) -> Dict[str, Any]:
60
+ """Generate Snowflake schema for this dataclass"""
61
+ schema = {
62
+ "feedback_id": "VARCHAR(255)",
63
+ "open_ended_feedback": "VARCHAR(16777216)", # Large text
64
+ "score": "INTEGER",
65
+ "is_feedback_about_last_retrieval": "BOOLEAN",
66
+ "conversation_id": "VARCHAR(255)",
67
+ "timestamp": "NUMBER(20, 0)",
68
+ "message_count": "INTEGER",
69
+ "has_retrievals": "BOOLEAN",
70
+ "retrieval_count": "INTEGER",
71
+ "transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
72
+ "retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
73
+ "feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
74
+ "retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
75
+ "created_at": "TIMESTAMP_NTZ",
76
+ # transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
77
+ # retrievals structure: [
78
+ # {
79
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
80
+ # "user_message_trigger": "final user message that triggered this retrieval"
81
+ # },
82
+ # ...
83
+ # ]
84
+ # feedback_score_related_retrieval_docs structure: {
85
+ # "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
86
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
87
+ # }
88
+ }
89
+ return schema
90
+
91
+ @classmethod
92
+ def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
93
+ """Generate CREATE TABLE SQL for Snowflake"""
94
+ schema = cls.to_snowflake_schema(None)
95
+
96
+ columns = []
97
+ for col_name, col_type in schema.items():
98
+ nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
99
+ columns.append(f" {col_name} {col_type} {nullable}")
100
+
101
+ # Build SQL string properly
102
+ columns_str = ",\n".join(columns)
103
+
104
+ sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
105
+ {columns_str},
106
+ PRIMARY KEY (feedback_id)
107
+ )
108
+ CLUSTER BY (timestamp, conversation_id, score);
109
+ -- Note: Snowflake doesn't support traditional indexes on regular tables.
110
+ -- Instead, we use CLUSTER BY to optimize queries on these columns.
111
+ -- Snowflake automatically maintains clustering for efficient querying.
112
+ -- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
113
+ -- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
114
+ """
115
+ return sql
116
+
117
+
118
+ # Snowflake variant schema for retrieved_data array
119
+ RETRIEVAL_ENTRY_SCHEMA = {
120
+ "rag_query": "VARCHAR",
121
+ "documents_retrieved": "ARRAY", # Array of document objects
122
+ "conversation_length": "INTEGER",
123
+ "filters_applied": "OBJECT",
124
+ "timestamp": "NUMBER"
125
+ }
126
+
127
+ DOCUMENT_SCHEMA = {
128
+ "doc_id": "VARCHAR",
129
+ "filename": "VARCHAR",
130
+ "page": "INTEGER",
131
+ "score": "DOUBLE",
132
+ "content": "VARCHAR(16777216)",
133
+ "metadata": "OBJECT"
134
+ }
135
+
136
+
137
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
138
+ """Generate complete Snowflake schema SQL for feedback system"""
139
+ if table_name is None:
140
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
141
+ return UserFeedback.get_snowflake_create_table_sql(table_name)
142
+
143
+
144
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
145
+ """Create UserFeedback instance from dictionary"""
146
+ return UserFeedback(
147
+ feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
148
+ open_ended_feedback=data.get("open_ended_feedback"),
149
+ score=data["score"],
150
+ is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
151
+ conversation_id=data["conversation_id"],
152
+ timestamp=data["timestamp"],
153
+ message_count=data["message_count"],
154
+ has_retrievals=data["has_retrievals"],
155
+ retrieval_count=data["retrieval_count"],
156
+ transcript=data.get("transcript", []),
157
+ retrievals=data.get("retrievals", []),
158
+ feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
159
+ retrieved_data=data.get("retrieved_data")
160
+ )
161
+
src/feedback/snowflake_connector.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snowflake Connector for Feedback System
3
+
4
+ This module handles inserting user feedback into Snowflake.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from typing import Dict, Any, Optional
11
+ from .feedback_schema import UserFeedback
12
+
13
+ # Try to import snowflake connector
14
+ try:
15
+ import snowflake.connector
16
+ SNOWFLAKE_AVAILABLE = True
17
+ except ImportError:
18
+ SNOWFLAKE_AVAILABLE = False
19
+ logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SnowflakeFeedbackConnector:
27
+ """Connector for inserting feedback into Snowflake"""
28
+
29
+ def __init__(
30
+ self,
31
+ user: str,
32
+ password: str,
33
+ account: str,
34
+ warehouse: str,
35
+ database: str = "SNOWFLAKE_LEARNING",
36
+ schema: str = "PUBLIC"
37
+ ):
38
+ self.user = user
39
+ self.password = password
40
+ self.account = account
41
+ self.warehouse = warehouse
42
+ self.database = database
43
+ self.schema = schema
44
+ self._connection = None
45
+
46
+ def connect(self):
47
+ """Establish Snowflake connection"""
48
+ if not SNOWFLAKE_AVAILABLE:
49
+ raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
50
+
51
+ logger.info("=" * 80)
52
+ logger.info("πŸ”Œ SNOWFLAKE CONNECTION: Attempting to connect...")
53
+ logger.info(f" - Account: {self.account}")
54
+ logger.info(f" - Warehouse: {self.warehouse}")
55
+ logger.info(f" - Database: {self.database}")
56
+ logger.info(f" - Schema: {self.schema}")
57
+ logger.info(f" - User: {self.user}")
58
+
59
+ try:
60
+ self._connection = snowflake.connector.connect(
61
+ user=self.user,
62
+ password=self.password,
63
+ account=self.account,
64
+ warehouse=self.warehouse
65
+ # Don't set database/schema in connection - we'll do it per query
66
+ )
67
+ logger.info("βœ… SNOWFLAKE CONNECTION: Successfully connected")
68
+ logger.info("=" * 80)
69
+ print(f"βœ… Connected to Snowflake: {self.database}.{self.schema}")
70
+ except Exception as e:
71
+ logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
72
+ logger.error("=" * 80)
73
+ print(f"❌ Failed to connect to Snowflake: {e}")
74
+ raise
75
+
76
+ def disconnect(self):
77
+ """Close Snowflake connection"""
78
+ if self._connection:
79
+ self._connection.close()
80
+ print("βœ… Disconnected from Snowflake")
81
+
82
+ def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
83
+ """Insert a single feedback record into Snowflake"""
84
+ logger.info("=" * 80)
85
+ logger.info("πŸ”„ SNOWFLAKE INSERT: Starting feedback insertion process")
86
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
87
+
88
+ # Get table name from parameter, env var, or default
89
+ if table_name is None:
90
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
91
+
92
+ if not self._connection:
93
+ logger.error("❌ Not connected to Snowflake. Call connect() first.")
94
+ raise RuntimeError("Not connected to Snowflake. Call connect() first.")
95
+
96
+ try:
97
+ logger.info("πŸ“Š VALIDATION: Validating feedback data structure...")
98
+
99
+ # Validate feedback object
100
+ validation_errors = []
101
+ if not feedback.feedback_id:
102
+ validation_errors.append("Missing feedback_id")
103
+ if feedback.score is None:
104
+ validation_errors.append("Missing score")
105
+ if feedback.timestamp is None:
106
+ validation_errors.append("Missing timestamp")
107
+
108
+ if validation_errors:
109
+ logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
110
+ return False
111
+ else:
112
+ logger.info("βœ… VALIDATION PASSED: All required fields present")
113
+
114
+ logger.info("πŸ“‹ Data Summary:")
115
+ logger.info(f" - Feedback ID: {feedback.feedback_id}")
116
+ logger.info(f" - Score: {feedback.score}")
117
+ logger.info(f" - Conversation ID: {feedback.conversation_id}")
118
+ logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
119
+ logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
120
+ logger.info(f" - Message Count: {feedback.message_count}")
121
+ logger.info(f" - Timestamp: {feedback.timestamp}")
122
+
123
+ cursor = self._connection.cursor()
124
+ logger.info("βœ… SNOWFLAKE CONNECTION: Cursor created")
125
+
126
+ # Set database and schema context
127
+ logger.info(f"πŸ”§ SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
128
+ try:
129
+ cursor.execute(f'USE DATABASE "{self.database}"')
130
+ cursor.execute(f'USE SCHEMA "{self.schema}"')
131
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
132
+ current_db, current_schema = cursor.fetchone()
133
+ logger.info(f"βœ… Current context verified: Database={current_db}, Schema={current_schema}")
134
+ except Exception as e:
135
+ logger.error(f"❌ Could not set context: {e}")
136
+ raise
137
+
138
+ # Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
139
+ logger.info("πŸ”§ DATA PREPARATION: Preparing VARIANT columns...")
140
+ feedback_dict = feedback.to_dict()
141
+
142
+ # Prepare transcript (ARRAY) - convert to JSON string
143
+ transcript_raw = feedback_dict.get('transcript', [])
144
+ if transcript_raw:
145
+ # Convert to JSON string (same approach as old retrieved_data)
146
+ transcript_for_db = json.dumps(transcript_raw)
147
+ logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
148
+ else:
149
+ transcript_for_db = None
150
+ logger.info(" - Transcript: None")
151
+
152
+ # Prepare retrievals (ARRAY) - convert to JSON string
153
+ retrievals_raw = feedback_dict.get('retrievals', [])
154
+ if retrievals_raw:
155
+ # Convert to JSON string (same approach as old retrieved_data)
156
+ retrievals_for_db = json.dumps(retrievals_raw)
157
+ logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
158
+ else:
159
+ retrievals_for_db = None
160
+ logger.info(" - Retrievals: None")
161
+
162
+ # Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
163
+ feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
164
+ if feedback_score_related_raw:
165
+ # Convert to JSON string (same approach as old retrieved_data)
166
+ feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
167
+ logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
168
+ else:
169
+ feedback_score_related_for_db = None
170
+ logger.info(" - Feedback score related docs: None")
171
+
172
+ # Prepare retrieved_data (preserved old column) - convert to JSON string
173
+ retrieved_data_raw = feedback_dict.get('retrieved_data')
174
+ if retrieved_data_raw:
175
+ # Convert to JSON string (same approach as old retrieved_data)
176
+ retrieved_data_for_db = json.dumps(retrieved_data_raw)
177
+ logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
178
+ else:
179
+ retrieved_data_for_db = None
180
+ logger.info(" - Retrieved data (preserved): None")
181
+
182
+ # Build SQL with new column structure
183
+ # Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
184
+ sql = f"""INSERT INTO {table_name} (
185
+ feedback_id,
186
+ open_ended_feedback,
187
+ score,
188
+ is_feedback_about_last_retrieval,
189
+ conversation_id,
190
+ timestamp,
191
+ message_count,
192
+ has_retrievals,
193
+ retrieval_count,
194
+ transcript,
195
+ retrievals,
196
+ feedback_score_related_retrieval_docs,
197
+ retrieved_data,
198
+ created_at
199
+ ) VALUES (
200
+ %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
201
+ %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
202
+ %(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
203
+ %(retrieved_data)s, %(created_at)s
204
+ )"""
205
+
206
+ logger.info("πŸ“ SQL PREPARATION: Building INSERT statement...")
207
+ logger.info(f" - Target table: {table_name}")
208
+ logger.info(f" - Database: {self.database}")
209
+ logger.info(f" - Schema: {self.schema}")
210
+
211
+ # Prepare parameters
212
+ # Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
213
+ params = {
214
+ 'feedback_id': feedback.feedback_id,
215
+ 'open_ended_feedback': feedback.open_ended_feedback,
216
+ 'score': feedback.score,
217
+ 'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
218
+ 'conversation_id': feedback.conversation_id,
219
+ 'timestamp': int(feedback.timestamp),
220
+ 'message_count': feedback.message_count,
221
+ 'has_retrievals': feedback.has_retrievals,
222
+ 'retrieval_count': feedback.retrieval_count,
223
+ 'transcript': transcript_for_db, # JSON string
224
+ 'retrievals': retrievals_for_db, # JSON string
225
+ 'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
226
+ 'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
227
+ 'created_at': feedback.created_at
228
+ }
229
+
230
+ # Execute insert
231
+ logger.info("πŸš€ SQL EXECUTION: Executing INSERT query...")
232
+ cursor.execute(sql, params)
233
+
234
+ logger.info("βœ… SQL EXECUTION: Query executed successfully")
235
+ logger.info(f" - Rows affected: 1")
236
+ logger.info(f" - Status: SUCCESS")
237
+
238
+ cursor.close()
239
+ logger.info("βœ… SNOWFLAKE INSERT: Feedback inserted successfully")
240
+ logger.info(f"πŸ“ Inserted feedback: {feedback.feedback_id}")
241
+ logger.info("=" * 80)
242
+ return True
243
+
244
+ except Exception as e:
245
+ # Check if it's a Snowflake error
246
+ if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
247
+ logger.error(f"❌ SQL EXECUTION ERROR: {e}")
248
+ logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
249
+ logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
250
+ else:
251
+ logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
252
+ logger.error(f" - Error: {e}")
253
+ logger.error("=" * 80)
254
+ return False
255
+
256
+ def __enter__(self):
257
+ """Context manager entry"""
258
+ self.connect()
259
+ return self
260
+
261
+ def __exit__(self, exc_type, exc_val, exc_tb):
262
+ """Context manager exit"""
263
+ self.disconnect()
264
+
265
+
266
+ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
267
+ """Create Snowflake connector from environment variables"""
268
+ user = os.getenv("SNOWFLAKE_USER")
269
+ password = os.getenv("SNOWFLAKE_PASSWORD")
270
+ account = os.getenv("SNOWFLAKE_ACCOUNT")
271
+ warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
272
+ database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
273
+ schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
274
+
275
+ if not all([user, password, account, warehouse]):
276
+ print("⚠️ Snowflake credentials not found in environment variables")
277
+ print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
278
+ return None
279
+
280
+ return SnowflakeFeedbackConnector(
281
+ user=user,
282
+ password=password,
283
+ account=account,
284
+ warehouse=warehouse,
285
+ database=database,
286
+ schema=schema
287
+ )
288
+
289
+
290
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
291
+ """Helper function to save feedback to Snowflake"""
292
+ logger.info("=" * 80)
293
+ logger.info("πŸ”΅ SNOWFLAKE SAVE: Starting save process")
294
+ logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
295
+
296
+ # Get table name from parameter or env var
297
+ if table_name is None:
298
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
299
+
300
+ connector = get_snowflake_connector_from_env()
301
+
302
+ if not connector:
303
+ logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
304
+ logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
305
+ logger.info("=" * 80)
306
+ return False
307
+
308
+ try:
309
+ logger.info("πŸ“‘ SNOWFLAKE SAVE: Establishing connection...")
310
+ connector.connect()
311
+ logger.info("βœ… SNOWFLAKE SAVE: Connection established")
312
+
313
+ logger.info("πŸ“₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
314
+ success = connector.insert_feedback(feedback, table_name=table_name)
315
+
316
+ logger.info("πŸ”Œ SNOWFLAKE SAVE: Disconnecting...")
317
+ connector.disconnect()
318
+
319
+ if success:
320
+ logger.info("βœ… SNOWFLAKE SAVE: Successfully saved feedback")
321
+ else:
322
+ logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
323
+
324
+ logger.info("=" * 80)
325
+ return success
326
+ except Exception as e:
327
+ logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
328
+ logger.error(f" - Error: {e}")
329
+ logger.info("=" * 80)
330
+ return False
331
+
src/gemini/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Integration Module
3
+
4
+ This module provides integration with Google Gemini File Search API
5
+ for RAG functionality using Gemini's built-in file search capabilities.
6
+ """
7
+
8
+ from .file_search import GeminiFileSearchClient, GeminiFileSearchResult
9
+
10
+ __all__ = ["GeminiFileSearchClient", "GeminiFileSearchResult"]
11
+
src/gemini/file_search.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini File Search Client
3
+
4
+ Handles interaction with Google Gemini File Search API for RAG.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from typing import List, Dict, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ try:
15
+ from google import genai
16
+ from google.genai import types
17
+ GEMINI_AVAILABLE = True
18
+ except ImportError:
19
+ GEMINI_AVAILABLE = False
20
+
21
+
22
+ @dataclass
23
+ class GeminiFileSearchResult:
24
+ """Result from Gemini File Search query"""
25
+ answer: str
26
+ sources: List[Dict[str, Any]] # List of document references
27
+ grounding_metadata: Optional[Dict[str, Any]] = None
28
+ query: str = ""
29
+
30
+
31
+ class GeminiFileSearchClient:
32
+ """Client for interacting with Gemini File Search API"""
33
+
34
+ def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
35
+ """
36
+ Initialize Gemini File Search client.
37
+
38
+ Args:
39
+ api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
40
+ store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
41
+ """
42
+ if not GEMINI_AVAILABLE:
43
+ raise ImportError("google-genai package not installed. Install with: pip install google-genai")
44
+
45
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
46
+ if not self.api_key:
47
+ raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
48
+
49
+ store_name_raw = store_name or os.getenv("GEMINI_FILESTORE_NAME")
50
+ if not store_name_raw:
51
+ raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
52
+
53
+ # Normalize store name: API expects the FULL path format (fileSearchStores/xxx)
54
+ # If just the ID is provided, construct the full path
55
+ if store_name_raw.startswith("fileSearchStores/"):
56
+ self.store_name = store_name_raw # Already full path
57
+ else:
58
+ # Just the ID provided, construct full path
59
+ self.store_name = f"fileSearchStores/{store_name_raw}"
60
+
61
+ logger.info(f"πŸ“¦ Using file search store: {self.store_name}")
62
+
63
+ self.client = genai.Client(api_key=self.api_key)
64
+ self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
65
+
66
+ def search(
67
+ self,
68
+ query: str,
69
+ filters: Optional[Dict[str, Any]] = None,
70
+ model: Optional[str] = None
71
+ ) -> GeminiFileSearchResult:
72
+ """
73
+ Search using Gemini File Search.
74
+
75
+ Args:
76
+ query: User query
77
+ filters: Optional filters (year, source, district, etc.)
78
+ model: Model to use (defaults to gemini-2.5-flash)
79
+
80
+ Returns:
81
+ GeminiFileSearchResult with answer and sources
82
+ """
83
+ model = model or self.model
84
+
85
+ # Build filter context for the query if filters are provided
86
+ # Gemini File Search doesn't support explicit filters in the API,
87
+ # so we add them as context in the query
88
+ filter_context = ""
89
+ if filters:
90
+ filter_parts = []
91
+ if filters.get("year"):
92
+ years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
93
+ filter_parts.append(f"Year: {', '.join(years)}")
94
+ if filters.get("sources"):
95
+ sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
96
+ filter_parts.append(f"Source: {', '.join(sources)}")
97
+ if filters.get("district"):
98
+ districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
99
+ filter_parts.append(f"District: {', '.join(districts)}")
100
+ if filters.get("filenames"):
101
+ filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
102
+ filter_parts.append(f"Filename: {', '.join(filenames)}")
103
+
104
+ if filter_parts:
105
+ filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
106
+
107
+ # Combine query with filter context
108
+ # Add comprehensive system instructions similar to multi-agent system
109
+ system_instructions = """You are a helpful audit report assistant specialized in analyzing government audit reports from Uganda's Office of the Auditor General.
110
+
111
+ CRITICAL RULES:
112
+ 1. **NO HALLUCINATION**: Only use information that is explicitly stated in the retrieved documents. Do not make up facts, numbers, or details.
113
+ 2. **Document References**: Always cite which documents you're using with [Doc i] references at the end of sentences that use specific information.
114
+ 3. **Formatting**: Structure your response with clear paragraphs, bullet points, or sections for readability.
115
+ 4. **Accuracy**: If the retrieved documents don't contain the requested information, explicitly state "The retrieved documents do not contain information about [topic]."
116
+ 5. **Years and Data**: Pay careful attention to years mentioned in documents. If a user asks about a specific year but documents show different years, explicitly state this.
117
+ 6. **District/Source Names**: Use the exact district and source names as they appear in the document metadata (e.g., "Kalangala" not "Kalagala").
118
+ 7. **Financial Data**: When providing financial figures, include the currency (UGX) and be precise about amounts.
119
+ 8. **Conversational Tone**: Be helpful, clear, and conversational while maintaining accuracy.
120
+
121
+ IMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents."""
122
+
123
+ # Combine system instructions with query
124
+ full_query = f"{system_instructions}\n\nUser Question: {query}{filter_context}\n\nPlease provide a detailed, well-formatted response with proper document references."
125
+
126
+ try:
127
+ # Generate content with file search
128
+ # Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
129
+ # Try with full path format first, then fallback to just ID if needed
130
+ store_name_to_try = self.store_name
131
+
132
+ try:
133
+ # Try the documented format first with full path
134
+ response = self.client.models.generate_content(
135
+ model=model,
136
+ contents=full_query,
137
+ config=types.GenerateContentConfig(
138
+ tools=[
139
+ types.Tool(
140
+ file_search=types.FileSearch(
141
+ file_search_store_names=[store_name_to_try]
142
+ )
143
+ )
144
+ ]
145
+ )
146
+ )
147
+ except Exception as api_error:
148
+ error_str = str(api_error).lower()
149
+ # If format error, try with just the ID (without fileSearchStores/ prefix)
150
+ if 'format' in error_str or 'invalid' in error_str or 'too long' in error_str:
151
+ logger.warning(f"Full path format failed, trying with just store ID: {api_error}")
152
+ # Extract just the ID part
153
+ if store_name_to_try.startswith("fileSearchStores/"):
154
+ store_id = store_name_to_try.split("/", 1)[1]
155
+ store_name_to_try = store_id
156
+
157
+ try:
158
+ response = self.client.models.generate_content(
159
+ model=model,
160
+ contents=full_query,
161
+ config=types.GenerateContentConfig(
162
+ tools=[
163
+ types.Tool(
164
+ file_search=types.FileSearch(
165
+ file_search_store_names=[store_name_to_try]
166
+ )
167
+ )
168
+ ]
169
+ )
170
+ )
171
+ except Exception as e2:
172
+ raise Exception(f"Failed to call Gemini API with both formats. Full path error: {api_error}, ID-only error: {e2}")
173
+ else:
174
+ # Try alternative dict format
175
+ logger.warning(f"Primary API format failed, trying alternative: {api_error}")
176
+ try:
177
+ response = self.client.models.generate_content(
178
+ model=model,
179
+ contents=full_query,
180
+ tools=[{
181
+ "file_search": {
182
+ "file_search_store_names": [store_name_to_try]
183
+ }
184
+ }]
185
+ )
186
+ except Exception as e2:
187
+ raise Exception(f"Failed to call Gemini API: {e2}")
188
+
189
+ # Extract answer
190
+ answer = ""
191
+ if hasattr(response, 'text'):
192
+ answer = response.text
193
+ elif hasattr(response, 'candidates') and response.candidates:
194
+ # Try to get text from first candidate
195
+ candidate = response.candidates[0]
196
+ if hasattr(candidate, 'content') and candidate.content:
197
+ if hasattr(candidate.content, 'parts'):
198
+ text_parts = []
199
+ for part in candidate.content.parts:
200
+ if hasattr(part, 'text'):
201
+ text_parts.append(part.text)
202
+ answer = " ".join(text_parts)
203
+ elif isinstance(candidate.content, str):
204
+ answer = candidate.content
205
+ else:
206
+ answer = str(response)
207
+
208
+ # Extract grounding metadata (document references)
209
+ sources = []
210
+ grounding_metadata = None
211
+
212
+ logger.info(f"πŸ” Extracting sources from Gemini response...")
213
+
214
+ if hasattr(response, 'candidates') and response.candidates:
215
+ candidate = response.candidates[0]
216
+ logger.info(f" Found candidate, checking for grounding_metadata...")
217
+
218
+ # Get grounding metadata
219
+ if hasattr(candidate, 'grounding_metadata'):
220
+ grounding_metadata = candidate.grounding_metadata
221
+ logger.info(f" Found grounding_metadata: {type(grounding_metadata)}")
222
+
223
+ # Extract source documents from grounding metadata
224
+ # Handle different response formats
225
+ grounding_chunks = None
226
+ if hasattr(grounding_metadata, 'grounding_chunks'):
227
+ grounding_chunks = grounding_metadata.grounding_chunks
228
+ logger.info(f" Found grounding_chunks (attr): {len(grounding_chunks) if grounding_chunks else 0}")
229
+ elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
230
+ grounding_chunks = grounding_metadata['grounding_chunks']
231
+ logger.info(f" Found grounding_chunks (dict): {len(grounding_chunks) if grounding_chunks else 0}")
232
+ elif hasattr(grounding_metadata, '__dict__'):
233
+ # Try to access as object attributes
234
+ metadata_dict = grounding_metadata.__dict__
235
+ if 'grounding_chunks' in metadata_dict:
236
+ grounding_chunks = metadata_dict['grounding_chunks']
237
+ logger.info(f" Found grounding_chunks (__dict__): {len(grounding_chunks) if grounding_chunks else 0}")
238
+
239
+ if grounding_chunks:
240
+ logger.info(f" Processing {len(grounding_chunks)} grounding chunks...")
241
+ for idx, chunk in enumerate(grounding_chunks):
242
+ # Handle both object and dict formats
243
+ try:
244
+ if isinstance(chunk, dict):
245
+ chunk_data = chunk
246
+ else:
247
+ # Object format - convert to dict-like access
248
+ chunk_data = {}
249
+ if hasattr(chunk, 'chunk'):
250
+ chunk_obj = chunk.chunk
251
+ chunk_data['chunk'] = {
252
+ 'text': getattr(chunk_obj, 'text', ''),
253
+ 'file_name': getattr(chunk_obj, 'file_name', '')
254
+ }
255
+ if hasattr(chunk, 'relevance_score'):
256
+ score_obj = chunk.relevance_score
257
+ chunk_data['relevance_score'] = {
258
+ 'score': getattr(score_obj, 'score', 0.0)
259
+ }
260
+
261
+ chunk_info = chunk_data.get('chunk', {})
262
+ text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
263
+ file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
264
+
265
+ # Try to extract file URI and parse metadata from it
266
+ file_uri = chunk_info.get('file_uri', '') if isinstance(chunk_info, dict) else ''
267
+
268
+ # Also check for 'web' attribute (GroundingChunkData structure)
269
+ if hasattr(chunk, 'web') and chunk.web:
270
+ web_data = chunk.web
271
+ file_uri = getattr(web_data, 'file_uri', '') or file_uri
272
+ file_name = getattr(web_data, 'title', '') or getattr(web_data, 'filename', '') or file_name
273
+ text = getattr(web_data, 'text', '') or getattr(web_data, 'content', '') or text
274
+
275
+ # Check retrieved_context - this is where the actual data seems to be!
276
+ if hasattr(chunk, 'retrieved_context') and chunk.retrieved_context:
277
+ rc = chunk.retrieved_context
278
+ # Get text content
279
+ if hasattr(rc, 'text'):
280
+ text = getattr(rc, 'text', '') or text
281
+ # Get document name
282
+ if hasattr(rc, 'document_name'):
283
+ doc_name = getattr(rc, 'document_name', '')
284
+ if doc_name:
285
+ file_name = doc_name or file_name
286
+
287
+ # Fallback: Parse from string representation if we still don't have filename
288
+ if not file_name:
289
+ chunk_str = str(chunk)
290
+ import re
291
+ # Look for PDF filenames
292
+ pdf_match = re.search(r"([A-Za-z0-9\s_-]+\.pdf)", chunk_str)
293
+ if pdf_match:
294
+ file_name = pdf_match.group(1)
295
+ # Or look for title= pattern
296
+ if not file_name and 'title=' in chunk_str:
297
+ title_match = re.search(r"title=['\"]([^'\"]+)['\"]", chunk_str)
298
+ if title_match:
299
+ file_name = title_match.group(1)
300
+
301
+ if not file_name and file_uri:
302
+ # Extract filename from URI if available
303
+ file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
304
+
305
+ score_data = chunk_data.get('relevance_score', {})
306
+ score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
307
+
308
+ if text or file_name: # Only add if we have content
309
+ source_info = {
310
+ "content": text,
311
+ "filename": file_name,
312
+ "score": score,
313
+ "file_uri": file_uri,
314
+ }
315
+ sources.append(source_info)
316
+ logger.info(f"πŸ“„ Extracted source {idx+1}: {file_name} (score: {score:.3f}, content length: {len(text)})")
317
+ except Exception as e:
318
+ logger.warning(f"Error extracting chunk {idx+1} info: {e}")
319
+ import traceback
320
+ logger.debug(traceback.format_exc())
321
+ continue
322
+ else:
323
+ logger.warning(f" No grounding_chunks found in grounding_metadata")
324
+ else:
325
+ logger.warning(f" Candidate does not have grounding_metadata attribute")
326
+
327
+ # Also try to get file references from other parts of the response
328
+ # Sometimes Gemini includes file references in the response itself
329
+ if not sources or len(sources) == 0:
330
+ logger.info(f" No sources from grounding_metadata, trying alternative extraction...")
331
+ # Check if response has file references in other attributes
332
+ if hasattr(candidate, 'content') and candidate.content:
333
+ if hasattr(candidate.content, 'parts'):
334
+ for part in candidate.content.parts:
335
+ if hasattr(part, 'file_data'):
336
+ file_data = part.file_data
337
+ if hasattr(file_data, 'file_uri') or (isinstance(file_data, dict) and 'file_uri' in file_data):
338
+ file_uri = getattr(file_data, 'file_uri', None) or (file_data.get('file_uri') if isinstance(file_data, dict) else None)
339
+ if file_uri:
340
+ file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
341
+ sources.append({
342
+ "content": "",
343
+ "filename": file_name,
344
+ "score": 0.0,
345
+ "file_uri": file_uri,
346
+ })
347
+ logger.info(f"πŸ“„ Extracted source from file_data: {file_name}")
348
+
349
+ logger.info(f"βœ… Total sources extracted: {len(sources)}")
350
+
351
+ return GeminiFileSearchResult(
352
+ answer=answer,
353
+ sources=sources,
354
+ grounding_metadata=grounding_metadata,
355
+ query=query
356
+ )
357
+
358
+ except Exception as e:
359
+ # Return error result
360
+ return GeminiFileSearchResult(
361
+ answer=f"I apologize, but I encountered an error: {str(e)}",
362
+ sources=[],
363
+ query=query
364
+ )
365
+
366
+ def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
367
+ """
368
+ Format Gemini sources to match the format expected by the UI.
369
+
370
+ Returns list of document-like objects compatible with existing display code.
371
+ """
372
+ from langchain.docstore.document import Document
373
+
374
+ formatted_sources = []
375
+
376
+ for i, source in enumerate(result.sources):
377
+ filename = source.get("filename", "Unknown")
378
+
379
+ # Try to extract metadata from filename (e.g., "Kalangala DLG Report of Auditor General 2021.pdf")
380
+ year = None
381
+ district = None
382
+ source_name = "Gemini File Search"
383
+
384
+ # Parse filename for year
385
+ import re
386
+ year_match = re.search(r'\b(20\d{2})\b', filename)
387
+ if year_match:
388
+ year = int(year_match.group(1))
389
+
390
+ # Parse filename for district/source
391
+ if "Kalangala" in filename:
392
+ district = "Kalangala"
393
+ source_name = "Kalangala DLG"
394
+ elif "Gulu" in filename:
395
+ district = "Gulu"
396
+ source_name = "Gulu DLG"
397
+ elif "KCCA" in filename:
398
+ district = "Kampala"
399
+ source_name = "KCCA"
400
+ elif "MAAIF" in filename:
401
+ source_name = "MAAIF"
402
+ elif "MWTS" in filename:
403
+ source_name = "MWTS"
404
+ elif "Consolidated" in filename:
405
+ source_name = "Consolidated"
406
+
407
+ # Create a Document object compatible with existing code
408
+ doc = Document(
409
+ page_content=source.get("content", ""),
410
+ metadata={
411
+ "filename": filename,
412
+ "source": source_name,
413
+ "score": source.get("score"),
414
+ "chunk_index": i,
415
+ "page": None, # Gemini doesn't provide page numbers
416
+ "year": year,
417
+ "district": district,
418
+ "chunk_id": f"gemini_{i}",
419
+ "_id": f"gemini_{i}",
420
+ }
421
+ )
422
+ formatted_sources.append(doc)
423
+ logger.info(f"πŸ“‹ Formatted source {i+1}: {filename} ({year}, {source_name})")
424
+
425
+ logger.info(f"βœ… Formatted {len(formatted_sources)} sources for display")
426
+ return formatted_sources
427
+
src/{loader.py β†’ llm/loader.py} RENAMED
File without changes
src/pipeline.py CHANGED
@@ -1,5 +1,7 @@
1
  """Main pipeline orchestrator for the Audit QA system."""
 
2
  import time
 
3
  from pathlib import Path
4
  from dataclasses import dataclass
5
  from typing import Dict, Any, List, Optional
@@ -11,11 +13,21 @@ except ModuleNotFoundError as me:
11
  from langchain.schema import Document
12
 
13
  from .logging import log_error
14
- from .llm.adapters import LLMRegistry
15
- from .loader import chunks_to_documents
16
  from .vectorstore import VectorStoreManager
 
17
  from .retrieval.context import ContextRetriever
18
- from .config.loader import get_embedding_model_for_collection
 
 
 
 
 
 
 
 
 
19
 
20
 
21
 
@@ -41,12 +53,13 @@ class PipelineManager:
41
  """
42
  Initialize the pipeline manager.
43
  """
 
 
44
  self.config = config or {}
 
45
  self.vectorstore_manager = None
46
  self.context_retriever = None # Initialize as None
47
- self.llm_client = None
48
- self.report_service = None
49
- self.chunks = None
50
 
51
  # Initialize components
52
  self._initialize_components()
@@ -118,13 +131,7 @@ class PipelineManager:
118
  try:
119
  # Load config if not provided
120
  if not self.config:
121
- try:
122
- from src.config.loader import load_config
123
- self.config = load_config()
124
- except ImportError:
125
- # Try alternate import path
126
- from src.config.loader import load_config
127
- self.config = load_config()
128
 
129
  # Validate config structure
130
  if not isinstance(self.config, dict):
@@ -159,7 +166,6 @@ class PipelineManager:
159
  print("βœ… VectorStoreManager initialized successfully")
160
  except Exception as vs_error:
161
  print(f"❌ Error initializing VectorStoreManager: {vs_error}")
162
- import traceback
163
  traceback.print_exc()
164
  self.vectorstore_manager = None
165
  raise # Re-raise to be caught by outer try-except
@@ -175,40 +181,35 @@ class PipelineManager:
175
  except Exception as e:
176
  try:
177
  # Try direct instantiation with config
178
- from src.llm.adapters import get_llm_client
179
  self.llm_client = get_llm_client("openai", self.config)
180
  print("βœ… LLM CLIENT: Initialized using direct get_llm_client function with config")
181
  except Exception as e2:
182
  print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
183
  # Try to create a simple LLM client directly
184
  try:
185
- from langchain_openai import ChatOpenAI
186
- import os
187
- api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
188
- if api_key:
189
- self.llm_client = ChatOpenAI(
190
- model="gpt-3.5-turbo",
191
- api_key=api_key,
192
- temperature=0.1,
193
- max_tokens=1000
194
- )
195
- print("βœ… LLM CLIENT: Initialized using direct ChatOpenAI")
 
196
  else:
197
- print("❌ LLM CLIENT: No API key available")
198
  except Exception as e3:
199
  print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
200
  self.llm_client = None
201
 
202
  # Load system prompt
203
- from src.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
204
  self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
205
 
206
  # Initialize report service
207
  try:
208
- try:
209
- from src.reporting.service import ReportService
210
- except ImportError:
211
- from src.reporting.service import ReportService
212
  self.report_service = ReportService()
213
  except Exception as e:
214
  print(f"Warning: Could not initialize report service: {e}")
@@ -216,7 +217,6 @@ class PipelineManager:
216
 
217
  except Exception as e:
218
  print(f"❌ Error initializing components: {e}")
219
- import traceback
220
  traceback.print_exc()
221
  # Don't set vectorstore_manager to None if it was already set
222
  if not hasattr(self, 'vectorstore_manager') or self.vectorstore_manager is None:
@@ -337,7 +337,6 @@ class PipelineManager:
337
  return False
338
  except Exception as init_error:
339
  print(f"❌ Error initializing vector store manager: {init_error}")
340
- import traceback
341
  traceback.print_exc()
342
  return False
343
 
@@ -352,7 +351,6 @@ class PipelineManager:
352
  except Exception as e:
353
  print(f"❌ Error connecting to vector store: {e}")
354
  log_error(e, {"component": "vectorstore_connection"})
355
- import traceback
356
  traceback.print_exc()
357
 
358
  # If it's a dimension mismatch error, try with force_recreate
@@ -541,9 +539,6 @@ Answer:"""
541
  if auto_infer_filters and not any([reports, sources, subtype]):
542
  print(f"πŸ€– AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
543
  try:
544
- # Import get_available_metadata here to avoid circular imports
545
- from src.retrieval.filter import get_available_metadata, infer_filters_from_query
546
-
547
  # Get available metadata
548
  available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
549
 
 
1
  """Main pipeline orchestrator for the Audit QA system."""
2
+ import os
3
  import time
4
+ import traceback
5
  from pathlib import Path
6
  from dataclasses import dataclass
7
  from typing import Dict, Any, List, Optional
 
13
  from langchain.schema import Document
14
 
15
  from .logging import log_error
16
+
17
+ from .llm.loader import chunks_to_documents
18
  from .vectorstore import VectorStoreManager
19
+ from .reporting.service import ReportService
20
  from .retrieval.context import ContextRetriever
21
+ from .llm.adapters import LLMRegistry, get_llm_client
22
+ from .llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
23
+ from .config.loader import load_config, get_embedding_model_for_collection
24
+ from .retrieval.filter import get_available_metadata, infer_filters_from_query
25
+
26
+ try:
27
+ from langchain_openai import ChatOpenAI
28
+ LANGCHAIN_OPENAI_AVAILABLE = True
29
+ except ImportError:
30
+ LANGCHAIN_OPENAI_AVAILABLE = False
31
 
32
 
33
 
 
53
  """
54
  Initialize the pipeline manager.
55
  """
56
+ self.chunks = None
57
+ self.llm_client = None
58
  self.config = config or {}
59
+ self.report_service = None
60
  self.vectorstore_manager = None
61
  self.context_retriever = None # Initialize as None
62
+
 
 
63
 
64
  # Initialize components
65
  self._initialize_components()
 
131
  try:
132
  # Load config if not provided
133
  if not self.config:
134
+ self.config = load_config()
 
 
 
 
 
 
135
 
136
  # Validate config structure
137
  if not isinstance(self.config, dict):
 
166
  print("βœ… VectorStoreManager initialized successfully")
167
  except Exception as vs_error:
168
  print(f"❌ Error initializing VectorStoreManager: {vs_error}")
 
169
  traceback.print_exc()
170
  self.vectorstore_manager = None
171
  raise # Re-raise to be caught by outer try-except
 
181
  except Exception as e:
182
  try:
183
  # Try direct instantiation with config
 
184
  self.llm_client = get_llm_client("openai", self.config)
185
  print("βœ… LLM CLIENT: Initialized using direct get_llm_client function with config")
186
  except Exception as e2:
187
  print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
188
  # Try to create a simple LLM client directly
189
  try:
190
+ if LANGCHAIN_OPENAI_AVAILABLE:
191
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
192
+ if api_key:
193
+ self.llm_client = ChatOpenAI(
194
+ model="gpt-3.5-turbo",
195
+ api_key=api_key,
196
+ temperature=0.1,
197
+ max_tokens=1000
198
+ )
199
+ print("βœ… LLM CLIENT: Initialized using direct ChatOpenAI")
200
+ else:
201
+ print("❌ LLM CLIENT: No API key available")
202
  else:
203
+ print("❌ LLM CLIENT: langchain-openai not available")
204
  except Exception as e3:
205
  print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
206
  self.llm_client = None
207
 
208
  # Load system prompt
 
209
  self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
210
 
211
  # Initialize report service
212
  try:
 
 
 
 
213
  self.report_service = ReportService()
214
  except Exception as e:
215
  print(f"Warning: Could not initialize report service: {e}")
 
217
 
218
  except Exception as e:
219
  print(f"❌ Error initializing components: {e}")
 
220
  traceback.print_exc()
221
  # Don't set vectorstore_manager to None if it was already set
222
  if not hasattr(self, 'vectorstore_manager') or self.vectorstore_manager is None:
 
337
  return False
338
  except Exception as init_error:
339
  print(f"❌ Error initializing vector store manager: {init_error}")
 
340
  traceback.print_exc()
341
  return False
342
 
 
351
  except Exception as e:
352
  print(f"❌ Error connecting to vector store: {e}")
353
  log_error(e, {"component": "vectorstore_connection"})
 
354
  traceback.print_exc()
355
 
356
  # If it's a dimension mismatch error, try with force_recreate
 
539
  if auto_infer_filters and not any([reports, sources, subtype]):
540
  print(f"πŸ€– AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
541
  try:
 
 
 
542
  # Get available metadata
543
  available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
544
 
src/reporting/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- """Report metadata and utilities."""
 
 
 
 
2
 
3
  from .metadata import get_report_metadata, get_available_sources
4
  from .service import ReportService
 
1
+ """Report metadata and utilities.
2
+
3
+ This module is kept for backward compatibility with pipeline.py.
4
+ For feedback-related functionality, use src.feedback instead.
5
+ """
6
 
7
  from .metadata import get_report_metadata, get_available_sources
8
  from .service import ReportService
src/reporting/feedback_schema.py CHANGED
@@ -4,10 +4,12 @@ Feedback Schema for RAG Chatbot
4
  This module defines dataclasses for feedback data structures
5
  and provides Snowflake schema generation.
6
  """
7
-
 
8
  from dataclasses import dataclass, asdict, field
9
  from typing import List, Optional, Dict, Any, Union
10
- from datetime import datetime
 
11
 
12
 
13
  @dataclass
@@ -39,34 +41,20 @@ class UserFeedback:
39
  open_ended_feedback: Optional[str]
40
  score: int
41
  is_feedback_about_last_retrieval: bool
42
- retrieved_data: List[RetrievalEntry]
43
  conversation_id: str
44
  timestamp: float
45
  message_count: int
46
  has_retrievals: bool
47
  retrieval_count: int
48
- user_query: Optional[str] = None
49
- bot_response: Optional[str] = None
 
 
50
  created_at: str = field(default_factory=lambda: datetime.now().isoformat())
51
 
52
  def to_dict(self) -> Dict[str, Any]:
53
  """Convert to dictionary with nested data structures"""
54
  result = asdict(self)
55
- # Handle nested objects
56
- if self.retrieved_data:
57
- result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
58
- return result
59
-
60
- def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
61
- """Serialize retrieval entry to dict"""
62
- # If raw data exists, use it (it's already properly formatted)
63
- if hasattr(entry, '_raw_data') and entry._raw_data:
64
- return entry._raw_data
65
-
66
- # Otherwise, serialize the dataclass
67
- result = asdict(entry)
68
- if entry.documents_retrieved:
69
- result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
70
  return result
71
 
72
  def to_snowflake_schema(self) -> Dict[str, Any]:
@@ -81,28 +69,28 @@ class UserFeedback:
81
  "message_count": "INTEGER",
82
  "has_retrievals": "BOOLEAN",
83
  "retrieval_count": "INTEGER",
84
- "user_query": "VARCHAR(16777216)",
85
- "bot_response": "VARCHAR(16777216)",
 
 
86
  "created_at": "TIMESTAMP_NTZ",
87
- "retrieved_data": "VARIANT", # Array of retrieval entries
88
- # retrieved_data structure:
89
- # [
90
  # {
91
- # "rag_query": "...",
92
- # "conversation_length": 5,
93
- # "timestamp": 1234567890,
94
- # "docs_retrieved": [
95
- # {"filename": "...", "page": 14, "score": 0.95, ...},
96
- # ...
97
- # ]
98
  # },
99
  # ...
100
  # ]
 
 
 
 
101
  }
102
  return schema
103
 
104
  @classmethod
105
- def get_snowflake_create_table_sql(cls, table_name: str = "user_feedback") -> str:
106
  """Generate CREATE TABLE SQL for Snowflake"""
107
  schema = cls.to_snowflake_schema(None)
108
 
@@ -117,16 +105,13 @@ class UserFeedback:
117
  sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
118
  {columns_str},
119
  PRIMARY KEY (feedback_id)
120
- );
121
-
122
- -- Create index on timestamp for querying by time
123
- CREATE INDEX IF NOT EXISTS idx_feedback_timestamp ON {table_name} (timestamp);
124
-
125
- -- Create index on conversation_id for querying by conversation
126
- CREATE INDEX IF NOT EXISTS idx_feedback_conversation ON {table_name} (conversation_id);
127
-
128
- -- Create index on score for feedback analysis
129
- CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
130
  """
131
  return sql
132
 
@@ -150,47 +135,27 @@ DOCUMENT_SCHEMA = {
150
  }
151
 
152
 
153
- def generate_snowflake_schema_sql() -> str:
154
  """Generate complete Snowflake schema SQL for feedback system"""
155
- return UserFeedback.get_snowflake_create_table_sql("user_feedback")
 
 
156
 
157
 
158
  def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
159
  """Create UserFeedback instance from dictionary"""
160
- # Parse retrieved_data if present
161
- retrieved_data = []
162
- if "retrieved_data" in data and data["retrieved_data"]:
163
- for entry_dict in data.get("retrieved_data", []):
164
- # Map the actual structure from rag_retrieval_history
165
- # Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
166
- try:
167
- # Try to map to expected structure
168
- entry = RetrievalEntry(
169
- rag_query=entry_dict.get("rag_query_expansion", ""),
170
- documents_retrieved=[], # Empty for now, will store as raw data
171
- conversation_length=len(entry_dict.get("conversation_up_to", [])),
172
- filters_applied=None,
173
- timestamp=entry_dict.get("timestamp", None)
174
- )
175
- # Store raw data in the entry
176
- entry._raw_data = entry_dict # Store original for preservation
177
- retrieved_data.append(entry)
178
- except Exception as e:
179
- # If mapping fails, store as-is without strict typing
180
- pass
181
-
182
  return UserFeedback(
183
  feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
184
  open_ended_feedback=data.get("open_ended_feedback"),
185
  score=data["score"],
186
  is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
187
- retrieved_data=retrieved_data,
188
  conversation_id=data["conversation_id"],
189
  timestamp=data["timestamp"],
190
  message_count=data["message_count"],
191
  has_retrievals=data["has_retrievals"],
192
  retrieval_count=data["retrieval_count"],
193
- user_query=data.get("user_query"),
194
- bot_response=data.get("bot_response")
 
 
195
  )
196
-
 
4
  This module defines dataclasses for feedback data structures
5
  and provides Snowflake schema generation.
6
  """
7
+ import os
8
+ from datetime import datetime
9
  from dataclasses import dataclass, asdict, field
10
  from typing import List, Optional, Dict, Any, Union
11
+
12
+
13
 
14
 
15
  @dataclass
 
41
  open_ended_feedback: Optional[str]
42
  score: int
43
  is_feedback_about_last_retrieval: bool
 
44
  conversation_id: str
45
  timestamp: float
46
  message_count: int
47
  has_retrievals: bool
48
  retrieval_count: int
49
+ transcript: List[Dict[str, str]] # List of {"role": "user"/"assistant", "content": "..."}
50
+ retrievals: List[Dict[str, Any]] # List of retrieval objects with retrieved_docs and user_message_trigger
51
+ feedback_score_related_retrieval_docs: Optional[Dict[str, Any]] = None # Conversation subset + retrieved docs
52
+ retrieved_data: Optional[List[Dict[str, Any]]] = None # Preserved old column for backward compatibility
53
  created_at: str = field(default_factory=lambda: datetime.now().isoformat())
54
 
55
  def to_dict(self) -> Dict[str, Any]:
56
  """Convert to dictionary with nested data structures"""
57
  result = asdict(self)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return result
59
 
60
  def to_snowflake_schema(self) -> Dict[str, Any]:
 
69
  "message_count": "INTEGER",
70
  "has_retrievals": "BOOLEAN",
71
  "retrieval_count": "INTEGER",
72
+ "transcript": "VARCHAR(16777216)", # JSON string of ARRAY of {"role": "user"/"assistant", "content": "..."}
73
+ "retrievals": "VARCHAR(16777216)", # JSON string of ARRAY of retrieval objects
74
+ "feedback_score_related_retrieval_docs": "VARCHAR(16777216)", # JSON string of OBJECT with conversation subset + retrieved docs
75
+ "retrieved_data": "VARCHAR(16777216)", # JSON string - preserved old column for backward compatibility
76
  "created_at": "TIMESTAMP_NTZ",
77
+ # transcript structure: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
78
+ # retrievals structure: [
 
79
  # {
80
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}], # content truncated to 100 chars
81
+ # "user_message_trigger": "final user message that triggered this retrieval"
 
 
 
 
 
82
  # },
83
  # ...
84
  # ]
85
+ # feedback_score_related_retrieval_docs structure: {
86
+ # "conversation_up_to_point": [{"role": "user", "content": "..."}, ...], # subset of transcript
87
+ # "retrieved_docs": [{"content": "...", "metadata": {...}, ...}] # full chunks with all info
88
+ # }
89
  }
90
  return schema
91
 
92
  @classmethod
93
+ def get_snowflake_create_table_sql(cls, table_name: str = "USER_FEEDBACK_V3") -> str:
94
  """Generate CREATE TABLE SQL for Snowflake"""
95
  schema = cls.to_snowflake_schema(None)
96
 
 
105
  sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
106
  {columns_str},
107
  PRIMARY KEY (feedback_id)
108
+ )
109
+ CLUSTER BY (timestamp, conversation_id, score);
110
+ -- Note: Snowflake doesn't support traditional indexes on regular tables.
111
+ -- Instead, we use CLUSTER BY to optimize queries on these columns.
112
+ -- Snowflake automatically maintains clustering for efficient querying.
113
+ -- Note: transcript, retrievals, and feedback_score_related_retrieval_docs are stored as VARCHAR (JSON strings),
114
+ -- same approach as the old retrieved_data column. This allows easy storage and retrieval without VARIANT type complexity.
 
 
 
115
  """
116
  return sql
117
 
 
135
  }
136
 
137
 
138
+ def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
139
  """Generate complete Snowflake schema SQL for feedback system"""
140
+ if table_name is None:
141
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
142
+ return UserFeedback.get_snowflake_create_table_sql(table_name)
143
 
144
 
145
  def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
146
  """Create UserFeedback instance from dictionary"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return UserFeedback(
148
  feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
149
  open_ended_feedback=data.get("open_ended_feedback"),
150
  score=data["score"],
151
  is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
 
152
  conversation_id=data["conversation_id"],
153
  timestamp=data["timestamp"],
154
  message_count=data["message_count"],
155
  has_retrievals=data["has_retrievals"],
156
  retrieval_count=data["retrieval_count"],
157
+ transcript=data.get("transcript", []),
158
+ retrievals=data.get("retrievals", []),
159
+ feedback_score_related_retrieval_docs=data.get("feedback_score_related_retrieval_docs"),
160
+ retrieved_data=data.get("retrieved_data")
161
  )
 
src/reporting/snowflake_connector.py CHANGED
@@ -8,8 +8,11 @@ import os
8
  import json
9
  import logging
10
  from typing import Dict, Any, Optional
 
 
11
  from src.reporting.feedback_schema import UserFeedback
12
 
 
13
  # Try to import snowflake connector
14
  try:
15
  import snowflake.connector
@@ -79,12 +82,16 @@ class SnowflakeFeedbackConnector:
79
  self._connection.close()
80
  print("βœ… Disconnected from Snowflake")
81
 
82
- def insert_feedback(self, feedback: UserFeedback) -> bool:
83
  """Insert a single feedback record into Snowflake"""
84
  logger.info("=" * 80)
85
  logger.info("πŸ”„ SNOWFLAKE INSERT: Starting feedback insertion process")
86
  logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
87
 
 
 
 
 
88
  if not self._connection:
89
  logger.error("❌ Not connected to Snowflake. Call connect() first.")
90
  raise RuntimeError("Not connected to Snowflake. Call connect() first.")
@@ -131,38 +138,53 @@ class SnowflakeFeedbackConnector:
131
  logger.error(f"❌ Could not set context: {e}")
132
  raise
133
 
134
- # Prepare data
135
- logger.info("πŸ”§ DATA PREPARATION: Preparing retrieved_data...")
136
- retrieved_data_raw = feedback.to_dict()['retrieved_data']
137
 
138
- logger.info(f" - Retrieved data type (raw): {type(retrieved_data_raw).__name__}")
139
- logger.info(f" - Retrieved data: {repr(retrieved_data_raw)[:200]}")
 
 
 
 
 
 
 
140
 
141
- # If retrieved_data is already a string (from UI), parse it
142
- if isinstance(retrieved_data_raw, str):
143
- logger.info(" - Parsing string to Python object")
144
- retrieved_data = json.loads(retrieved_data_raw)
145
- elif retrieved_data_raw is None:
146
- retrieved_data = None
147
  else:
148
- # It's already a Python object (list/dict)
149
- logger.info(" - Data is already a Python object")
150
- retrieved_data = retrieved_data_raw
151
 
152
- logger.info(f" - Retrieved data size: {len(str(retrieved_data)) if retrieved_data else 0} characters")
153
- logger.info(f" - Retrieved data type: {type(retrieved_data).__name__}")
 
 
 
 
 
 
 
154
 
155
- # Convert to JSON string for TEXT column
156
- if retrieved_data:
157
- retrieved_data_for_db = json.dumps(retrieved_data)
158
- logger.info(f" - Converting to JSON string for TEXT column")
159
- logger.info(f" - JSON string length: {len(retrieved_data_for_db)}")
 
160
  else:
161
- logger.info(f" - Retrieved data is None, using NULL")
162
  retrieved_data_for_db = None
 
163
 
164
- # Build SQL with retrieved_data as a TEXT column parameter
165
- sql = f"""INSERT INTO user_feedback (
 
166
  feedback_id,
167
  open_ended_feedback,
168
  score,
@@ -172,23 +194,25 @@ class SnowflakeFeedbackConnector:
172
  message_count,
173
  has_retrievals,
174
  retrieval_count,
175
- user_query,
176
- bot_response,
177
- created_at,
178
- retrieved_data
 
179
  ) VALUES (
180
  %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
181
  %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
182
- %(retrieval_count)s, %(user_query)s, %(bot_response)s, %(created_at)s,
183
- %(retrieved_data)s
184
  )"""
185
 
186
  logger.info("πŸ“ SQL PREPARATION: Building INSERT statement...")
187
- logger.info(f" - Target table: user_feedback")
188
  logger.info(f" - Database: {self.database}")
189
  logger.info(f" - Schema: {self.schema}")
190
 
191
  # Prepare parameters
 
192
  params = {
193
  'feedback_id': feedback.feedback_id,
194
  'open_ended_feedback': feedback.open_ended_feedback,
@@ -199,10 +223,11 @@ class SnowflakeFeedbackConnector:
199
  'message_count': feedback.message_count,
200
  'has_retrievals': feedback.has_retrievals,
201
  'retrieval_count': feedback.retrieval_count,
202
- 'user_query': feedback.user_query,
203
- 'bot_response': feedback.bot_response,
204
- 'created_at': feedback.created_at,
205
- 'retrieved_data': retrieved_data_for_db
 
206
  }
207
 
208
  # Execute insert
@@ -265,12 +290,16 @@ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
265
  )
266
 
267
 
268
- def save_to_snowflake(feedback: UserFeedback) -> bool:
269
  """Helper function to save feedback to Snowflake"""
270
  logger.info("=" * 80)
271
  logger.info("πŸ”΅ SNOWFLAKE SAVE: Starting save process")
272
  logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
273
 
 
 
 
 
274
  connector = get_snowflake_connector_from_env()
275
 
276
  if not connector:
@@ -285,7 +314,7 @@ def save_to_snowflake(feedback: UserFeedback) -> bool:
285
  logger.info("βœ… SNOWFLAKE SAVE: Connection established")
286
 
287
  logger.info("πŸ“₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
288
- success = connector.insert_feedback(feedback)
289
 
290
  logger.info("πŸ”Œ SNOWFLAKE SAVE: Disconnecting...")
291
  connector.disconnect()
@@ -302,4 +331,3 @@ def save_to_snowflake(feedback: UserFeedback) -> bool:
302
  logger.error(f" - Error: {e}")
303
  logger.info("=" * 80)
304
  return False
305
-
 
8
  import json
9
  import logging
10
  from typing import Dict, Any, Optional
11
+
12
+
13
  from src.reporting.feedback_schema import UserFeedback
14
 
15
+
16
  # Try to import snowflake connector
17
  try:
18
  import snowflake.connector
 
82
  self._connection.close()
83
  print("βœ… Disconnected from Snowflake")
84
 
85
+ def insert_feedback(self, feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
86
  """Insert a single feedback record into Snowflake"""
87
  logger.info("=" * 80)
88
  logger.info("πŸ”„ SNOWFLAKE INSERT: Starting feedback insertion process")
89
  logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
90
 
91
+ # Get table name from parameter, env var, or default
92
+ if table_name is None:
93
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
94
+
95
  if not self._connection:
96
  logger.error("❌ Not connected to Snowflake. Call connect() first.")
97
  raise RuntimeError("Not connected to Snowflake. Call connect() first.")
 
138
  logger.error(f"❌ Could not set context: {e}")
139
  raise
140
 
141
+ # Prepare data - convert to JSON strings for VARIANT columns (same approach as old retrieved_data)
142
+ logger.info("πŸ”§ DATA PREPARATION: Preparing VARIANT columns...")
143
+ feedback_dict = feedback.to_dict()
144
 
145
+ # Prepare transcript (ARRAY) - convert to JSON string
146
+ transcript_raw = feedback_dict.get('transcript', [])
147
+ if transcript_raw:
148
+ # Convert to JSON string (same approach as old retrieved_data)
149
+ transcript_for_db = json.dumps(transcript_raw)
150
+ logger.info(f" - Transcript: {len(transcript_raw)} messages, JSON length: {len(transcript_for_db)}")
151
+ else:
152
+ transcript_for_db = None
153
+ logger.info(" - Transcript: None")
154
 
155
+ # Prepare retrievals (ARRAY) - convert to JSON string
156
+ retrievals_raw = feedback_dict.get('retrievals', [])
157
+ if retrievals_raw:
158
+ # Convert to JSON string (same approach as old retrieved_data)
159
+ retrievals_for_db = json.dumps(retrievals_raw)
160
+ logger.info(f" - Retrievals: {len(retrievals_raw)} entries, JSON length: {len(retrievals_for_db)}")
161
  else:
162
+ retrievals_for_db = None
163
+ logger.info(" - Retrievals: None")
 
164
 
165
+ # Prepare feedback_score_related_retrieval_docs (OBJECT) - convert to JSON string
166
+ feedback_score_related_raw = feedback_dict.get('feedback_score_related_retrieval_docs')
167
+ if feedback_score_related_raw:
168
+ # Convert to JSON string (same approach as old retrieved_data)
169
+ feedback_score_related_for_db = json.dumps(feedback_score_related_raw)
170
+ logger.info(f" - Feedback score related docs: present, JSON length: {len(feedback_score_related_for_db)}")
171
+ else:
172
+ feedback_score_related_for_db = None
173
+ logger.info(" - Feedback score related docs: None")
174
 
175
+ # Prepare retrieved_data (preserved old column) - convert to JSON string
176
+ retrieved_data_raw = feedback_dict.get('retrieved_data')
177
+ if retrieved_data_raw:
178
+ # Convert to JSON string (same approach as old retrieved_data)
179
+ retrieved_data_for_db = json.dumps(retrieved_data_raw)
180
+ logger.info(f" - Retrieved data (preserved): present, JSON length: {len(retrieved_data_for_db)}")
181
  else:
 
182
  retrieved_data_for_db = None
183
+ logger.info(" - Retrieved data (preserved): None")
184
 
185
+ # Build SQL with new column structure
186
+ # Columns are VARCHAR (storing JSON strings), same approach as old retrieved_data
187
+ sql = f"""INSERT INTO {table_name} (
188
  feedback_id,
189
  open_ended_feedback,
190
  score,
 
194
  message_count,
195
  has_retrievals,
196
  retrieval_count,
197
+ transcript,
198
+ retrievals,
199
+ feedback_score_related_retrieval_docs,
200
+ retrieved_data,
201
+ created_at
202
  ) VALUES (
203
  %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
204
  %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
205
+ %(retrieval_count)s, %(transcript)s, %(retrievals)s, %(feedback_score_related_retrieval_docs)s,
206
+ %(retrieved_data)s, %(created_at)s
207
  )"""
208
 
209
  logger.info("πŸ“ SQL PREPARATION: Building INSERT statement...")
210
+ logger.info(f" - Target table: {table_name}")
211
  logger.info(f" - Database: {self.database}")
212
  logger.info(f" - Schema: {self.schema}")
213
 
214
  # Prepare parameters
215
+ # Pass JSON strings for VARIANT columns (same approach as old retrieved_data)
216
  params = {
217
  'feedback_id': feedback.feedback_id,
218
  'open_ended_feedback': feedback.open_ended_feedback,
 
223
  'message_count': feedback.message_count,
224
  'has_retrievals': feedback.has_retrievals,
225
  'retrieval_count': feedback.retrieval_count,
226
+ 'transcript': transcript_for_db, # JSON string
227
+ 'retrievals': retrievals_for_db, # JSON string
228
+ 'feedback_score_related_retrieval_docs': feedback_score_related_for_db, # JSON string
229
+ 'retrieved_data': retrieved_data_for_db, # JSON string - preserved old column
230
+ 'created_at': feedback.created_at
231
  }
232
 
233
  # Execute insert
 
290
  )
291
 
292
 
293
+ def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
294
  """Helper function to save feedback to Snowflake"""
295
  logger.info("=" * 80)
296
  logger.info("πŸ”΅ SNOWFLAKE SAVE: Starting save process")
297
  logger.info(f"πŸ“ Feedback ID: {feedback.feedback_id}")
298
 
299
+ # Get table name from parameter or env var
300
+ if table_name is None:
301
+ table_name = os.getenv("SNOWFLAKE_FEEDBACK_TABLE", "USER_FEEDBACK_V3")
302
+
303
  connector = get_snowflake_connector_from_env()
304
 
305
  if not connector:
 
314
  logger.info("βœ… SNOWFLAKE SAVE: Connection established")
315
 
316
  logger.info("πŸ“₯ SNOWFLAKE SAVE: Attempting to insert feedback...")
317
+ success = connector.insert_feedback(feedback, table_name=table_name)
318
 
319
  logger.info("πŸ”Œ SNOWFLAKE SAVE: Disconnecting...")
320
  connector.disconnect()
 
331
  logger.error(f" - Error: {e}")
332
  logger.info("=" * 80)
333
  return False
 
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ui_components/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI Components Module
3
+
4
+ This module contains UI-related components including styles, visualizations,
5
+ and utility functions for the Streamlit application.
6
+ """
7
+
8
+ from .styles import get_custom_css
9
+ from .components import (
10
+ display_chunk_statistics_charts,
11
+ display_chunk_statistics_table
12
+ )
13
+ from .utils import extract_chunk_statistics
14
+
15
+ __all__ = [
16
+ "get_custom_css",
17
+ "display_chunk_statistics_charts",
18
+ "display_chunk_statistics_table",
19
+ "extract_chunk_statistics"
20
+ ]
21
+
src/ui_components/components.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI components for displaying statistics and visualizations
3
+ """
4
+
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ from typing import Dict, Any
9
+
10
+
11
+ def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
12
+ """Display statistics as interactive charts for 10+ results."""
13
+ if not stats or stats.get('total_chunks', 0) == 0:
14
+ return
15
+
16
+ # Wrap everything in one styled container - open it
17
+ st.markdown(f"""
18
+ <div class="retrieval-distribution-container">
19
+ <h3 style="margin-top: 0;">πŸ“Š {title}</h3>
20
+ <div style="display: flex; justify-content: space-around; align-items: center; padding: 15px 0; border-bottom: 1px solid #e0e0e0; margin-bottom: 20px;">
21
+ <div class="metric-container">
22
+ <div class="metric-label">Total Chunks</div>
23
+ <div class="metric-value">{stats['total_chunks']}</div>
24
+ </div>
25
+ <div class="metric-container">
26
+ <div class="metric-label">Unique Sources</div>
27
+ <div class="metric-value">{stats['unique_sources']}</div>
28
+ </div>
29
+ <div class="metric-container">
30
+ <div class="metric-label">Unique Years</div>
31
+ <div class="metric-value">{stats['unique_years']}</div>
32
+ </div>
33
+ <div class="metric-container">
34
+ <div class="metric-label">Unique Files</div>
35
+ <div class="metric-value">{stats['unique_filenames']}</div>
36
+ </div>
37
+ </div>
38
+ """, unsafe_allow_html=True)
39
+
40
+ # Charts - three columns to include Districts
41
+ col1, col2, col3 = st.columns(3)
42
+
43
+ with col1:
44
+ # Source distribution chart
45
+ if stats['source_distribution']:
46
+ source_df = pd.DataFrame(
47
+ list(stats['source_distribution'].items()),
48
+ columns=['Source', 'Count']
49
+ )
50
+ fig_source = px.bar(
51
+ source_df,
52
+ x='Count',
53
+ y='Source',
54
+ orientation='h',
55
+ title='Distribution by Source',
56
+ color='Count',
57
+ color_continuous_scale='viridis'
58
+ )
59
+ fig_source.update_layout(height=400, showlegend=False)
60
+ st.plotly_chart(fig_source, use_container_width=True) # Note: plotly_chart still uses use_container_width
61
+
62
+ with col2:
63
+ # Year distribution chart
64
+ if stats['year_distribution']:
65
+ # Filter out 'Unknown' years for the chart
66
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
67
+ if year_dist_filtered:
68
+ year_df = pd.DataFrame(
69
+ list(year_dist_filtered.items()),
70
+ columns=['Year', 'Count']
71
+ )
72
+ # Sort by year as integer but keep as string for categorical display
73
+ year_df['Year_Int'] = year_df['Year'].astype(int)
74
+ year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
75
+
76
+ fig_year = px.bar(
77
+ year_df,
78
+ x='Year',
79
+ y='Count',
80
+ title='Distribution by Year',
81
+ color='Count',
82
+ color_continuous_scale='plasma'
83
+ )
84
+ # Ensure years are treated as categorical (discrete) not continuous
85
+ fig_year.update_xaxes(type='category')
86
+ fig_year.update_layout(height=400, showlegend=False)
87
+ st.plotly_chart(fig_year, use_container_width=True) # Note: plotly_chart still uses use_container_width
88
+ else:
89
+ st.info("No valid years found in the results")
90
+
91
+ with col3:
92
+ # District distribution chart
93
+ if stats.get('district_distribution'):
94
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
95
+ if district_dist_filtered:
96
+ district_df = pd.DataFrame(
97
+ list(district_dist_filtered.items()),
98
+ columns=['District', 'Count']
99
+ )
100
+ district_df = district_df.sort_values('Count', ascending=False)
101
+
102
+ fig_district = px.bar(
103
+ district_df,
104
+ x='Count',
105
+ y='District',
106
+ orientation='h',
107
+ title='Distribution by District',
108
+ color='Count',
109
+ color_continuous_scale='blues'
110
+ )
111
+ fig_district.update_layout(height=400, showlegend=False)
112
+ st.plotly_chart(fig_district, use_container_width=True) # Note: plotly_chart still uses use_container_width
113
+ else:
114
+ st.info("No valid districts found in the results")
115
+
116
+ # Close the container
117
+ st.markdown('</div>', unsafe_allow_html=True)
118
+
119
+
120
+ def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
121
+ """Display statistics as tables for smaller results with fixed alignment."""
122
+ if not stats or stats.get('total_chunks', 0) == 0:
123
+ return
124
+
125
+ # Wrap in styled container
126
+ st.markdown('<div class="retrieval-distribution-container">', unsafe_allow_html=True)
127
+
128
+ st.subheader(f"πŸ“Š {title}")
129
+
130
+ # Create a container with fixed height for alignment
131
+ stats_container = st.container()
132
+
133
+ with stats_container:
134
+ # Create 4 equal columns for consistent alignment
135
+ col1, col2, col3, col4 = st.columns(4)
136
+
137
+ with col1:
138
+ st.markdown("**🏘️ Districts**")
139
+ if stats.get('district_distribution'):
140
+ district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
141
+ if district_dist_filtered:
142
+ district_data = {
143
+ "District": list(district_dist_filtered.keys()),
144
+ "Count": list(district_dist_filtered.values())
145
+ }
146
+ district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
147
+ st.dataframe(district_df, hide_index=True, width='stretch')
148
+ else:
149
+ st.write("No district data")
150
+ else:
151
+ st.write("No district data")
152
+
153
+ with col2:
154
+ st.markdown("**πŸ“‚ Sources**")
155
+ if stats['source_distribution']:
156
+ source_data = {
157
+ "Source": list(stats['source_distribution'].keys()),
158
+ "Count": list(stats['source_distribution'].values())
159
+ }
160
+ source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
161
+ st.dataframe(source_df, hide_index=True, width='stretch')
162
+ else:
163
+ st.write("No source data")
164
+
165
+ with col3:
166
+ st.markdown("**πŸ“… Years**")
167
+ if stats['year_distribution']:
168
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
169
+ if year_dist_filtered:
170
+ year_data = {
171
+ "Year": list(year_dist_filtered.keys()),
172
+ "Count": list(year_dist_filtered.values())
173
+ }
174
+ year_df = pd.DataFrame(year_data)
175
+ # Sort by year as integer but display as string
176
+ year_df['Year_Int'] = year_df['Year'].astype(int)
177
+ year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
178
+ st.dataframe(year_df, hide_index=True, width='stretch')
179
+ else:
180
+ st.write("No year data")
181
+ else:
182
+ st.write("No year data")
183
+
184
+ with col4:
185
+ st.markdown("**πŸ“„ Files**")
186
+ if stats['filename_distribution']:
187
+ filename_items = list(stats['filename_distribution'].items())
188
+ filename_items.sort(key=lambda x: x[1], reverse=True)
189
+
190
+ # Show top files with truncated names
191
+ file_data = {
192
+ "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
193
+ "Count": [c for f, c in filename_items[:5]]
194
+ }
195
+ file_df = pd.DataFrame(file_data)
196
+ st.dataframe(file_df, hide_index=True, width='stretch')
197
+ else:
198
+ st.write("No file data")
199
+
200
+ # Close container
201
+ st.markdown('</div>', unsafe_allow_html=True)
202
+
src/ui_components/styles.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom CSS styles for Streamlit application
3
+ """
4
+
5
+
6
+ def get_custom_css() -> str:
7
+ """Get custom CSS styles as a string"""
8
+ return """
9
+ <style>
10
+ .main-header {
11
+ font-size: 2.5rem;
12
+ font-weight: bold;
13
+ color: #1f77b4;
14
+ text-align: center;
15
+ margin-bottom: 1rem;
16
+ width: 100%;
17
+ display: block;
18
+ }
19
+
20
+ .subtitle {
21
+ font-size: 1.2rem;
22
+ color: #666;
23
+ text-align: center;
24
+ margin-bottom: 2rem;
25
+ width: 100%;
26
+ display: block;
27
+ }
28
+
29
+ .session-info {
30
+ background-color: #f0f2f6;
31
+ padding: 10px;
32
+ border-radius: 5px;
33
+ margin-bottom: 20px;
34
+ font-size: 0.9rem;
35
+ }
36
+
37
+ .user-message {
38
+ background-color: #007bff;
39
+ color: white;
40
+ padding: 12px 16px;
41
+ border-radius: 18px 18px 4px 18px;
42
+ margin: 8px 0;
43
+ margin-left: 20%;
44
+ word-wrap: break-word;
45
+ }
46
+
47
+ .bot-message {
48
+ background-color: #f1f3f4;
49
+ color: #333;
50
+ padding: 12px 16px;
51
+ border-radius: 18px 18px 18px 4px;
52
+ margin: 8px 0;
53
+ margin-right: 20%;
54
+ word-wrap: break-word;
55
+ border: 1px solid #e0e0e0;
56
+ }
57
+
58
+ .filter-section {
59
+ margin-bottom: 20px;
60
+ padding: 15px;
61
+ background-color: #f8f9fa;
62
+ border-radius: 8px;
63
+ border: 1px solid #e9ecef;
64
+ }
65
+
66
+ .filter-title {
67
+ font-weight: bold;
68
+ margin-bottom: 10px;
69
+ color: #495057;
70
+ }
71
+
72
+ .feedback-section {
73
+ background-color: #f8f9fa;
74
+ padding: 20px;
75
+ border-radius: 10px;
76
+ margin-top: 30px;
77
+ border: 2px solid #dee2e6;
78
+ }
79
+
80
+ .retrieval-history {
81
+ background-color: #ffffff;
82
+ padding: 15px;
83
+ border-radius: 5px;
84
+ margin: 10px 0;
85
+ border-left: 4px solid #007bff;
86
+ }
87
+
88
+ .retrieval-distribution-container {
89
+ background-color: #ffffff;
90
+ padding: 25px;
91
+ border-radius: 10px;
92
+ margin: 20px 0;
93
+ border: 2px solid #e0e0e0;
94
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.06);
95
+ }
96
+
97
+ .metric-label {
98
+ font-size: 0.9rem;
99
+ color: #555;
100
+ margin-bottom: 5px;
101
+ text-align: center;
102
+ }
103
+
104
+ .metric-value {
105
+ font-size: 1.8rem;
106
+ font-weight: bold;
107
+ color: #000000;
108
+ text-align: center;
109
+ }
110
+
111
+ .metric-container {
112
+ text-align: center;
113
+ padding: 10px;
114
+ }
115
+ </style>
116
+ """
117
+
src/ui_components/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI utility functions for data processing and statistics
3
+ """
4
+
5
+ from typing import Dict, Any, List
6
+ from collections import Counter
7
+
8
+
9
+ def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
10
+ """Extract statistics from retrieved chunks."""
11
+ if not sources:
12
+ return {}
13
+
14
+ sources_list = []
15
+ years = []
16
+ filenames = []
17
+ districts = []
18
+
19
+ for doc in sources:
20
+ metadata = getattr(doc, 'metadata', {})
21
+
22
+ # Extract source
23
+ source = metadata.get('source', 'Unknown')
24
+ sources_list.append(source)
25
+
26
+ # Extract year
27
+ year = metadata.get('year', 'Unknown')
28
+ if year and year != 'Unknown':
29
+ try:
30
+ # Convert to int first, then back to string to ensure it's a proper year
31
+ year_int = int(float(year)) # Handle both int and float strings
32
+ if 1900 <= year_int <= 2030: # Reasonable year range
33
+ years.append(str(year_int))
34
+ else:
35
+ years.append('Unknown')
36
+ except (ValueError, TypeError):
37
+ years.append('Unknown')
38
+ else:
39
+ years.append('Unknown')
40
+
41
+ # Extract filename
42
+ filename = metadata.get('filename', 'Unknown')
43
+ filenames.append(filename)
44
+
45
+ # Extract district
46
+ district = metadata.get('district', 'Unknown')
47
+ if district and district != 'Unknown':
48
+ districts.append(district)
49
+ else:
50
+ districts.append('Unknown')
51
+
52
+ # Count occurrences
53
+ source_counts = Counter(sources_list)
54
+ year_counts = Counter(years)
55
+ filename_counts = Counter(filenames)
56
+ district_counts = Counter(districts)
57
+
58
+ return {
59
+ 'total_chunks': len(sources),
60
+ 'unique_sources': len(source_counts),
61
+ 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
62
+ 'unique_filenames': len(filename_counts),
63
+ 'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
64
+ 'source_distribution': dict(source_counts),
65
+ 'year_distribution': dict(year_counts),
66
+ 'filename_distribution': dict(filename_counts),
67
+ 'district_distribution': dict(district_counts),
68
+ 'sources': sources_list,
69
+ 'years': years,
70
+ 'filenames': filenames,
71
+ 'districts': districts
72
+ }
73
+
utils.py β†’ src/utils.py RENAMED
File without changes
src/vectorstore.py CHANGED
@@ -1,9 +1,20 @@
1
  """Vector store management and operations."""
 
 
 
 
 
2
  from pathlib import Path
3
  from typing import Dict, Any, List, Optional
4
 
5
 
6
  import torch
 
 
 
 
 
 
7
  from langchain_qdrant import QdrantVectorStore
8
  from langchain.docstore.document import Document
9
  from langchain_core.embeddings import Embeddings
@@ -28,11 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
28
 
29
  if truncate_dim and "matryoshka" in model_name.lower():
30
  # Use SentenceTransformer directly for Matryoshka models
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
 
 
 
 
 
 
33
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
34
  else:
35
  # Use standard HuggingFaceEmbeddings
 
 
 
 
 
 
36
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
37
 
38
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -76,12 +99,17 @@ class VectorStoreManager:
76
 
77
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
78
  """Create embeddings model from configuration."""
79
- device = "cuda" if torch.cuda.is_available() else "cpu"
80
-
81
  model_name = self.config["retriever"]["model"]
82
  normalize = self.config["retriever"]["normalize"]
83
 
84
- model_kwargs = {"device": device}
 
 
 
 
 
 
 
85
  encode_kwargs = {
86
  "normalize_embeddings": normalize,
87
  "batch_size": 100,
@@ -108,6 +136,8 @@ class VectorStoreManager:
108
  return embeddings
109
 
110
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
 
 
111
  embeddings = HuggingFaceEmbeddings(
112
  model_name=model_name,
113
  model_kwargs=model_kwargs,
 
1
  """Vector store management and operations."""
2
+ import os
3
+ # Disable MPS before importing torch to prevent meta tensor issues on Mac
4
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
5
+ os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
6
+
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional
9
 
10
 
11
  import torch
12
+ # Disable MPS backend explicitly to prevent meta tensor issues
13
+ if hasattr(torch.backends, 'mps'):
14
+ # Monkey patch to disable MPS
15
+ original_mps_available = torch.backends.mps.is_available
16
+ torch.backends.mps.is_available = lambda: False
17
+
18
  from langchain_qdrant import QdrantVectorStore
19
  from langchain.docstore.document import Document
20
  from langchain_core.embeddings import Embeddings
 
39
 
40
  if truncate_dim and "matryoshka" in model_name.lower():
41
  # Use SentenceTransformer directly for Matryoshka models
42
+ # Fix for meta tensor issue: Explicitly force CPU
43
+ # MPS is already disabled at module level
44
+ # Explicitly pass device="cpu" to prevent MPS/CUDA detection
45
+ self.model = SentenceTransformer(
46
+ model_name,
47
+ truncate_dim=truncate_dim,
48
+ device="cpu" # Force CPU to prevent meta tensor issues
49
+ )
50
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
51
  else:
52
  # Use standard HuggingFaceEmbeddings
53
+ # Don't pass device parameter - let it load naturally on CPU
54
+ # This prevents the meta tensor error
55
+ if "model_kwargs" not in kwargs:
56
+ kwargs["model_kwargs"] = {}
57
+ # Remove device from model_kwargs if present to prevent meta tensor issues
58
+ kwargs["model_kwargs"].pop("device", None)
59
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
60
 
61
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
99
 
100
  def _create_embeddings(self) -> HuggingFaceEmbeddings:
101
  """Create embeddings model from configuration."""
 
 
102
  model_name = self.config["retriever"]["model"]
103
  normalize = self.config["retriever"]["normalize"]
104
 
105
+ # Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
106
+ # The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
107
+ # MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
108
+ model_kwargs = {
109
+ "device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
110
+ "trust_remote_code": True, # Some models need this
111
+ }
112
+
113
  encode_kwargs = {
114
  "normalize_embeddings": normalize,
115
  "batch_size": 100,
 
136
  return embeddings
137
 
138
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
139
+ # Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
140
+ # but ensure we're not using meta device
141
  embeddings = HuggingFaceEmbeddings(
142
  model_name=model_name,
143
  model_kwargs=model_kwargs,