Spaces:
Sleeping
Sleeping
lazy load vLLM
Browse files
app.py
CHANGED
|
@@ -1,38 +1,10 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import json
|
| 9 |
-
import uuid
|
| 10 |
-
import logging
|
| 11 |
-
import traceback
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
from collections import Counter
|
| 15 |
-
from typing import List, Dict, Any, Optional
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
import pandas as pd
|
| 19 |
-
import streamlit as st
|
| 20 |
-
import plotly.express as px
|
| 21 |
-
from langchain_core.messages import HumanMessage, AIMessage
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
from src.agents import get_multi_agent_chatbot, get_smart_chatbot, get_gemini_chatbot
|
| 25 |
-
from src.feedback import FeedbackManager
|
| 26 |
-
from src.ui_components import get_custom_css, display_chunk_statistics_charts, display_chunk_statistics_table, extract_chunk_statistics
|
| 27 |
-
|
| 28 |
-
from src.config.paths import (
|
| 29 |
-
IS_DEPLOYED,
|
| 30 |
-
PROJECT_DIR,
|
| 31 |
-
HF_CACHE_DIR,
|
| 32 |
-
FEEDBACK_DIR,
|
| 33 |
-
CONVERSATIONS_DIR,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
|
| 37 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 38 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
|
@@ -56,6 +28,17 @@ try:
|
|
| 56 |
except (ValueError, TypeError):
|
| 57 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# ===== Setup HuggingFace cache directories BEFORE any model imports =====
|
| 60 |
# CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
|
| 61 |
# Only override cache directories in deployed environment (local uses defaults)
|
|
@@ -73,1067 +56,857 @@ if IS_DEPLOYED and HF_CACHE_DIR:
|
|
| 73 |
except (PermissionError, OSError):
|
| 74 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 75 |
pass
|
| 76 |
-
|
| 77 |
else:
|
|
|
|
|
|
|
| 78 |
from dotenv import load_dotenv
|
| 79 |
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
logger.info(f"π PROJECT_DIR: {PROJECT_DIR}")
|
| 87 |
-
logger.info(f"π Environment: {'DEPLOYED' if IS_DEPLOYED else 'LOCAL'}")
|
| 88 |
-
logger.info(f"π§ OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
|
| 89 |
-
logger.info(f"π HuggingFace cache: {os.environ.get('HF_HOME', 'DEFAULT (not overridden)')}")
|
| 90 |
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if gpu_check not in st.session_state:
|
| 96 |
-
st.write(f"GPU check skipped: {e.__str__}")
|
| 97 |
-
print("CUDA:", cuda_)
|
| 98 |
-
logger.info("CUDA:", cuda_)
|
| 99 |
-
if cuda_:
|
| 100 |
-
if gpu_check not in st.session_state:
|
| 101 |
-
st.write(f"Device: {torch.cuda.get_device_name(0)}")
|
| 102 |
-
print("Device:", torch.cuda.get_device_name(0))
|
| 103 |
-
logger.info(f"Device: {torch.cuda.get_device_name(0)}")
|
| 104 |
-
except Exception as e:
|
| 105 |
-
if gpu_check not in st.session_state:
|
| 106 |
-
st.write(f"GPU check skipped: {e.__str__}")
|
| 107 |
-
logger.error(f"GPU check skipped: {e.__str__}")
|
| 108 |
-
print("GPU check skipped:", e, file=sys.stderr)
|
| 109 |
-
finally:
|
| 110 |
-
st.session_state.gpu_check = True
|
| 111 |
-
|
| 112 |
|
| 113 |
-
#
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
page_title="Intelligent Audit Report Chatbot"
|
| 119 |
)
|
|
|
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def get_system_type():
|
| 126 |
-
"""Get the current system type"""
|
| 127 |
-
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 128 |
-
if system == 'smart':
|
| 129 |
-
return "Smart Chatbot System"
|
| 130 |
-
else:
|
| 131 |
-
return "Multi-Agent System"
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# Check environment variable for system type (v1)
|
| 139 |
-
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 140 |
-
if system == 'smart':
|
| 141 |
-
return get_smart_chatbot()
|
| 142 |
-
else:
|
| 143 |
-
return get_multi_agent_chatbot()
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
"type": type(msg).__name__,
|
| 152 |
-
"content": str(msg.content)
|
| 153 |
-
})
|
| 154 |
-
return serialized
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
if content in seen_content:
|
| 166 |
-
continue
|
| 167 |
-
|
| 168 |
-
seen_content.add(content)
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
"
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
"
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
try:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
st.session_state.messages = []
|
| 207 |
-
if 'conversation_id' not in st.session_state:
|
| 208 |
-
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
|
| 209 |
-
if 'session_start_time' not in st.session_state:
|
| 210 |
-
st.session_state.session_start_time = time.time()
|
| 211 |
-
if 'active_filters' not in st.session_state:
|
| 212 |
-
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
|
| 213 |
-
# Track RAG retrieval history for feedback
|
| 214 |
-
if 'rag_retrieval_history' not in st.session_state:
|
| 215 |
-
st.session_state.rag_retrieval_history = []
|
| 216 |
-
# Version selection (v1 or beta)
|
| 217 |
-
if 'chatbot_version' not in st.session_state:
|
| 218 |
-
st.session_state.chatbot_version = "v1"
|
| 219 |
-
|
| 220 |
-
# Initialize chatbot based on version (only if not already initialized for this version)
|
| 221 |
-
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 222 |
-
|
| 223 |
-
# Check if we need to initialize: chatbot doesn't exist OR version changed
|
| 224 |
-
needs_init = (
|
| 225 |
-
chatbot_version_key not in st.session_state or
|
| 226 |
-
st.session_state.get('_last_version') != st.session_state.chatbot_version
|
| 227 |
-
)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
else:
|
| 235 |
-
spinner_msg = "π Loading AI models and connecting to database..."
|
| 236 |
-
|
| 237 |
-
with st.spinner(spinner_msg):
|
| 238 |
-
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 239 |
-
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 240 |
-
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 241 |
-
print("β
AI system ready!")
|
| 242 |
-
except Exception as e:
|
| 243 |
-
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 244 |
-
# Only show Gemini-specific error message for beta version
|
| 245 |
-
if st.session_state.chatbot_version == "beta":
|
| 246 |
-
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
|
| 247 |
-
else:
|
| 248 |
-
st.error("Please check your configuration and ensure all required models and databases are accessible.")
|
| 249 |
-
# Reset to v1 to prevent infinite loop
|
| 250 |
-
st.session_state.chatbot_version = "v1"
|
| 251 |
-
st.session_state['_last_version'] = "v1"
|
| 252 |
-
if 'chatbot' in st.session_state:
|
| 253 |
-
del st.session_state['chatbot']
|
| 254 |
-
st.stop() # Stop execution to prevent infinite loop
|
| 255 |
-
else:
|
| 256 |
-
# Chatbot already initialized for this version, just use it
|
| 257 |
-
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 258 |
|
| 259 |
-
#
|
| 260 |
-
|
| 261 |
-
st.session_state.messages = []
|
| 262 |
-
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
|
| 263 |
-
st.session_state.session_start_time = time.time()
|
| 264 |
-
st.session_state.rag_retrieval_history = []
|
| 265 |
-
st.session_state.feedback_submitted = False
|
| 266 |
-
st.session_state.reset_conversation = False
|
| 267 |
-
st.rerun()
|
| 268 |
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
"
|
| 278 |
-
|
| 279 |
-
index=0 if st.session_state.chatbot_version == "v1" else 1,
|
| 280 |
-
horizontal=True,
|
| 281 |
-
key="version_selector",
|
| 282 |
-
help="Select v1 (default RAG system) or beta (Gemini FSA)"
|
| 283 |
-
)
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
|
| 291 |
-
#
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
#
|
| 304 |
-
|
| 305 |
-
st.info("π¬ **Beta Mode**: Using Google Gemini FSA")
|
| 306 |
|
| 307 |
-
|
| 308 |
-
duration = int(time.time() - st.session_state.session_start_time)
|
| 309 |
-
duration_str = f"{duration // 60}m {duration % 60}s"
|
| 310 |
-
st.markdown(f'''
|
| 311 |
-
<div class="session-info">
|
| 312 |
-
<strong>Session Info:</strong> Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
|
| 313 |
-
</div>
|
| 314 |
-
''', unsafe_allow_html=True)
|
| 315 |
|
| 316 |
-
# Load filter options
|
| 317 |
-
filter_options = load_filter_options()
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
- Check the "Retrieved Documents" tab to see source material
|
| 339 |
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
-
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
st.markdown('<div class="filter-title">π Specific Reports (Filename Filter)</div>', unsafe_allow_html=True)
|
| 354 |
-
st.markdown('<p style="font-size: 0.85em; color: #666;">β οΈ Selecting specific reports will ignore all other filters</p>', unsafe_allow_html=True)
|
| 355 |
-
selected_filenames = st.multiselect(
|
| 356 |
-
"Select specific reports:",
|
| 357 |
-
options=filter_options.get('filenames', []),
|
| 358 |
-
default=st.session_state.active_filters.get('filenames', []),
|
| 359 |
-
key="filenames_filter",
|
| 360 |
-
help="Choose specific reports to search. When enabled, all other filters are ignored."
|
| 361 |
)
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
#
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 368 |
-
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 369 |
-
selected_sources = st.multiselect(
|
| 370 |
-
"Select sources:",
|
| 371 |
-
options=filter_options['sources'],
|
| 372 |
-
default=st.session_state.active_filters['sources'],
|
| 373 |
-
disabled = filename_mode,
|
| 374 |
-
key="sources_filter",
|
| 375 |
-
help="Choose which types of reports to search"
|
| 376 |
-
)
|
| 377 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
st.markdown('<div class="filter-title">π
Years</div>', unsafe_allow_html=True)
|
| 382 |
-
selected_years = st.multiselect(
|
| 383 |
-
"Select years:",
|
| 384 |
-
options=filter_options['years'],
|
| 385 |
-
default=st.session_state.active_filters['years'],
|
| 386 |
-
disabled = filename_mode,
|
| 387 |
-
key="years_filter",
|
| 388 |
-
help="Choose which years to search"
|
| 389 |
-
)
|
| 390 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 391 |
|
| 392 |
-
#
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
selected_districts = st.multiselect(
|
| 396 |
-
"Select districts:",
|
| 397 |
-
options=filter_options['districts'],
|
| 398 |
-
default=st.session_state.active_filters['districts'],
|
| 399 |
-
disabled = filename_mode,
|
| 400 |
-
key="districts_filter",
|
| 401 |
-
help="Choose which districts to search"
|
| 402 |
-
)
|
| 403 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 404 |
|
| 405 |
-
|
| 406 |
-
st.session_state.active_filters = {
|
| 407 |
-
'sources': selected_sources if not filename_mode else [],
|
| 408 |
-
'years': selected_years if not filename_mode else [],
|
| 409 |
-
'districts': selected_districts if not filename_mode else [],
|
| 410 |
-
'filenames': selected_filenames
|
| 411 |
-
}
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
#
|
| 423 |
-
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
| 432 |
|
| 433 |
-
#
|
| 434 |
-
st.
|
|
|
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
|
| 439 |
-
|
| 440 |
-
# Use a counter to force input clearing
|
| 441 |
-
if 'input_counter' not in st.session_state:
|
| 442 |
-
st.session_state.input_counter = 0
|
| 443 |
-
|
| 444 |
-
# Handle pending question from example questions section
|
| 445 |
-
if 'pending_question' in st.session_state and st.session_state.pending_question:
|
| 446 |
-
default_value = st.session_state.pending_question
|
| 447 |
-
# Increment counter to force new input widget
|
| 448 |
-
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 449 |
-
del st.session_state.pending_question
|
| 450 |
-
key_suffix = st.session_state.input_counter
|
| 451 |
-
else:
|
| 452 |
-
default_value = ""
|
| 453 |
-
key_suffix = st.session_state.input_counter
|
| 454 |
-
|
| 455 |
-
user_input = st.text_input(
|
| 456 |
-
"Type your message here...",
|
| 457 |
-
placeholder="Ask about budget allocations, expenditures, or audit findings...",
|
| 458 |
-
key=f"user_input_{key_suffix}",
|
| 459 |
-
label_visibility="collapsed",
|
| 460 |
-
value=default_value if default_value else None
|
| 461 |
-
)
|
| 462 |
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
-
|
| 467 |
-
if st.button("ποΈ Clear Chat", key="clear_chat_button"):
|
| 468 |
-
st.session_state.reset_conversation = True
|
| 469 |
-
# Clear all conversation files
|
| 470 |
-
conversations_path = CONVERSATIONS_DIR
|
| 471 |
-
if conversations_path.exists():
|
| 472 |
-
for file in conversations_path.iterdir():
|
| 473 |
-
if file.suffix == '.json':
|
| 474 |
-
file.unlink()
|
| 475 |
-
st.rerun()
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
filter_context_str = ""
|
| 481 |
-
if selected_filenames:
|
| 482 |
-
filter_context_str += "FILTER CONTEXT:\n"
|
| 483 |
-
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
|
| 484 |
-
filter_context_str += "USER QUERY:\n"
|
| 485 |
-
elif selected_sources or selected_years or selected_districts:
|
| 486 |
-
filter_context_str += "FILTER CONTEXT:\n"
|
| 487 |
-
if selected_sources:
|
| 488 |
-
filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
|
| 489 |
-
if selected_years:
|
| 490 |
-
filter_context_str += f"Years: {', '.join(selected_years)}\n"
|
| 491 |
-
if selected_districts:
|
| 492 |
-
filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
|
| 493 |
-
filter_context_str += "USER QUERY:\n"
|
| 494 |
-
|
| 495 |
-
full_query = filter_context_str + user_input
|
| 496 |
|
| 497 |
-
#
|
| 498 |
-
|
|
|
|
|
|
|
| 499 |
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
if isinstance(chat_result, dict):
|
| 508 |
-
response = chat_result['response']
|
| 509 |
-
rag_result = chat_result.get('rag_result')
|
| 510 |
-
st.session_state.last_rag_result = rag_result
|
| 511 |
-
|
| 512 |
-
# Track RAG retrieval for feedback
|
| 513 |
-
if rag_result:
|
| 514 |
-
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 515 |
-
|
| 516 |
-
# For Gemini, also check gemini_result for sources
|
| 517 |
-
if not sources or len(sources) == 0:
|
| 518 |
-
gemini_result = chat_result.get('gemini_result')
|
| 519 |
-
print(f"π DEBUG: Checking gemini_result for sources...")
|
| 520 |
-
print(f" gemini_result exists: {gemini_result is not None}")
|
| 521 |
-
if gemini_result:
|
| 522 |
-
print(f" gemini_result type: {type(gemini_result)}")
|
| 523 |
-
print(f" has sources attr: {hasattr(gemini_result, 'sources')}")
|
| 524 |
-
if hasattr(gemini_result, 'sources'):
|
| 525 |
-
print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}")
|
| 526 |
-
|
| 527 |
-
if gemini_result and hasattr(gemini_result, 'sources'):
|
| 528 |
-
# Format Gemini sources for display
|
| 529 |
-
if hasattr(st.session_state.chatbot, 'gemini_client'):
|
| 530 |
-
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result)
|
| 531 |
-
print(f"β
Formatted {len(sources)} sources from gemini_client")
|
| 532 |
-
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'):
|
| 533 |
-
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 534 |
-
print(f"β
Formatted {len(sources)} sources from _format_gemini_sources")
|
| 535 |
-
|
| 536 |
-
# Update rag_result with sources if we found them
|
| 537 |
-
if sources and len(sources) > 0:
|
| 538 |
-
if isinstance(rag_result, dict):
|
| 539 |
-
rag_result['sources'] = sources
|
| 540 |
-
elif hasattr(rag_result, 'sources'):
|
| 541 |
-
rag_result.sources = sources
|
| 542 |
-
# Update last_rag_result with sources
|
| 543 |
-
st.session_state.last_rag_result = rag_result
|
| 544 |
-
print(f"β
Updated rag_result with {len(sources)} sources")
|
| 545 |
-
|
| 546 |
-
# Get the actual RAG query
|
| 547 |
-
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 548 |
-
if actual_rag_query:
|
| 549 |
-
# Format it like the log message
|
| 550 |
-
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
| 551 |
-
formatted_query = f"{timestamp} - INFO - π ACTUAL RAG QUERY: '{actual_rag_query}'"
|
| 552 |
-
else:
|
| 553 |
-
formatted_query = "No RAG query available"
|
| 554 |
-
|
| 555 |
-
# Extract filters from active filters
|
| 556 |
-
filters_used = {
|
| 557 |
-
"sources": st.session_state.active_filters.get('sources', []),
|
| 558 |
-
"years": st.session_state.active_filters.get('years', []),
|
| 559 |
-
"districts": st.session_state.active_filters.get('districts', []),
|
| 560 |
-
"filenames": st.session_state.active_filters.get('filenames', [])
|
| 561 |
-
}
|
| 562 |
-
|
| 563 |
-
retrieval_entry = {
|
| 564 |
-
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 565 |
-
"rag_query_expansion": formatted_query,
|
| 566 |
-
"docs_retrieved": serialize_documents(sources),
|
| 567 |
-
"filters_applied": filters_used,
|
| 568 |
-
"timestamp": time.time()
|
| 569 |
-
}
|
| 570 |
-
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 571 |
-
|
| 572 |
-
# Debug logging
|
| 573 |
-
print(f"π RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history")
|
| 574 |
-
else:
|
| 575 |
-
response = chat_result
|
| 576 |
-
st.session_state.last_rag_result = None
|
| 577 |
-
|
| 578 |
-
# Add bot response to history
|
| 579 |
-
st.session_state.messages.append(AIMessage(content=response))
|
| 580 |
-
|
| 581 |
-
except Exception as e:
|
| 582 |
-
error_msg = f"Sorry, I encountered an error: {str(e)}"
|
| 583 |
-
st.session_state.messages.append(AIMessage(content=error_msg))
|
| 584 |
|
| 585 |
-
#
|
| 586 |
-
st.session_state.
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
with tab2:
|
| 590 |
-
# Document retrieval panel
|
| 591 |
-
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
|
| 592 |
-
rag_result = st.session_state.last_rag_result
|
| 593 |
|
| 594 |
-
#
|
| 595 |
-
|
| 596 |
-
if
|
| 597 |
-
|
| 598 |
-
sources = rag_result.sources
|
| 599 |
-
elif isinstance(rag_result, dict) and 'sources' in rag_result:
|
| 600 |
-
# Dictionary format from multi-agent system
|
| 601 |
-
sources = rag_result['sources']
|
| 602 |
|
| 603 |
-
#
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
if
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
sources = st.session_state.chatbot._format_gemini_sources(gemini_result)
|
| 612 |
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
|
| 618 |
-
unique_filenames.add(filename)
|
| 619 |
-
|
| 620 |
-
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
|
| 621 |
-
if len(unique_filenames) < len(sources):
|
| 622 |
-
st.info(f"π‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
|
| 623 |
-
|
| 624 |
-
# Extract and display statistics
|
| 625 |
-
stats = extract_chunk_statistics(sources)
|
| 626 |
-
|
| 627 |
-
# Show charts for 10+ results, tables for fewer
|
| 628 |
-
if len(sources) >= 10:
|
| 629 |
-
display_chunk_statistics_charts(stats, "Retrieval Statistics")
|
| 630 |
-
# Also show tables below charts for detailed view
|
| 631 |
-
st.markdown("---")
|
| 632 |
-
display_chunk_statistics_table(stats, "Retrieval Distribution")
|
| 633 |
-
else:
|
| 634 |
-
display_chunk_statistics_table(stats, "Retrieval Distribution")
|
| 635 |
-
|
| 636 |
-
st.markdown("---")
|
| 637 |
-
st.markdown("### π Document Details")
|
| 638 |
-
|
| 639 |
-
for i, doc in enumerate(sources): # Show all documents
|
| 640 |
-
# Get relevance score and ID if available
|
| 641 |
-
metadata = getattr(doc, 'metadata', {})
|
| 642 |
-
# Handle both standard RAG scores and Gemini scores
|
| 643 |
-
score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score')
|
| 644 |
-
chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown')
|
| 645 |
-
if score is not None:
|
| 646 |
-
try:
|
| 647 |
-
score_text = f" (Score: {float(score):.3f})"
|
| 648 |
-
except (ValueError, TypeError):
|
| 649 |
-
score_text = ""
|
| 650 |
-
else:
|
| 651 |
-
score_text = ""
|
| 652 |
-
if chunk_id and chunk_id != 'Unknown':
|
| 653 |
-
score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)"
|
| 654 |
-
|
| 655 |
-
with st.expander(f"π Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 656 |
-
# Display document metadata with emojis
|
| 657 |
-
metadata = getattr(doc, 'metadata', {})
|
| 658 |
-
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
|
| 659 |
-
|
| 660 |
-
with col1:
|
| 661 |
-
st.write(f"π **File:** {metadata.get('filename', 'Unknown')}")
|
| 662 |
-
with col2:
|
| 663 |
-
st.write(f"ποΈ **Source:** {metadata.get('source', 'Unknown')}")
|
| 664 |
-
with col3:
|
| 665 |
-
st.write(f"π
**Year:** {metadata.get('year', 'Unknown')}")
|
| 666 |
-
with col4:
|
| 667 |
-
# Display page number and chunk ID
|
| 668 |
-
page = metadata.get('page_label', metadata.get('page', 'Unknown'))
|
| 669 |
-
chunk_id = metadata.get('_id', 'Unknown')
|
| 670 |
-
st.write(f"π **Page:** {page}")
|
| 671 |
-
st.write(f"π **ID:** {chunk_id}")
|
| 672 |
-
|
| 673 |
-
# Display full content (no truncation)
|
| 674 |
-
content = getattr(doc, 'page_content', 'No content available')
|
| 675 |
-
st.write(f"**Full Content:**")
|
| 676 |
-
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
|
| 677 |
else:
|
| 678 |
-
st.
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
with col2:
|
| 714 |
-
is_feedback_about_last_retrieval = st.checkbox(
|
| 715 |
-
"Feedback about last retrieval only",
|
| 716 |
-
value=True,
|
| 717 |
-
help="If checked, feedback applies to the most recent document retrieval"
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
open_ended_feedback = st.text_area(
|
| 721 |
-
"Your feedback (optional)",
|
| 722 |
-
placeholder="Tell us what went well or what could be improved...",
|
| 723 |
-
height=100
|
| 724 |
-
)
|
| 725 |
-
|
| 726 |
-
# Disable submit if no score selected
|
| 727 |
-
submit_disabled = feedback_score is None
|
| 728 |
-
|
| 729 |
-
submitted = st.form_submit_button(
|
| 730 |
-
"π€ Submit Feedback",
|
| 731 |
-
width='stretch',
|
| 732 |
-
disabled=submit_disabled
|
| 733 |
-
)
|
| 734 |
-
|
| 735 |
-
if submitted:
|
| 736 |
-
# Log the feedback data being submitted
|
| 737 |
-
print("=" * 80)
|
| 738 |
-
print("π FEEDBACK SUBMISSION: Starting...")
|
| 739 |
-
print("=" * 80)
|
| 740 |
-
st.write("π **Debug: Feedback Data Being Submitted:**")
|
| 741 |
-
|
| 742 |
-
# Extract transcript from messages
|
| 743 |
-
transcript = feedback_manager.extract_transcript(st.session_state.messages)
|
| 744 |
-
|
| 745 |
-
# Build retrievals structure
|
| 746 |
-
retrievals = feedback_manager.build_retrievals_structure(
|
| 747 |
-
|
| 748 |
-
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 749 |
-
st.session_state.messages
|
| 750 |
-
)
|
| 751 |
-
|
| 752 |
-
# Build feedback_score_related_retrieval_docs
|
| 753 |
-
|
| 754 |
-
feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs(
|
| 755 |
-
is_feedback_about_last_retrieval,
|
| 756 |
-
st.session_state.messages,
|
| 757 |
-
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
|
| 758 |
)
|
| 759 |
|
| 760 |
-
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
"score": feedback_score,
|
| 767 |
-
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
|
| 768 |
-
"conversation_id": st.session_state.conversation_id,
|
| 769 |
-
"timestamp": time.time(),
|
| 770 |
-
"message_count": len(st.session_state.messages),
|
| 771 |
-
"has_retrievals": has_retrievals,
|
| 772 |
-
"retrieval_count": len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0,
|
| 773 |
-
"transcript": transcript,
|
| 774 |
-
"retrievals": retrievals,
|
| 775 |
-
"feedback_score_related_retrieval_docs": feedback_score_related_retrieval_docs,
|
| 776 |
-
"retrieved_data": retrieved_data_old_format # Preserved old column
|
| 777 |
-
}
|
| 778 |
-
|
| 779 |
-
print(f"π FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
|
| 780 |
-
|
| 781 |
-
# Create UserFeedback dataclass instance
|
| 782 |
-
feedback_obj = None # Initialize outside try block
|
| 783 |
-
try:
|
| 784 |
-
feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict)
|
| 785 |
-
print(f"β
FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 786 |
-
st.write(f"β
**Feedback Object Created**")
|
| 787 |
-
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
| 788 |
-
st.write(f"- Score: {feedback_obj.score}/5")
|
| 789 |
-
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
|
| 790 |
-
|
| 791 |
-
# Convert back to dict for JSON serialization
|
| 792 |
-
feedback_data = feedback_obj.to_dict()
|
| 793 |
-
except Exception as e:
|
| 794 |
-
print(f"β FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
|
| 795 |
-
st.error(f"Failed to create feedback object: {e}")
|
| 796 |
-
feedback_data = feedback_dict
|
| 797 |
-
|
| 798 |
-
# Display the data being submitted
|
| 799 |
-
st.json(feedback_data)
|
| 800 |
-
|
| 801 |
-
# Save feedback to file - use PROJECT_DIR to ensure writability
|
| 802 |
-
feedback_dir = FEEDBACK_DIR
|
| 803 |
-
try:
|
| 804 |
-
# Ensure directory exists with write permissions (777 for compatibility)
|
| 805 |
-
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 806 |
-
except (PermissionError, OSError) as e:
|
| 807 |
-
logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
|
| 808 |
-
# Fallback to relative path
|
| 809 |
-
feedback_dir = Path("feedback")
|
| 810 |
-
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 811 |
-
|
| 812 |
-
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
|
| 813 |
-
|
| 814 |
-
try:
|
| 815 |
-
# Ensure parent directory exists before writing
|
| 816 |
-
feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
|
| 817 |
-
|
| 818 |
-
# Save to local file first
|
| 819 |
-
print(f"πΎ FEEDBACK SAVE: Saving to local file: {feedback_file}")
|
| 820 |
-
with open(feedback_file, 'w') as f:
|
| 821 |
-
json.dump(feedback_data, f, indent=2, default=str)
|
| 822 |
-
|
| 823 |
-
print(f"β
FEEDBACK SAVE: Local file saved successfully")
|
| 824 |
-
|
| 825 |
-
# Save to Snowflake if enabled and credentials available
|
| 826 |
-
logger.info("π FEEDBACK SAVE: Starting Snowflake save process...")
|
| 827 |
-
logger.info(f"π FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
|
| 828 |
-
|
| 829 |
-
snowflake_success = False
|
| 830 |
-
try:
|
| 831 |
-
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
|
| 832 |
-
logger.info(f"π SNOWFLAKE CHECK: enabled={snowflake_enabled}")
|
| 833 |
-
|
| 834 |
-
if snowflake_enabled:
|
| 835 |
-
if feedback_obj:
|
| 836 |
-
try:
|
| 837 |
-
logger.info("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 838 |
-
print("π€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 839 |
-
|
| 840 |
-
# Show spinner while saving to Snowflake (can take 10-15 seconds)
|
| 841 |
-
# This includes: connection establishment (~5s), data preparation, and SQL execution (~5s)
|
| 842 |
-
with st.spinner("πΎ Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"):
|
| 843 |
-
snowflake_success = feedback_manager.save_to_snowflake(feedback_obj)
|
| 844 |
-
|
| 845 |
-
if snowflake_success:
|
| 846 |
-
logger.info("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 847 |
-
print("β
SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 848 |
-
else:
|
| 849 |
-
logger.warning("β οΈ SNOWFLAKE UI: Save failed")
|
| 850 |
-
print("β οΈ SNOWFLAKE UI: Save failed")
|
| 851 |
-
except Exception as e:
|
| 852 |
-
logger.error(f"β SNOWFLAKE UI ERROR: {e}")
|
| 853 |
-
print(f"β SNOWFLAKE UI ERROR: {e}")
|
| 854 |
-
traceback.print_exc()
|
| 855 |
-
snowflake_success = False
|
| 856 |
-
else:
|
| 857 |
-
logger.warning("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 858 |
-
print("β οΈ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 859 |
-
snowflake_success = False
|
| 860 |
-
else:
|
| 861 |
-
logger.info("π‘ SNOWFLAKE UI: Integration disabled")
|
| 862 |
-
print("π‘ SNOWFLAKE UI: Integration disabled")
|
| 863 |
-
# If Snowflake is disabled, consider it successful (local save only)
|
| 864 |
-
snowflake_success = True
|
| 865 |
-
|
| 866 |
-
except Exception as e:
|
| 867 |
-
logger.error(f"β Exception in Snowflake save: {type(e).__name__}: {e}")
|
| 868 |
-
print(f"β Exception in Snowflake save: {type(e).__name__}: {e}")
|
| 869 |
-
snowflake_success = False
|
| 870 |
-
|
| 871 |
-
# Only show success if Snowflake save succeeded (or if Snowflake is disabled)
|
| 872 |
-
if snowflake_success:
|
| 873 |
-
st.success("β
Thank you for your feedback! It has been saved successfully.")
|
| 874 |
-
st.balloons()
|
| 875 |
else:
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
print("=" * 80)
|
| 884 |
-
|
| 885 |
-
# Log file location
|
| 886 |
-
st.info(f"π Feedback saved to: {feedback_file}")
|
| 887 |
-
|
| 888 |
-
except Exception as e:
|
| 889 |
-
print(f"β FEEDBACK SUBMISSION: Error saving feedback: {e}")
|
| 890 |
-
print(f"β FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
|
| 891 |
-
traceback.print_exc()
|
| 892 |
-
st.error(f"β Error saving feedback: {e}")
|
| 893 |
-
st.write(f"Debug error: {str(e)}")
|
| 894 |
-
else:
|
| 895 |
-
# Feedback already submitted - show success message and reset option
|
| 896 |
-
st.success("β
Feedback already submitted for this conversation!")
|
| 897 |
-
col1, col2 = st.columns([1, 1])
|
| 898 |
-
with col1:
|
| 899 |
-
if st.button("π Submit New Feedback", key="new_feedback_button", width='stretch'):
|
| 900 |
-
try:
|
| 901 |
-
st.session_state.feedback_submitted = False
|
| 902 |
-
st.rerun()
|
| 903 |
-
except Exception as e:
|
| 904 |
-
# Handle any Streamlit API exceptions gracefully
|
| 905 |
-
logger.error(f"Error resetting feedback state: {e}")
|
| 906 |
-
st.error(f"Error resetting feedback. Please refresh the page.")
|
| 907 |
-
with col2:
|
| 908 |
-
if st.button("π View Conversation", key="view_conversation_button", width='stretch'):
|
| 909 |
-
# Scroll to conversation - this is handled by the auto-scroll at bottom
|
| 910 |
-
pass
|
| 911 |
-
|
| 912 |
-
# Display retrieval history stats
|
| 913 |
-
if st.session_state.rag_retrieval_history:
|
| 914 |
-
st.markdown("---")
|
| 915 |
-
st.markdown("#### π Retrieval History")
|
| 916 |
-
|
| 917 |
-
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
|
| 918 |
-
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 919 |
-
st.markdown(f"### **Retrieval #{idx}**")
|
| 920 |
-
|
| 921 |
-
# Display timestamp if available
|
| 922 |
-
if entry.get("timestamp"):
|
| 923 |
-
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"]))
|
| 924 |
-
st.caption(f"π {timestamp_str}")
|
| 925 |
-
|
| 926 |
-
# Display the actual RAG query
|
| 927 |
-
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
| 928 |
-
st.markdown("**π RAG Query:**")
|
| 929 |
-
st.code(rag_query_expansion, language="text")
|
| 930 |
-
|
| 931 |
-
# Display filters used
|
| 932 |
-
filters_applied = entry.get("filters_applied", {})
|
| 933 |
-
if filters_applied and any(filters_applied.values()):
|
| 934 |
-
st.markdown("**π― Filters Applied:**")
|
| 935 |
-
filter_display = {}
|
| 936 |
-
if filters_applied.get("sources"):
|
| 937 |
-
filter_display["Sources"] = filters_applied["sources"]
|
| 938 |
-
if filters_applied.get("years"):
|
| 939 |
-
filter_display["Years"] = filters_applied["years"]
|
| 940 |
-
if filters_applied.get("districts"):
|
| 941 |
-
filter_display["Districts"] = filters_applied["districts"]
|
| 942 |
-
if filters_applied.get("filenames"):
|
| 943 |
-
filter_display["Filenames"] = filters_applied["filenames"]
|
| 944 |
|
| 945 |
-
if
|
| 946 |
-
st.
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
st.info("No conversation history available")
|
| 967 |
-
|
| 968 |
-
# Display documents retrieved
|
| 969 |
-
docs_retrieved = entry.get("docs_retrieved", [])
|
| 970 |
-
if docs_retrieved:
|
| 971 |
-
st.markdown(f"**π Documents Retrieved ({len(docs_retrieved)}):**")
|
| 972 |
-
with st.expander(f"View {len(docs_retrieved)} documents", expanded=False):
|
| 973 |
-
for doc_idx, doc in enumerate(docs_retrieved, 1):
|
| 974 |
-
st.markdown(f"**Document {doc_idx}:**")
|
| 975 |
-
|
| 976 |
-
# Display metadata
|
| 977 |
-
metadata = doc.get("metadata", {})
|
| 978 |
-
if metadata:
|
| 979 |
-
col1, col2, col3 = st.columns(3)
|
| 980 |
-
with col1:
|
| 981 |
-
st.write(f"π **File:** {metadata.get('filename', 'Unknown')}")
|
| 982 |
-
with col2:
|
| 983 |
-
st.write(f"ποΈ **Source:** {metadata.get('source', 'Unknown')}")
|
| 984 |
-
with col3:
|
| 985 |
-
st.write(f"π
**Year:** {metadata.get('year', 'Unknown')}")
|
| 986 |
-
|
| 987 |
-
# Additional metadata
|
| 988 |
-
if metadata.get('district'):
|
| 989 |
-
st.write(f"π **District:** {metadata.get('district')}")
|
| 990 |
-
if metadata.get('page'):
|
| 991 |
-
st.write(f"π **Page:** {metadata.get('page')}")
|
| 992 |
-
if metadata.get('score') is not None:
|
| 993 |
-
st.write(f"β **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"β **Score:** {metadata.get('score')}")
|
| 994 |
-
|
| 995 |
-
# Display content preview (first 200 chars)
|
| 996 |
-
content = doc.get("content", doc.get("page_content", ""))
|
| 997 |
-
if content:
|
| 998 |
-
st.markdown("**Content Preview:**")
|
| 999 |
-
st.text_area(
|
| 1000 |
-
"Content Preview",
|
| 1001 |
-
value=content[:200] + ("..." if len(content) > 200 else ""),
|
| 1002 |
-
height=100,
|
| 1003 |
-
disabled=True,
|
| 1004 |
-
label_visibility="collapsed",
|
| 1005 |
-
key=f"retrieval_{idx}_doc_{doc_idx}_preview"
|
| 1006 |
-
)
|
| 1007 |
-
|
| 1008 |
-
if doc_idx < len(docs_retrieved):
|
| 1009 |
-
st.markdown("---")
|
| 1010 |
-
else:
|
| 1011 |
-
st.info("No documents retrieved")
|
| 1012 |
-
|
| 1013 |
-
# Display summary stats
|
| 1014 |
-
st.markdown("**π Summary:**")
|
| 1015 |
-
st.json({
|
| 1016 |
-
"conversation_length": len(conversation_up_to),
|
| 1017 |
-
"documents_retrieved": len(docs_retrieved)
|
| 1018 |
-
})
|
| 1019 |
|
| 1020 |
-
if
|
| 1021 |
-
st.
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1068 |
st.rerun()
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
st.info("π‘ **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
|
| 1075 |
-
|
| 1076 |
-
st.markdown("---")
|
| 1077 |
-
|
| 1078 |
-
# Question 3
|
| 1079 |
-
# st.markdown("**Question 3:**")
|
| 1080 |
-
custom_q2 = st.text_area(
|
| 1081 |
-
"Edit question 3:",
|
| 1082 |
-
value=st.session_state.custom_question_2,
|
| 1083 |
-
height=80,
|
| 1084 |
-
key="edit_question_3",
|
| 1085 |
-
help="Modify this question to fit your needs, then click 'Use This Question'"
|
| 1086 |
-
)
|
| 1087 |
-
col1, col2 = st.columns([1, 4])
|
| 1088 |
-
with col1:
|
| 1089 |
-
if st.button("π Use Question 3", key="use_custom_2", width='stretch'):
|
| 1090 |
-
if custom_q2.strip():
|
| 1091 |
-
st.session_state.pending_question = custom_q2.strip()
|
| 1092 |
-
st.session_state.custom_question_2 = custom_q2.strip()
|
| 1093 |
-
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
|
| 1094 |
st.rerun()
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
|
|
|
| 1103 |
|
| 1104 |
-
#
|
| 1105 |
-
st.
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1111 |
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
print("=" * 80)
|
| 1121 |
-
print("\nPlease run this app using:")
|
| 1122 |
-
print(" streamlit run app.py")
|
| 1123 |
-
print("\nNot: python app.py")
|
| 1124 |
-
print("\nThe app will not function correctly when run with 'python app.py'")
|
| 1125 |
-
print("=" * 80)
|
| 1126 |
-
import sys
|
| 1127 |
-
sys.exit(1)
|
| 1128 |
-
except ImportError:
|
| 1129 |
-
# Streamlit not installed or not in Streamlit context
|
| 1130 |
-
print("=" * 80)
|
| 1131 |
-
print("β οΈ WARNING: This is a Streamlit app!")
|
| 1132 |
-
print("=" * 80)
|
| 1133 |
-
print("\nPlease run this app using:")
|
| 1134 |
-
print(" streamlit run app.py")
|
| 1135 |
-
print("\nNot: python app.py")
|
| 1136 |
-
print("=" * 80)
|
| 1137 |
-
import sys
|
| 1138 |
-
sys.exit(1)
|
| 1139 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
+
FempowerBot Training Simulator - Main Application
|
| 3 |
+
Interactive chatbot for practicing communication strategies.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
| 7 |
+
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 10 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
|
|
|
| 28 |
except (ValueError, TypeError):
|
| 29 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 30 |
|
| 31 |
+
# ===== Import path configuration BEFORE other imports =====
|
| 32 |
+
from src.config.paths import (
|
| 33 |
+
IS_DEPLOYED,
|
| 34 |
+
PROJECT_DIR,
|
| 35 |
+
HF_CACHE_DIR,
|
| 36 |
+
FEEDBACK_DIR,
|
| 37 |
+
CONVERSATIONS_DIR,
|
| 38 |
+
PROMPTS_DIR,
|
| 39 |
+
LOGS_DIR,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
# ===== Setup HuggingFace cache directories BEFORE any model imports =====
|
| 43 |
# CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
|
| 44 |
# Only override cache directories in deployed environment (local uses defaults)
|
|
|
|
| 56 |
except (PermissionError, OSError):
|
| 57 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 58 |
pass
|
|
|
|
| 59 |
else:
|
| 60 |
+
# Local development - load .env file and ensure NO cache vars are set
|
| 61 |
+
# Let HuggingFace use its defaults (~/.cache/huggingface)
|
| 62 |
from dotenv import load_dotenv
|
| 63 |
load_dotenv()
|
| 64 |
+
|
| 65 |
+
# Unset any HF cache variables that might exist in the environment
|
| 66 |
+
for var in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_DATASETS_CACHE", "HF_HUB_CACHE", "SENTENCE_TRANSFORMERS_HOME"]:
|
| 67 |
+
if var in os.environ:
|
| 68 |
+
del os.environ[var]
|
| 69 |
|
| 70 |
+
# ===== NOW safe to import everything else =====
|
| 71 |
+
import streamlit as st
|
| 72 |
+
from pathlib import Path
|
| 73 |
+
import json
|
| 74 |
+
from datetime import datetime
|
| 75 |
+
import logging
|
| 76 |
|
| 77 |
+
from src.config.loader import config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Import Ollama for local development (lightweight import)
|
| 80 |
+
if not IS_DEPLOYED:
|
| 81 |
+
import ollama
|
| 82 |
|
| 83 |
+
# NOTE: vLLM is NOT imported here! It's imported lazily in load_model()
|
| 84 |
+
# Reason: vLLM import takes 30-60 seconds due to CUDA initialization
|
| 85 |
+
# This keeps app startup fast and lets Streamlit render UI immediately
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# ===== Setup logging =====
|
| 88 |
+
log_level = os.getenv("LOG_LEVEL", config.get("logging.level", "INFO"))
|
| 89 |
+
logging.basicConfig(
|
| 90 |
+
level=getattr(logging, log_level),
|
| 91 |
+
format=config.get("logging.format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
|
| 92 |
)
|
| 93 |
+
logger = logging.getLogger(__name__)
|
| 94 |
|
| 95 |
+
# Reduce noise from external libraries
|
| 96 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 97 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 98 |
+
logging.getLogger("fsevents").setLevel(logging.WARNING)
|
| 99 |
|
| 100 |
+
logger.info(f"Starting FempowerBot - Deployed: {IS_DEPLOYED}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
# Log startup message
|
| 103 |
+
if IS_DEPLOYED:
|
| 104 |
+
logger.info("App starting in DEPLOYED mode (vLLM will be imported when model is loaded)")
|
| 105 |
+
else:
|
| 106 |
+
logger.info("App starting in LOCAL mode (using Ollama)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
# ===== Page Configuration =====
|
| 109 |
+
st.set_page_config(
|
| 110 |
+
page_title=config.get("app.title", "FempowerBot Training Simulator"),
|
| 111 |
+
page_icon=config.get("app.page_icon", "π¬"),
|
| 112 |
+
layout=config.get("app.layout", "wide")
|
| 113 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
# ===== Initialize session state =====
|
| 116 |
+
if "messages" not in st.session_state:
|
| 117 |
+
st.session_state.messages = []
|
| 118 |
+
if "model_loaded" not in st.session_state:
|
| 119 |
+
st.session_state.model_loaded = False
|
| 120 |
+
if "current_model" not in st.session_state:
|
| 121 |
+
st.session_state.current_model = None
|
| 122 |
+
if "current_persona" not in st.session_state:
|
| 123 |
+
st.session_state.current_persona = None
|
| 124 |
+
if "current_prompt_type" not in st.session_state:
|
| 125 |
+
st.session_state.current_prompt_type = None
|
| 126 |
+
if "session_id" not in st.session_state:
|
| 127 |
+
st.session_state.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 128 |
+
if "response_times" not in st.session_state:
|
| 129 |
+
st.session_state.response_times = []
|
| 130 |
+
if "custom_prompt" not in st.session_state:
|
| 131 |
+
st.session_state.custom_prompt = None
|
| 132 |
+
if "loaded_prompt_text" not in st.session_state:
|
| 133 |
+
st.session_state.loaded_prompt_text = ""
|
| 134 |
+
if "prompt_edited" not in st.session_state:
|
| 135 |
+
st.session_state.prompt_edited = False
|
| 136 |
+
if "few_shot_examples" not in st.session_state:
|
| 137 |
+
st.session_state.few_shot_examples = ""
|
| 138 |
+
if "custom_gen_params" not in st.session_state:
|
| 139 |
+
st.session_state.custom_gen_params = None
|
| 140 |
+
if "show_save_dialog" not in st.session_state:
|
| 141 |
+
st.session_state.show_save_dialog = False
|
| 142 |
+
if "last_prompt_selection" not in st.session_state:
|
| 143 |
+
st.session_state.last_prompt_selection = ""
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_prompt_from_file(persona: str, prompt_type: str) -> str:
|
| 147 |
+
"""Load prompt text from file.
|
| 148 |
|
| 149 |
+
For 'modular' type: Combines base instructions + persona-specific module.
|
| 150 |
+
For other types: Loads single file.
|
| 151 |
+
"""
|
| 152 |
+
# Handle modular prompts (base + persona-specific)
|
| 153 |
+
if prompt_type.lower() == "modular":
|
| 154 |
+
base_path = PROMPTS_DIR / "_base_instructions.txt"
|
| 155 |
+
persona_path = PROMPTS_DIR / f"{persona.lower()}_modular.txt"
|
| 156 |
|
| 157 |
+
prompt_parts = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
# Load base instructions
|
| 160 |
+
if base_path.exists():
|
| 161 |
+
with open(base_path, "r", encoding="utf-8") as f:
|
| 162 |
+
prompt_parts.append(f.read())
|
| 163 |
+
else:
|
| 164 |
+
logger.warning(f"Base instructions not found: {base_path}")
|
| 165 |
+
|
| 166 |
+
# Load persona-specific module
|
| 167 |
+
if persona_path.exists():
|
| 168 |
+
with open(persona_path, "r", encoding="utf-8") as f:
|
| 169 |
+
prompt_parts.append(f.read())
|
| 170 |
+
else:
|
| 171 |
+
logger.error(f"Persona module not found: {persona_path}")
|
| 172 |
+
return ""
|
| 173 |
+
|
| 174 |
+
# Combine with clear separation
|
| 175 |
+
return "\n\n" + "="*80 + "\n\n".join(prompt_parts)
|
| 176 |
+
|
| 177 |
+
# Handle standard single-file prompts
|
| 178 |
+
filename = f"{persona.lower()}_{prompt_type.lower()}.txt"
|
| 179 |
+
prompt_path = PROMPTS_DIR / filename
|
| 180 |
|
| 181 |
+
if prompt_path.exists():
|
| 182 |
+
with open(prompt_path, "r", encoding="utf-8") as f:
|
| 183 |
+
return f.read()
|
| 184 |
+
else:
|
| 185 |
+
logger.error(f"Prompt file not found: {prompt_path}")
|
| 186 |
+
return ""
|
| 187 |
|
| 188 |
|
| 189 |
+
def get_available_prompt_types(persona: str) -> list:
|
| 190 |
+
"""Get all available prompt types for a persona by scanning prompts directory."""
|
| 191 |
+
persona_lower = persona.lower()
|
| 192 |
+
prompt_files = list(PROMPTS_DIR.glob(f"{persona_lower}_*.txt"))
|
| 193 |
+
|
| 194 |
+
# Extract prompt types from filenames
|
| 195 |
+
types = []
|
| 196 |
+
for file in prompt_files:
|
| 197 |
+
# Format: {persona}_{type}.txt
|
| 198 |
+
type_name = file.stem.replace(f"{persona_lower}_", "")
|
| 199 |
+
types.append(type_name.capitalize())
|
| 200 |
+
|
| 201 |
+
return sorted(types)
|
| 202 |
|
| 203 |
|
| 204 |
+
def save_custom_prompt(persona: str, prompt_type_name: str, prompt_text: str) -> bool:
|
| 205 |
+
"""Save a custom prompt to disk."""
|
| 206 |
try:
|
| 207 |
+
filename = f"{persona.lower()}_{prompt_type_name.lower()}.txt"
|
| 208 |
+
prompt_path = PROMPTS_DIR / filename
|
| 209 |
+
|
| 210 |
+
# Don't overwrite compressed or full
|
| 211 |
+
if prompt_type_name.lower() in ['compressed', 'full']:
|
| 212 |
+
logger.error(f"Cannot overwrite default prompt types: {prompt_type_name}")
|
| 213 |
+
return False
|
| 214 |
+
|
| 215 |
+
with open(prompt_path, "w", encoding="utf-8") as f:
|
| 216 |
+
f.write(prompt_text)
|
| 217 |
+
|
| 218 |
+
logger.info(f"Saved custom prompt: {prompt_path}")
|
| 219 |
+
return True
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f"Error saving custom prompt: {e}")
|
| 222 |
+
return False
|
| 223 |
|
| 224 |
+
|
| 225 |
+
def load_prompt(persona: str, prompt_type: str) -> str:
|
| 226 |
+
"""Load the system prompt based on persona and type.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
Priority: custom_prompt (edited in UI) > loaded file
|
| 229 |
+
"""
|
| 230 |
+
# Use edited prompt if it exists
|
| 231 |
+
if st.session_state.prompt_edited and st.session_state.custom_prompt:
|
| 232 |
+
return st.session_state.custom_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
# Otherwise load from file
|
| 235 |
+
return load_prompt_from_file(persona, prompt_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
|
| 238 |
+
@st.cache_resource
|
| 239 |
+
def load_model(model_key: str):
|
| 240 |
+
"""Load and cache the model (Ollama for local, vLLM for deployed)."""
|
| 241 |
+
try:
|
| 242 |
+
# Get model configuration
|
| 243 |
+
models = config.get_models()
|
| 244 |
+
if model_key not in models:
|
| 245 |
+
st.error(f"Model '{model_key}' not found in configuration")
|
| 246 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
model_info = models[model_key]
|
| 249 |
+
model_name = model_info['name']
|
| 250 |
+
|
| 251 |
+
if IS_DEPLOYED:
|
| 252 |
+
# DEPLOYED: Use vLLM for optimized inference on T4 GPU
|
| 253 |
|
| 254 |
+
# Lazy import vLLM (takes 30-60s due to CUDA initialization)
|
| 255 |
+
# This keeps app startup fast - vLLM only imported when user loads model
|
| 256 |
+
logger.info("Importing vLLM (this may take 30-60 seconds)...")
|
| 257 |
+
from vllm import LLM, SamplingParams
|
| 258 |
+
logger.info("vLLM imported successfully")
|
| 259 |
+
|
| 260 |
+
model_path = model_info['hf_id']
|
| 261 |
+
|
| 262 |
+
logger.info(f"Loading model with vLLM: {model_name} ({model_path})")
|
| 263 |
+
|
| 264 |
+
with st.spinner(f"Loading {model_name} with vLLM... This may take a few minutes."):
|
| 265 |
+
# Initialize vLLM with optimized settings for T4
|
| 266 |
+
llm = LLM(
|
| 267 |
+
model=model_path,
|
| 268 |
+
download_dir=str(HF_CACHE_DIR) if HF_CACHE_DIR else None,
|
| 269 |
+
dtype="half", # FP16 for T4
|
| 270 |
+
gpu_memory_utilization=0.90, # Use 90% of GPU memory
|
| 271 |
+
max_model_len=4096, # Adjust based on model and T4 VRAM
|
| 272 |
+
trust_remote_code=False,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
logger.info(f"vLLM model loaded successfully: {model_name}")
|
| 276 |
+
return {"type": "vllm", "llm": llm, "model_key": model_key, "model_name": model_name}
|
| 277 |
+
else:
|
| 278 |
+
# LOCAL: Use Ollama for Apple Silicon optimization
|
| 279 |
+
ollama_model = model_info.get('ollama_id', 'mistral')
|
| 280 |
+
|
| 281 |
+
logger.info(f"Checking Ollama for model: {ollama_model}")
|
| 282 |
+
|
| 283 |
+
with st.spinner(f"Checking Ollama model {ollama_model}..."):
|
| 284 |
+
try:
|
| 285 |
+
# Check if Ollama is running and get available models
|
| 286 |
+
available_models_response = ollama.list()
|
| 287 |
+
|
| 288 |
+
# Extract model names from response
|
| 289 |
+
# Response format: {'models': [{'name': '...', 'model': '...', ...}]}
|
| 290 |
+
models = available_models_response.get('models', [])
|
| 291 |
+
model_names = []
|
| 292 |
+
|
| 293 |
+
for m in models:
|
| 294 |
+
# Handle both 'name' and 'model' keys
|
| 295 |
+
name = m.get('name') or m.get('model', '')
|
| 296 |
+
if name:
|
| 297 |
+
model_names.append(name)
|
| 298 |
+
|
| 299 |
+
logger.info(f"Available Ollama models: {model_names}")
|
| 300 |
+
|
| 301 |
+
# Check if requested model is available (check base name without tag)
|
| 302 |
+
model_available = any(
|
| 303 |
+
ollama_model in name or name.startswith(ollama_model.split(':')[0])
|
| 304 |
+
for name in model_names
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if not model_available:
|
| 308 |
+
st.warning(f"Model '{ollama_model}' not found locally. Pulling...")
|
| 309 |
+
logger.info(f"Pulling Ollama model: {ollama_model}")
|
| 310 |
+
|
| 311 |
+
# Pull the model
|
| 312 |
+
with st.spinner(f"Downloading {ollama_model}... This may take a few minutes."):
|
| 313 |
+
ollama.pull(ollama_model)
|
| 314 |
+
|
| 315 |
+
st.success(f"β
Model '{ollama_model}' downloaded successfully!")
|
| 316 |
+
|
| 317 |
+
logger.info(f"Ollama ready with model: {ollama_model}")
|
| 318 |
+
return {"type": "ollama", "model_name": ollama_model, "model_key": model_key}
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
st.error(f"Ollama error: {str(e)}")
|
| 322 |
+
st.info("Make sure Ollama is running: `ollama serve`")
|
| 323 |
+
logger.error(f"Ollama error: {e}", exc_info=True)
|
| 324 |
+
return None
|
| 325 |
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"Error loading model: {str(e)}", exc_info=True)
|
| 328 |
+
st.error(f"Error loading model: {str(e)}")
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def format_chat_prompt(system_prompt: str, conversation_history: list, few_shot_examples: str = "") -> str:
|
| 333 |
+
"""Format the conversation history with system prompt and optional few-shot examples."""
|
| 334 |
+
# Build conversation string
|
| 335 |
+
formatted = system_prompt
|
| 336 |
+
|
| 337 |
+
# Add few-shot examples if provided
|
| 338 |
+
if few_shot_examples and few_shot_examples.strip():
|
| 339 |
+
formatted += "\n\n### Example Conversations\n\n" + few_shot_examples.strip()
|
| 340 |
+
|
| 341 |
+
formatted += "\n\n---\n\n"
|
| 342 |
+
|
| 343 |
+
for msg in conversation_history:
|
| 344 |
+
if msg["role"] == "user":
|
| 345 |
+
formatted += f"User: {msg['content']}\n\n"
|
| 346 |
+
elif msg["role"] == "assistant":
|
| 347 |
+
formatted += f"FempowerBot: {msg['content']}\n\n"
|
| 348 |
|
| 349 |
+
# Add prompt for next response
|
| 350 |
+
formatted += "FempowerBot:"
|
|
|
|
| 351 |
|
| 352 |
+
return formatted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
|
|
|
|
|
|
| 354 |
|
| 355 |
+
def generate_response(model_obj, prompt: str) -> tuple[str, float]:
|
| 356 |
+
"""Generate a response from the model (Ollama for local, vLLM for deployed).
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
tuple: (response_text, generation_time_seconds)
|
| 360 |
+
"""
|
| 361 |
+
import time
|
| 362 |
+
start_time = time.time()
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
# Use custom params if set, otherwise use config defaults
|
| 366 |
+
gen_params = st.session_state.custom_gen_params if st.session_state.custom_gen_params else config.get_generation_params()
|
| 367 |
+
|
| 368 |
+
if model_obj["type"] == "ollama":
|
| 369 |
+
# LOCAL: Use Ollama with proper conversation history
|
| 370 |
|
| 371 |
+
# Get model type from config to handle response format
|
| 372 |
+
model_key = model_obj.get('model_key')
|
| 373 |
+
models = config.get_models()
|
| 374 |
+
model_type = models.get(model_key, {}).get('model_type', 'standard')
|
| 375 |
|
| 376 |
+
# Build options for Ollama
|
| 377 |
+
max_tokens = gen_params.get('max_new_tokens', 2000)
|
| 378 |
+
ollama_options = {
|
| 379 |
+
"temperature": gen_params.get('temperature', 0.8),
|
| 380 |
+
"num_predict": max_tokens,
|
| 381 |
+
"top_p": gen_params.get('top_p', 0.9),
|
| 382 |
+
"top_k": gen_params.get('top_k', 50),
|
| 383 |
+
}
|
| 384 |
|
| 385 |
+
# Build proper message history instead of one big prompt string
|
| 386 |
+
# Extract system prompt and conversation from the formatted prompt
|
| 387 |
+
# The prompt is: system_prompt + "---" + conversation history
|
| 388 |
+
messages = []
|
| 389 |
|
| 390 |
+
# Add system message
|
| 391 |
+
if prompt.startswith("You are"):
|
| 392 |
+
parts = prompt.split("\n\n---\n\n", 1)
|
| 393 |
+
system_prompt_text = parts[0]
|
| 394 |
+
messages.append({"role": "system", "content": system_prompt_text})
|
| 395 |
+
|
| 396 |
+
# Parse conversation history if present
|
| 397 |
+
if len(parts) > 1:
|
| 398 |
+
conv_text = parts[1]
|
| 399 |
+
# Split by "User:" and "FempowerBot:"
|
| 400 |
+
lines = conv_text.split('\n\n')
|
| 401 |
+
for line in lines:
|
| 402 |
+
line = line.strip()
|
| 403 |
+
if line.startswith("User:"):
|
| 404 |
+
messages.append({"role": "user", "content": line[5:].strip()})
|
| 405 |
+
elif line.startswith("FempowerBot:") and len(line) > 12:
|
| 406 |
+
messages.append({"role": "assistant", "content": line[12:].strip()})
|
| 407 |
+
else:
|
| 408 |
+
# Fallback: use as single user message
|
| 409 |
+
messages = [{"role": "user", "content": prompt}]
|
| 410 |
|
| 411 |
+
response = ollama.chat(
|
| 412 |
+
model=model_obj['model_name'],
|
| 413 |
+
messages=messages,
|
| 414 |
+
options=ollama_options
|
| 415 |
+
)
|
| 416 |
|
| 417 |
+
# Extract content from response based on model type
|
| 418 |
+
message = response.get('message', {})
|
|
|
|
| 419 |
|
| 420 |
+
if model_type == "reasoning":
|
| 421 |
+
# Reasoning models (Qwen3, DeepSeek-R1) use 'thinking' field for internal reasoning
|
| 422 |
+
# and 'content' for final response. If content is empty, extract from thinking.
|
| 423 |
+
generated = message.get('content', '').strip()
|
| 424 |
+
|
| 425 |
+
if not generated:
|
| 426 |
+
# Fallback: extract actual response from thinking field
|
| 427 |
+
thinking = message.get('thinking', '').strip()
|
| 428 |
+
if thinking:
|
| 429 |
+
logger.info("Extracting response from 'thinking' field (reasoning model)")
|
| 430 |
+
|
| 431 |
+
# Try to find where the model formulated the actual response
|
| 432 |
+
# Look for patterns like "Final response:", "I'll write:", quotation marks with response
|
| 433 |
+
best_match = None
|
| 434 |
+
|
| 435 |
+
# Pattern 1: Look for quoted responses (most reliable)
|
| 436 |
+
import re
|
| 437 |
+
quoted = re.findall(r'"([^"]+(?:\?|\.)[^"]*)"', thinking)
|
| 438 |
+
if quoted:
|
| 439 |
+
# Take the longest quoted response that looks like a conversational reply
|
| 440 |
+
best_match = max(quoted, key=len) if len(quoted[-1]) > 30 else quoted[-1]
|
| 441 |
+
|
| 442 |
+
# Pattern 2: Look for "Final response" or similar markers
|
| 443 |
+
if not best_match:
|
| 444 |
+
for delimiter in ['Final response (', 'I\'ll write:\n"', 'Revised to ', 'Final response:\n"']:
|
| 445 |
+
if delimiter in thinking:
|
| 446 |
+
parts = thinking.split(delimiter, 1)
|
| 447 |
+
if len(parts) > 1:
|
| 448 |
+
# Extract text in quotes after delimiter
|
| 449 |
+
text_after = parts[1]
|
| 450 |
+
match = re.search(r'"([^"]+)"', text_after)
|
| 451 |
+
if match:
|
| 452 |
+
best_match = match.group(1)
|
| 453 |
+
break
|
| 454 |
+
|
| 455 |
+
generated = best_match if best_match else "I apologize, but I couldn't generate a proper response."
|
| 456 |
+
else:
|
| 457 |
+
logger.error(f"Empty response from reasoning model. Full response: {response}")
|
| 458 |
+
generated = "I apologize, but I couldn't generate a response."
|
| 459 |
+
else:
|
| 460 |
+
# Standard models put response directly in 'content'
|
| 461 |
+
generated = message.get('content', '').strip()
|
| 462 |
+
|
| 463 |
+
if not generated:
|
| 464 |
+
logger.error(f"Empty response from Ollama. Full response: {response}")
|
| 465 |
+
generated = "I apologize, but I couldn't generate a response."
|
| 466 |
|
| 467 |
+
else:
|
| 468 |
+
# DEPLOYED: Use vLLM
|
| 469 |
+
# Note: vLLM is already imported in load_model() before this is called
|
| 470 |
+
from vllm import SamplingParams
|
| 471 |
|
| 472 |
+
llm = model_obj['llm']
|
| 473 |
|
| 474 |
+
# Create sampling parameters
|
| 475 |
+
sampling_params = SamplingParams(
|
| 476 |
+
temperature=gen_params.get('temperature', 0.8),
|
| 477 |
+
top_p=gen_params.get('top_p', 0.9),
|
| 478 |
+
top_k=gen_params.get('top_k', 50),
|
| 479 |
+
max_tokens=gen_params.get('max_new_tokens', 200),
|
| 480 |
+
repetition_penalty=gen_params.get('repetition_penalty', 1.1),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
)
|
| 482 |
+
|
| 483 |
+
# Generate
|
| 484 |
+
outputs = llm.generate([prompt], sampling_params)
|
| 485 |
+
generated = outputs[0].outputs[0].text.strip()
|
| 486 |
|
| 487 |
+
# Clean up response
|
| 488 |
+
if generated.startswith("FempowerBot:"):
|
| 489 |
+
generated = generated[len("FempowerBot:"):].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
+
if "User:" in generated:
|
| 492 |
+
generated = generated.split("User:")[0].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
+
# Calculate generation time
|
| 495 |
+
generation_time = time.time() - start_time
|
| 496 |
+
logger.info(f"Response generated in {generation_time:.2f}s ({len(generated)} chars)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
+
return generated, generation_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
+
except Exception as e:
|
| 501 |
+
generation_time = time.time() - start_time
|
| 502 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 503 |
+
st.error(f"Error generating response: {str(e)}")
|
| 504 |
+
return "I apologize, but I encountered an error generating a response.", generation_time
|
| 505 |
+
|
| 506 |
|
| 507 |
+
def save_conversation():
|
| 508 |
+
"""Save the current conversation to disk."""
|
| 509 |
+
if not config.get("storage.save_conversations", True):
|
| 510 |
+
return
|
| 511 |
|
| 512 |
+
try:
|
| 513 |
+
# Ensure directory exists
|
| 514 |
+
CONVERSATIONS_DIR.mkdir(parents=True, exist_ok=True)
|
| 515 |
|
| 516 |
+
# Create conversation data
|
| 517 |
+
conversation_data = {
|
| 518 |
+
"session_id": st.session_state.session_id,
|
| 519 |
+
"timestamp": datetime.now().isoformat(),
|
| 520 |
+
"model": st.session_state.current_model,
|
| 521 |
+
"persona": st.session_state.current_persona,
|
| 522 |
+
"prompt_type": st.session_state.current_prompt_type,
|
| 523 |
+
"messages": st.session_state.messages
|
| 524 |
+
}
|
| 525 |
|
| 526 |
+
# Save to file
|
| 527 |
+
filename = f"conversation_{st.session_state.session_id}.json"
|
| 528 |
+
filepath = CONVERSATIONS_DIR / filename
|
| 529 |
|
| 530 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 531 |
+
json.dump(conversation_data, f, indent=2, ensure_ascii=False)
|
| 532 |
|
| 533 |
+
logger.debug(f"Conversation saved: {filepath}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
+
except Exception as e:
|
| 536 |
+
logger.error(f"Error saving conversation: {str(e)}", exc_info=True)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# ===== Main UI =====
|
| 540 |
+
st.title("π¬ FempowerBot Training Simulator")
|
| 541 |
+
st.markdown("""
|
| 542 |
+
Practice your communication strategies with realistic conversation partners who are
|
| 543 |
+
**unfamiliar**, **skeptical**, or **antagonistic** toward feminism and gender equality.
|
| 544 |
+
""")
|
| 545 |
+
|
| 546 |
+
# ===== Sidebar for configuration =====
|
| 547 |
+
with st.sidebar:
|
| 548 |
+
st.header("βοΈ Configuration")
|
| 549 |
+
|
| 550 |
+
# Model selection
|
| 551 |
+
st.subheader("π€ Select Model")
|
| 552 |
+
|
| 553 |
+
# Get available models from config
|
| 554 |
+
available_models = config.get_models()
|
| 555 |
+
model_display_names = {key: info['name'] for key, info in available_models.items()}
|
| 556 |
+
|
| 557 |
+
selected_model_key = st.selectbox(
|
| 558 |
+
"Choose Language Model",
|
| 559 |
+
options=list(model_display_names.keys()),
|
| 560 |
+
format_func=lambda x: model_display_names[x],
|
| 561 |
+
help="Select a model that fits within T4 GPU constraints"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Show model info
|
| 565 |
+
model_info = available_models[selected_model_key]
|
| 566 |
+
with st.expander("βΉοΈ Model Info"):
|
| 567 |
+
st.write(f"**Recommended GPU:** {model_info.get('recommended_gpu', 'N/A')}")
|
| 568 |
+
st.write(f"**VRAM Required:** {model_info.get('vram_required', 'N/A')}")
|
| 569 |
+
st.write(f"**Context Length:** {model_info.get('context_length', 'N/A')}")
|
| 570 |
+
|
| 571 |
+
# Persona selection
|
| 572 |
+
st.subheader("π Select Persona")
|
| 573 |
+
|
| 574 |
+
# Get personas from config
|
| 575 |
+
personas_list = config.get_personas()
|
| 576 |
+
persona_options = [p['name'] for p in personas_list]
|
| 577 |
+
|
| 578 |
+
selected_persona = st.radio(
|
| 579 |
+
"Choose conversation partner type",
|
| 580 |
+
options=persona_options,
|
| 581 |
+
help="Select who you want to practice talking with"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Prompt type selection
|
| 585 |
+
st.subheader("π Prompt Type")
|
| 586 |
+
|
| 587 |
+
# Get available prompt types for selected persona (scans files dynamically)
|
| 588 |
+
available_types = get_available_prompt_types(selected_persona)
|
| 589 |
+
|
| 590 |
+
if not available_types:
|
| 591 |
+
# Fallback to config if no files found
|
| 592 |
+
prompt_types = config.get("prompts.types", [])
|
| 593 |
+
available_types = [p['name'] for p in prompt_types]
|
| 594 |
+
|
| 595 |
+
selected_prompt_type = st.radio(
|
| 596 |
+
"Choose prompt type",
|
| 597 |
+
options=available_types,
|
| 598 |
+
help="Select prompt type. Custom types appear here after saving."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Load model button
|
| 602 |
+
if st.button("π Load Model", type="primary"):
|
| 603 |
+
model_obj = load_model(selected_model_key)
|
| 604 |
+
if model_obj:
|
| 605 |
+
st.session_state.model_loaded = True
|
| 606 |
+
st.session_state.current_model = selected_model_key
|
| 607 |
+
st.session_state.model_obj = model_obj
|
| 608 |
+
backend = "Ollama (Apple Silicon)" if model_obj["type"] == "ollama" else "vLLM (T4 GPU)"
|
| 609 |
+
st.success(f"β
{model_display_names[selected_model_key]} loaded via {backend}!")
|
| 610 |
+
|
| 611 |
+
# Reset conversation button
|
| 612 |
+
if st.button("π Reset Conversation"):
|
| 613 |
+
st.session_state.messages = []
|
| 614 |
+
st.session_state.response_times = []
|
| 615 |
+
st.session_state.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 616 |
+
st.rerun()
|
| 617 |
+
|
| 618 |
+
# # Debug info (if debug mode enabled)
|
| 619 |
+
# if os.getenv("DEBUG", "false").lower() == "true":
|
| 620 |
+
# st.divider()
|
| 621 |
+
# st.caption("π§ Debug Info")
|
| 622 |
+
# st.caption(f"Deployed: {IS_DEPLOYED}")
|
| 623 |
+
# st.caption(f"Session: {st.session_state.session_id}")
|
| 624 |
+
# st.caption(f"Messages: {len(st.session_state.messages)}")
|
| 625 |
+
|
| 626 |
+
# ===== Main interface with tabs =====
|
| 627 |
+
if not st.session_state.model_loaded:
|
| 628 |
+
st.info("π Please select and load a model from the sidebar to begin.")
|
| 629 |
+
else:
|
| 630 |
+
# Check if settings changed
|
| 631 |
+
settings_changed = (
|
| 632 |
+
selected_persona != st.session_state.current_persona or
|
| 633 |
+
selected_prompt_type != st.session_state.current_prompt_type
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
if settings_changed and len(st.session_state.messages) > 0:
|
| 637 |
+
st.warning("β οΈ Settings changed! Click 'Reset Conversation' to apply new settings.")
|
| 638 |
+
|
| 639 |
+
# Update current settings
|
| 640 |
+
st.session_state.current_persona = selected_persona
|
| 641 |
+
st.session_state.current_prompt_type = selected_prompt_type
|
| 642 |
+
|
| 643 |
+
# Create tabs
|
| 644 |
+
tab_chat, tab_config = st.tabs(["π¬ Chat", "βοΈ Configuration"])
|
| 645 |
+
|
| 646 |
+
with tab_chat:
|
| 647 |
+
# ===== Chat Tab - Display messages only =====
|
| 648 |
+
for message in st.session_state.messages:
|
| 649 |
+
with st.chat_message(message["role"]):
|
| 650 |
+
st.markdown(message["content"])
|
| 651 |
+
# Show timing for assistant messages
|
| 652 |
+
if message["role"] == "assistant" and "gen_time" in message:
|
| 653 |
+
st.caption(f"β±οΈ {message['gen_time']:.2f}s")
|
| 654 |
+
|
| 655 |
+
with tab_config:
|
| 656 |
+
# ===== Configuration Tab =====
|
| 657 |
+
st.subheader("π― Advanced Configuration")
|
| 658 |
|
| 659 |
+
col1, col2 = st.columns([2, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
+
with col1:
|
| 662 |
+
# ===== System Prompt Editor =====
|
| 663 |
+
st.markdown("### π System Prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
+
# Auto-load prompt when persona/type changes
|
| 666 |
+
current_selection = f"{selected_persona}_{selected_prompt_type}"
|
| 667 |
+
if "last_prompt_selection" not in st.session_state:
|
| 668 |
+
st.session_state.last_prompt_selection = current_selection
|
| 669 |
|
| 670 |
+
if st.session_state.last_prompt_selection != current_selection:
|
| 671 |
+
# Selection changed - reload prompt from file
|
| 672 |
+
loaded_text = load_prompt_from_file(selected_persona, selected_prompt_type)
|
| 673 |
+
st.session_state.loaded_prompt_text = loaded_text
|
| 674 |
+
st.session_state.custom_prompt = loaded_text
|
| 675 |
+
st.session_state.prompt_edited = False
|
| 676 |
+
st.session_state.last_prompt_selection = current_selection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
+
# Initialize loaded prompt if empty
|
| 679 |
+
if not st.session_state.loaded_prompt_text:
|
| 680 |
+
st.session_state.loaded_prompt_text = load_prompt_from_file(selected_persona, selected_prompt_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
+
# Show current selection info
|
| 683 |
+
st.caption(f"π Currently loaded: **{selected_persona} / {selected_prompt_type}**")
|
| 684 |
+
if st.session_state.prompt_edited:
|
| 685 |
+
st.caption("βοΈ *Prompt has been edited (not saved to disk)*")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
|
| 687 |
+
# Prompt editor
|
| 688 |
+
custom_prompt = st.text_area(
|
| 689 |
+
"Edit system prompt",
|
| 690 |
+
value=st.session_state.custom_prompt if st.session_state.custom_prompt else st.session_state.loaded_prompt_text,
|
| 691 |
+
height=300,
|
| 692 |
+
key="prompt_editor",
|
| 693 |
+
help="Edit the prompt. Changes apply immediately to chat (RAM only)."
|
| 694 |
+
)
|
|
|
|
| 695 |
|
| 696 |
+
# Track if edited
|
| 697 |
+
if custom_prompt != st.session_state.loaded_prompt_text:
|
| 698 |
+
st.session_state.prompt_edited = True
|
| 699 |
+
st.session_state.custom_prompt = custom_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
else:
|
| 701 |
+
st.session_state.prompt_edited = False
|
| 702 |
+
st.session_state.custom_prompt = None
|
| 703 |
+
|
| 704 |
+
# Action buttons
|
| 705 |
+
col_a, col_b, col_c = st.columns(3)
|
| 706 |
+
|
| 707 |
+
with col_a:
|
| 708 |
+
if st.button("π Reload from File"):
|
| 709 |
+
loaded_text = load_prompt_from_file(selected_persona, selected_prompt_type)
|
| 710 |
+
st.session_state.loaded_prompt_text = loaded_text
|
| 711 |
+
st.session_state.custom_prompt = loaded_text
|
| 712 |
+
st.session_state.prompt_edited = False
|
| 713 |
+
st.success("β
Reloaded from file!")
|
| 714 |
+
st.rerun()
|
| 715 |
+
|
| 716 |
+
with col_b:
|
| 717 |
+
if st.button("πΎ Save as New Type"):
|
| 718 |
+
st.session_state.show_save_dialog = True
|
| 719 |
+
|
| 720 |
+
with col_c:
|
| 721 |
+
if st.session_state.prompt_edited:
|
| 722 |
+
if st.button("β Discard Changes"):
|
| 723 |
+
st.session_state.custom_prompt = st.session_state.loaded_prompt_text
|
| 724 |
+
st.session_state.prompt_edited = False
|
| 725 |
+
st.rerun()
|
| 726 |
+
|
| 727 |
+
# Save as new type dialog
|
| 728 |
+
if st.session_state.get("show_save_dialog", False):
|
| 729 |
+
with st.form("save_prompt_form"):
|
| 730 |
+
st.markdown("#### πΎ Save as New Prompt Type")
|
| 731 |
+
new_type_name = st.text_input(
|
| 732 |
+
"Prompt Type Name",
|
| 733 |
+
placeholder="e.g., detailed, brief, custom1",
|
| 734 |
+
help="Avoid 'compressed' and 'full' (reserved)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
)
|
| 736 |
|
| 737 |
+
col_save, col_cancel = st.columns(2)
|
| 738 |
+
with col_save:
|
| 739 |
+
save_submitted = st.form_submit_button("β
Save", type="primary")
|
| 740 |
+
with col_cancel:
|
| 741 |
+
cancel_submitted = st.form_submit_button("β Cancel")
|
| 742 |
|
| 743 |
+
if save_submitted and new_type_name:
|
| 744 |
+
if new_type_name.lower() in ['compressed', 'full']:
|
| 745 |
+
st.error("β Cannot use reserved names: 'compressed' or 'full'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
else:
|
| 747 |
+
success = save_custom_prompt(selected_persona, new_type_name, custom_prompt)
|
| 748 |
+
if success:
|
| 749 |
+
st.success(f"β
Saved as: {selected_persona.lower()}_{new_type_name.lower()}.txt")
|
| 750 |
+
st.session_state.show_save_dialog = False
|
| 751 |
+
st.rerun()
|
| 752 |
+
else:
|
| 753 |
+
st.error("β Failed to save prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
+
if cancel_submitted:
|
| 756 |
+
st.session_state.show_save_dialog = False
|
| 757 |
+
st.rerun()
|
| 758 |
+
|
| 759 |
+
# Preview expander
|
| 760 |
+
with st.expander("ποΈ Preview Current Prompt", expanded=False):
|
| 761 |
+
current = st.session_state.custom_prompt if st.session_state.prompt_edited else st.session_state.loaded_prompt_text
|
| 762 |
+
st.code(current, language="text")
|
| 763 |
+
|
| 764 |
+
# ===== Few-Shot Examples =====
|
| 765 |
+
st.markdown("### π Few-Shot Examples")
|
| 766 |
+
st.caption("Add example conversations to guide the bot's responses (currently empty by default)")
|
| 767 |
+
|
| 768 |
+
with st.expander("β Add Few-Shot Examples", expanded=False):
|
| 769 |
+
few_shot_text = st.text_area(
|
| 770 |
+
"Paste example conversation (format: User: ... / FempowerBot: ...)",
|
| 771 |
+
value=st.session_state.few_shot_examples,
|
| 772 |
+
height=200,
|
| 773 |
+
placeholder="Example:\n\nUser: What is feminism?\n\nFempowerBot: It's about equal rights for everyone, regardless of gender.\n\nUser: That makes sense!\n\nFempowerBot: Glad I could help clarify!",
|
| 774 |
+
help="Provide multi-turn conversation examples to improve bot responses"
|
| 775 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
+
if st.button("πΎ Save Few-Shot Examples"):
|
| 778 |
+
st.session_state.few_shot_examples = few_shot_text
|
| 779 |
+
st.success("β
Few-shot examples saved!")
|
| 780 |
+
st.rerun()
|
| 781 |
+
|
| 782 |
+
with col2:
|
| 783 |
+
# ===== Generation Parameters =====
|
| 784 |
+
st.markdown("### βοΈ Generation Parameters")
|
| 785 |
+
|
| 786 |
+
# Get default params
|
| 787 |
+
default_params = config.get_generation_params()
|
| 788 |
+
current_params = st.session_state.custom_gen_params if st.session_state.custom_gen_params else default_params
|
| 789 |
+
|
| 790 |
+
max_tokens = st.number_input(
|
| 791 |
+
"Max Tokens",
|
| 792 |
+
min_value=50,
|
| 793 |
+
max_value=4000,
|
| 794 |
+
value=current_params.get('max_new_tokens', 2000),
|
| 795 |
+
step=50,
|
| 796 |
+
help="Maximum length of generated response"
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
temperature = st.slider(
|
| 800 |
+
"Temperature",
|
| 801 |
+
min_value=0.0,
|
| 802 |
+
max_value=2.0,
|
| 803 |
+
value=current_params.get('temperature', 0.8),
|
| 804 |
+
step=0.1,
|
| 805 |
+
help="Higher = more creative, Lower = more focused"
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
top_p = st.slider(
|
| 809 |
+
"Top P",
|
| 810 |
+
min_value=0.0,
|
| 811 |
+
max_value=1.0,
|
| 812 |
+
value=current_params.get('top_p', 0.9),
|
| 813 |
+
step=0.05,
|
| 814 |
+
help="Nucleus sampling threshold"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
top_k = st.number_input(
|
| 818 |
+
"Top K",
|
| 819 |
+
min_value=1,
|
| 820 |
+
max_value=100,
|
| 821 |
+
value=current_params.get('top_k', 50),
|
| 822 |
+
step=5,
|
| 823 |
+
help="Sample from top K tokens"
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
repetition_penalty = st.slider(
|
| 827 |
+
"Repetition Penalty",
|
| 828 |
+
min_value=1.0,
|
| 829 |
+
max_value=2.0,
|
| 830 |
+
value=current_params.get('repetition_penalty', 1.1),
|
| 831 |
+
step=0.05,
|
| 832 |
+
help="Penalty for repeating tokens"
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
if st.button("πΎ Apply Parameters"):
|
| 836 |
+
st.session_state.custom_gen_params = {
|
| 837 |
+
'max_new_tokens': max_tokens,
|
| 838 |
+
'temperature': temperature,
|
| 839 |
+
'top_p': top_p,
|
| 840 |
+
'top_k': top_k,
|
| 841 |
+
'repetition_penalty': repetition_penalty
|
| 842 |
+
}
|
| 843 |
+
st.success("β
Parameters applied!")
|
| 844 |
st.rerun()
|
| 845 |
+
|
| 846 |
+
if st.button("π Reset Parameters"):
|
| 847 |
+
st.session_state.custom_gen_params = None
|
| 848 |
+
st.success("β
Reset to defaults!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
st.rerun()
|
| 850 |
+
|
| 851 |
+
# ===== Performance Stats =====
|
| 852 |
+
if st.session_state.response_times:
|
| 853 |
+
st.markdown("### π Performance Stats")
|
| 854 |
+
avg_time = sum(r['time'] for r in st.session_state.response_times) / len(st.session_state.response_times)
|
| 855 |
+
total_responses = len(st.session_state.response_times)
|
| 856 |
+
|
| 857 |
+
st.metric("Avg Response Time", f"{avg_time:.2f}s")
|
| 858 |
+
st.metric("Total Responses", total_responses)
|
| 859 |
|
| 860 |
+
# ===== Chat Input (outside tabs - must be at this level) =====
|
| 861 |
+
if prompt := st.chat_input(config.get("app.chat_input_placeholder", "Type your message here...")):
|
| 862 |
+
# Add user message to chat
|
| 863 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 864 |
+
|
| 865 |
+
# Generate bot response
|
| 866 |
+
with st.spinner("Thinking..."):
|
| 867 |
+
# Load system prompt
|
| 868 |
+
system_prompt = load_prompt(
|
| 869 |
+
st.session_state.current_persona,
|
| 870 |
+
st.session_state.current_prompt_type
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# Format prompt with conversation history and few-shot examples
|
| 874 |
+
full_prompt = format_chat_prompt(
|
| 875 |
+
system_prompt,
|
| 876 |
+
st.session_state.messages,
|
| 877 |
+
st.session_state.few_shot_examples
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# Generate response
|
| 881 |
+
response, gen_time = generate_response(
|
| 882 |
+
st.session_state.model_obj,
|
| 883 |
+
full_prompt
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
# Track response time
|
| 887 |
+
st.session_state.response_times.append({
|
| 888 |
+
"time": gen_time,
|
| 889 |
+
"chars": len(response),
|
| 890 |
+
"timestamp": datetime.now().isoformat()
|
| 891 |
+
})
|
| 892 |
+
|
| 893 |
+
# Add assistant response to chat with timing info
|
| 894 |
+
st.session_state.messages.append({
|
| 895 |
+
"role": "assistant",
|
| 896 |
+
"content": response,
|
| 897 |
+
"gen_time": gen_time
|
| 898 |
+
})
|
| 899 |
+
|
| 900 |
+
# Save conversation
|
| 901 |
+
save_conversation()
|
| 902 |
+
|
| 903 |
+
st.rerun()
|
| 904 |
|
| 905 |
+
# ===== Footer =====
|
| 906 |
+
st.divider()
|
| 907 |
+
st.markdown("""
|
| 908 |
+
<div style='text-align: center; color: gray; font-size: 0.9em;'>
|
| 909 |
+
<p>π― <strong>FempowerBot</strong> is a training tool for practicing difficult conversations.</p>
|
| 910 |
+
<p>The bot stays in character to provide realistic practice scenarios.</p>
|
| 911 |
+
</div>
|
| 912 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|