Spaces:
Sleeping
Sleeping
Ara Yeroyan
commited on
Commit
Β·
3fc1b5f
1
Parent(s):
f8a1d41
finalize gemini version
Browse files- app.py +60 -7
- src/vectorstore.py +27 -8
app.py
CHANGED
|
@@ -33,6 +33,7 @@ from src.config.paths import (
|
|
| 33 |
CONVERSATIONS_DIR,
|
| 34 |
)
|
| 35 |
|
|
|
|
| 36 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 37 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
| 38 |
omp_threads = os.environ.get("OMP_NUM_THREADS", "")
|
|
@@ -72,6 +73,9 @@ if IS_DEPLOYED and HF_CACHE_DIR:
|
|
| 72 |
except (PermissionError, OSError):
|
| 73 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 74 |
pass
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Configure logging
|
| 77 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -191,18 +195,35 @@ def main():
|
|
| 191 |
if 'chatbot_version' not in st.session_state:
|
| 192 |
st.session_state.chatbot_version = "v1"
|
| 193 |
|
| 194 |
-
# Initialize chatbot based on version (
|
| 195 |
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
try:
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 200 |
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 201 |
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 202 |
-
|
| 203 |
except Exception as e:
|
| 204 |
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# Reset to v1 to prevent infinite loop
|
| 207 |
st.session_state.chatbot_version = "v1"
|
| 208 |
st.session_state['_last_version'] = "v1"
|
|
@@ -210,6 +231,7 @@ def main():
|
|
| 210 |
del st.session_state['chatbot']
|
| 211 |
st.stop() # Stop execution to prevent infinite loop
|
| 212 |
else:
|
|
|
|
| 213 |
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 214 |
|
| 215 |
# Reset conversation history if needed (but keep chatbot cached)
|
|
@@ -223,7 +245,38 @@ def main():
|
|
| 223 |
st.rerun()
|
| 224 |
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Show version info
|
| 229 |
if st.session_state.chatbot_version == "beta":
|
|
@@ -289,7 +342,7 @@ def main():
|
|
| 289 |
# Determine if filename filter is active
|
| 290 |
filename_mode = len(selected_filenames) > 0
|
| 291 |
# Sources filter
|
| 292 |
-
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 293 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 294 |
selected_sources = st.multiselect(
|
| 295 |
"Select sources:",
|
|
|
|
| 33 |
CONVERSATIONS_DIR,
|
| 34 |
)
|
| 35 |
|
| 36 |
+
|
| 37 |
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
|
| 38 |
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
|
| 39 |
omp_threads = os.environ.get("OMP_NUM_THREADS", "")
|
|
|
|
| 73 |
except (PermissionError, OSError):
|
| 74 |
# If we can't create it, log but continue (might already exist from Dockerfile)
|
| 75 |
pass
|
| 76 |
+
else:
|
| 77 |
+
from dotenv import load_dotenv
|
| 78 |
+
load_dotenv()
|
| 79 |
|
| 80 |
# Configure logging
|
| 81 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 195 |
if 'chatbot_version' not in st.session_state:
|
| 196 |
st.session_state.chatbot_version = "v1"
|
| 197 |
|
| 198 |
+
# Initialize chatbot based on version (only if not already initialized for this version)
|
| 199 |
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
|
| 200 |
+
|
| 201 |
+
# Check if we need to initialize: chatbot doesn't exist OR version changed
|
| 202 |
+
needs_init = (
|
| 203 |
+
chatbot_version_key not in st.session_state or
|
| 204 |
+
st.session_state.get('_last_version') != st.session_state.chatbot_version
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if needs_init:
|
| 208 |
try:
|
| 209 |
+
# Different spinner messages for different versions
|
| 210 |
+
if st.session_state.chatbot_version == "beta":
|
| 211 |
+
spinner_msg = "π Initializing Gemini File Search..."
|
| 212 |
+
else:
|
| 213 |
+
spinner_msg = "π Loading AI models and connecting to database..."
|
| 214 |
+
|
| 215 |
+
with st.spinner(spinner_msg):
|
| 216 |
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
|
| 217 |
st.session_state['_last_version'] = st.session_state.chatbot_version
|
| 218 |
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 219 |
+
print("β
AI system ready!")
|
| 220 |
except Exception as e:
|
| 221 |
st.error(f"β Failed to initialize chatbot: {str(e)}")
|
| 222 |
+
# Only show Gemini-specific error message for beta version
|
| 223 |
+
if st.session_state.chatbot_version == "beta":
|
| 224 |
+
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
|
| 225 |
+
else:
|
| 226 |
+
st.error("Please check your configuration and ensure all required models and databases are accessible.")
|
| 227 |
# Reset to v1 to prevent infinite loop
|
| 228 |
st.session_state.chatbot_version = "v1"
|
| 229 |
st.session_state['_last_version'] = "v1"
|
|
|
|
| 231 |
del st.session_state['chatbot']
|
| 232 |
st.stop() # Stop execution to prevent infinite loop
|
| 233 |
else:
|
| 234 |
+
# Chatbot already initialized for this version, just use it
|
| 235 |
st.session_state.chatbot = st.session_state[chatbot_version_key]
|
| 236 |
|
| 237 |
# Reset conversation history if needed (but keep chatbot cached)
|
|
|
|
| 245 |
st.rerun()
|
| 246 |
|
| 247 |
|
| 248 |
+
# Version selection radio button (top right)
|
| 249 |
+
col1, col2 = st.columns([3, 1])
|
| 250 |
+
with col1:
|
| 251 |
+
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 252 |
+
with col2:
|
| 253 |
+
st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
|
| 254 |
+
selected_version = st.radio(
|
| 255 |
+
"**Version:**",
|
| 256 |
+
options=["v1", "beta"],
|
| 257 |
+
index=0 if st.session_state.chatbot_version == "v1" else 1,
|
| 258 |
+
horizontal=True,
|
| 259 |
+
key="version_selector",
|
| 260 |
+
help="Select v1 (default RAG system) or beta (Gemini File Search)"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Update version if changed
|
| 264 |
+
if selected_version != st.session_state.chatbot_version:
|
| 265 |
+
# Store the old version to check if we need to switch
|
| 266 |
+
old_version = st.session_state.chatbot_version
|
| 267 |
+
st.session_state.chatbot_version = selected_version
|
| 268 |
+
|
| 269 |
+
# If chatbot for new version already exists, just switch to it
|
| 270 |
+
new_chatbot_key = f"chatbot_{selected_version}"
|
| 271 |
+
if new_chatbot_key in st.session_state:
|
| 272 |
+
# Chatbot already exists, just switch
|
| 273 |
+
st.session_state.chatbot = st.session_state[new_chatbot_key]
|
| 274 |
+
st.session_state['_last_version'] = selected_version
|
| 275 |
+
else:
|
| 276 |
+
# Need to initialize new version - will be handled by initialization logic above
|
| 277 |
+
st.session_state['_last_version'] = old_version # Set to old to trigger init check
|
| 278 |
+
|
| 279 |
+
st.rerun()
|
| 280 |
|
| 281 |
# Show version info
|
| 282 |
if st.session_state.chatbot_version == "beta":
|
|
|
|
| 342 |
# Determine if filename filter is active
|
| 343 |
filename_mode = len(selected_filenames) > 0
|
| 344 |
# Sources filter
|
| 345 |
+
# st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 346 |
st.markdown('<div class="filter-title">π Sources</div>', unsafe_allow_html=True)
|
| 347 |
selected_sources = st.multiselect(
|
| 348 |
"Select sources:",
|
src/vectorstore.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
| 1 |
"""Vector store management and operations."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Dict, Any, List, Optional
|
| 4 |
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain_qdrant import QdrantVectorStore
|
| 8 |
from langchain.docstore.document import Document
|
| 9 |
from langchain_core.embeddings import Embeddings
|
|
@@ -28,19 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
|
|
| 28 |
|
| 29 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
-
#
|
|
|
|
|
|
|
| 32 |
self.model = SentenceTransformer(
|
| 33 |
model_name,
|
| 34 |
truncate_dim=truncate_dim,
|
| 35 |
-
device="cpu" #
|
| 36 |
)
|
| 37 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 38 |
else:
|
| 39 |
# Use standard HuggingFaceEmbeddings
|
| 40 |
-
#
|
|
|
|
| 41 |
if "model_kwargs" not in kwargs:
|
| 42 |
kwargs["model_kwargs"] = {}
|
| 43 |
-
|
|
|
|
| 44 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 45 |
|
| 46 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -87,12 +102,14 @@ class VectorStoreManager:
|
|
| 87 |
model_name = self.config["retriever"]["model"]
|
| 88 |
normalize = self.config["retriever"]["normalize"]
|
| 89 |
|
| 90 |
-
# Fix for meta tensor issue:
|
| 91 |
-
#
|
| 92 |
-
#
|
| 93 |
model_kwargs = {
|
| 94 |
-
"device": "cpu" #
|
|
|
|
| 95 |
}
|
|
|
|
| 96 |
encode_kwargs = {
|
| 97 |
"normalize_embeddings": normalize,
|
| 98 |
"batch_size": 100,
|
|
@@ -119,6 +136,8 @@ class VectorStoreManager:
|
|
| 119 |
return embeddings
|
| 120 |
|
| 121 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
|
|
|
|
|
|
| 122 |
embeddings = HuggingFaceEmbeddings(
|
| 123 |
model_name=model_name,
|
| 124 |
model_kwargs=model_kwargs,
|
|
|
|
| 1 |
"""Vector store management and operations."""
|
| 2 |
+
import os
|
| 3 |
+
# Disable MPS before importing torch to prevent meta tensor issues on Mac
|
| 4 |
+
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
| 5 |
+
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
|
| 6 |
+
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Dict, Any, List, Optional
|
| 9 |
|
| 10 |
|
| 11 |
import torch
|
| 12 |
+
# Disable MPS backend explicitly to prevent meta tensor issues
|
| 13 |
+
if hasattr(torch.backends, 'mps'):
|
| 14 |
+
# Monkey patch to disable MPS
|
| 15 |
+
original_mps_available = torch.backends.mps.is_available
|
| 16 |
+
torch.backends.mps.is_available = lambda: False
|
| 17 |
+
|
| 18 |
from langchain_qdrant import QdrantVectorStore
|
| 19 |
from langchain.docstore.document import Document
|
| 20 |
from langchain_core.embeddings import Embeddings
|
|
|
|
| 39 |
|
| 40 |
if truncate_dim and "matryoshka" in model_name.lower():
|
| 41 |
# Use SentenceTransformer directly for Matryoshka models
|
| 42 |
+
# Fix for meta tensor issue: Explicitly force CPU
|
| 43 |
+
# MPS is already disabled at module level
|
| 44 |
+
# Explicitly pass device="cpu" to prevent MPS/CUDA detection
|
| 45 |
self.model = SentenceTransformer(
|
| 46 |
model_name,
|
| 47 |
truncate_dim=truncate_dim,
|
| 48 |
+
device="cpu" # Force CPU to prevent meta tensor issues
|
| 49 |
)
|
| 50 |
print(f"π§ Matryoshka model configured for {truncate_dim} dimensions")
|
| 51 |
else:
|
| 52 |
# Use standard HuggingFaceEmbeddings
|
| 53 |
+
# Don't pass device parameter - let it load naturally on CPU
|
| 54 |
+
# This prevents the meta tensor error
|
| 55 |
if "model_kwargs" not in kwargs:
|
| 56 |
kwargs["model_kwargs"] = {}
|
| 57 |
+
# Remove device from model_kwargs if present to prevent meta tensor issues
|
| 58 |
+
kwargs["model_kwargs"].pop("device", None)
|
| 59 |
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 60 |
|
| 61 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
| 102 |
model_name = self.config["retriever"]["model"]
|
| 103 |
normalize = self.config["retriever"]["normalize"]
|
| 104 |
|
| 105 |
+
# Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
|
| 106 |
+
# The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
|
| 107 |
+
# MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
|
| 108 |
model_kwargs = {
|
| 109 |
+
"device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
|
| 110 |
+
"trust_remote_code": True, # Some models need this
|
| 111 |
}
|
| 112 |
+
|
| 113 |
encode_kwargs = {
|
| 114 |
"normalize_embeddings": normalize,
|
| 115 |
"batch_size": 100,
|
|
|
|
| 136 |
return embeddings
|
| 137 |
|
| 138 |
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
| 139 |
+
# Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
|
| 140 |
+
# but ensure we're not using meta device
|
| 141 |
embeddings = HuggingFaceEmbeddings(
|
| 142 |
model_name=model_name,
|
| 143 |
model_kwargs=model_kwargs,
|