DeKCIB / app_no_login.py
InnovisionLLC's picture
Rename app.py to app_no_login.py
3250bf4
raw
history blame
30.3 kB
# app.py
from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
import gradio as gr
from llama_index.core import Settings
from llama_index.core import VectorStoreIndex, StorageContext, Response
from llama_index.vector_stores.duckdb import DuckDBVectorStore
from llama_index.llms.ollama import Ollama
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.evaluation import FaithfulnessEvaluator
import json
import ollama
import os
import uuid
import nest_asyncio
from huggingface_hub import snapshot_download
import html
from gradio.themes.utils import fonts, sizes
from gradio.themes import Base
import concurrent.futures
import time
nest_asyncio.apply()
# Create a custom theme with larger text
large_text_theme = Base(
# Increase all font sizes by ~25%
font=[fonts.GoogleFont("Roboto"), "ui-sans-serif", "sans-serif"],
font_mono=[fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"],
text_size=sizes.text_lg, # Base text size (default is text_md)
radius_size=sizes.radius_md,
)
CONFIG_PATH = "config.json"
PERSISTENT_DIR = "/data"
FORCE_UPDATE_FLAG = False
DEFAULT_LLM = "Jatin19K/unsloth-q5_k_m-mistral-nemo-instruct-2407"
DEFAULT_VECTOR_STORE = "CFIR"
EMBED_MODEL_PATH = os.path.join(PERSISTENT_DIR, "bge_onnx")
VECTOR_STORE_DIR = os.path.join(PERSISTENT_DIR, "vector_stores")
token = os.getenv("HF_TOKEN")
dataset_id = os.getenv("DATASET_ID")
def download_data_if_needed():
global FORCE_UPDATE_FLAG
if not os.path.exists(EMBED_MODEL_PATH) or not os.path.exists(VECTOR_STORE_DIR):
FORCE_UPDATE_FLAG = True
if FORCE_UPDATE_FLAG:
snapshot_download(
repo_id=dataset_id,
repo_type="dataset",
token=token,
local_dir=PERSISTENT_DIR
)
print("Data downloaded successfully.")
else:
print("Data exists.")
download_data_if_needed()
class ModelManager:
def __init__(self):
self.config = self._load_config()
self.available_models = self._initialize_models()
def _load_config(self):
"""Load model configuration from JSON file"""
try:
with open(CONFIG_PATH, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading config: {e}")
return {"models": []}
def _initialize_models(self):
"""Initialize and verify all models from config"""
config_models = self.config.get("models", [])
available_models = {}
# Get currently available Ollama models
try:
current_models = {m['name']: m['name'] for m in ollama.list()['models']}
print(current_models)
except Exception as e:
print(f"Error fetching current models: {e}")
current_models = {}
# Check each configured model
for model_name in config_models:
if model_name not in current_models:
print(f"Model {model_name} not found locally. Attempting to pull...")
try:
ollama.pull(model_name)
available_models[model_name] = model_name
print(f"Successfully pulled model {model_name}")
except Exception as e:
print(f"Error pulling model {model_name}: {e}")
continue
else:
available_models[model_name] = model_name
return available_models
def get_available_models(self):
"""Return dictionary of available models"""
return self.available_models
class EmbeddingManager:
def __init__(self):
self.embed_model = None
self._initialize_embed_model()
def _initialize_embed_model(self):
"""Initialize BGE ONNX embedding model with validation"""
try:
if not os.path.exists(EMBED_MODEL_PATH):
raise FileNotFoundError(f"BGE ONNX model not found at {EMBED_MODEL_PATH}")
self.embed_model = OptimumEmbedding(folder_name=EMBED_MODEL_PATH)
Settings.embed_model = self.embed_model
print("Successfully initialized BGE embedding model")
except Exception as e:
print(f"Embedding model error: {e}")
# Initialize managers
model_manager = ModelManager()
embed_manager = EmbeddingManager()
# Warm-up function to pre-initialize resources
def warm_up_resources():
"""Pre-initialize models and resources to reduce first response time"""
print("Warming up resources...")
try:
# Use the predefined default model and vector store
default_model = DEFAULT_LLM
default_store = DEFAULT_VECTOR_STORE
# Get available models and stores
available_models = model_manager.get_available_models()
available_stores = get_available_vector_stores()
# Debugging information
print(f"Default model we want to use: {default_model}")
print(f"Available models: {available_models}")
# Check if default model is configured
if default_model not in model_manager.config.get("models", []):
print(f"Warning: {default_model} is not in configured models list")
# Check if default model and store are available
if default_store in available_stores:
# Try to use default model if it's available
if default_model in available_models:
model_to_use = default_model
print(f"Using default model {model_to_use} and store {default_store} for warmup")
# Create a dummy session
dummy_session_id = f"warmup_{uuid.uuid4()}"
# Configure LLM
llm = Ollama(
model=model_to_use,
request_timeout=120,
temperature=0.3
)
Settings.llm = llm
# Preload vector store
vs_path = available_stores[default_store]["path"]
vector_store = DuckDBVectorStore.from_local(vs_path)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Initialize index and chat engine
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
storage_context=storage_context
)
# Create evaluator
evaluator = FaithfulnessEvaluator(llm=llm)
# Optionally run a simple query to fully initialize components
memory = ChatMemoryBuffer.from_defaults()
chat_engine = index.as_chat_engine(
chat_mode="context",
memory=memory,
system_prompt=(
"You are a helpful assistant which helps users to understand scientific knowledge"
"about biomechanics of injuries to human bodies."
),
similarity_top_k=3
)
print(f"Warm-up complete. Models and resources pre-initialized.")
return True
else:
print(f"Warm-up skipped: Default model {default_model} not available")
print(f"Available models: {list(available_models.keys())}")
else:
print(f"Warm-up skipped: Default store {default_store} not available")
print(f"Available stores: {list(available_stores.keys())}")
return False
except Exception as e:
print(f"Warm-up error: {str(e)}")
return False
def get_available_vector_stores():
"""Scan vector store directory for DuckDB files, supporting nested directories"""
vector_stores = {}
if os.path.exists(VECTOR_STORE_DIR):
# Add default store if it exists
cfir_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb")
if os.path.exists(cfir_path):
vector_stores[DEFAULT_VECTOR_STORE] = {
"path": cfir_path,
"display_name": DEFAULT_VECTOR_STORE
}
# Scan for .duckdb files in root directory and subdirectories
for root, dirs, files in os.walk(VECTOR_STORE_DIR):
for file in files:
if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb":
# Skip the default store since we've already handled it
if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb":
continue
# Get the full path to the file
file_path = os.path.join(root, file)
# Calculate store_name: combine category and subcategory
rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR)
path_parts = rel_path.split(os.sep)
if len(path_parts) == 1:
# Files in the root directory
store_name = path_parts[0][:-7] # Remove .duckdb
display_name = store_name
else:
# Files in subdirectories
category = path_parts[0]
file_name = path_parts[-1][:-7] # Remove .duckdb
store_name = f"{category}_{file_name}"
display_name = f"{category} - {file_name}"
vector_stores[store_name] = {
"path": file_path,
"display_name": display_name
}
return vector_stores
class ChatSessionManager:
def __init__(self):
self.sessions = {}
self.evaluators = {}
self.llms = {}
self.indexes = {}
self.memories = {} # Store memories separately
self.llm_options = model_manager.get_available_models()
self.vector_stores = get_available_vector_stores()
# Track which model and store are used for each session
self.session_configs = {}
def refresh_models(self):
self.llm_options = model_manager.get_available_models()
def refresh_vector_stores(self):
self.vector_stores = get_available_vector_stores()
def get_memory(self, session_id):
"""Get or create a memory buffer for the session"""
if session_id not in self.memories:
self.memories[session_id] = ChatMemoryBuffer.from_defaults()
print(f"Created new memory for session {session_id}")
return self.memories[session_id]
def get_llm(self, session_id, llm_choice):
"""Create or get an LLM instance for the given session"""
# Verify model exists
if llm_choice not in self.llm_options.values():
raise ValueError(f"Model {llm_choice} not available")
# Create a new LLM if needed
if session_id not in self.llms or self.session_configs.get(session_id, {}).get("llm") != llm_choice:
# Configure LLM
llm = Ollama(
model=llm_choice,
request_timeout=120,
temperature=0.3
)
self.llms[session_id] = llm
# Update config
if session_id not in self.session_configs:
self.session_configs[session_id] = {}
self.session_configs[session_id]["llm"] = llm_choice
# Set as default LLM for this session
Settings.llm = llm
# We need to recreate the chat engine when LLM changes, but preserve memory
if session_id in self.sessions:
del self.sessions[session_id]
print(f"Recreating chat engine for session {session_id} - LLM changed to {llm_choice}")
return self.llms[session_id]
def get_index(self, session_id, vector_store_choice):
"""Create or get a vector index for the given session"""
# Verify vector store exists
if vector_store_choice not in self.vector_stores:
raise ValueError(f"Vector store {vector_store_choice} not found")
# Create a new index if needed
if (session_id not in self.indexes or
self.session_configs.get(session_id, {}).get("vector_store") != vector_store_choice):
# Load vector store
vs_path = self.vector_stores[vector_store_choice]["path"]
vector_store = DuckDBVectorStore.from_local(vs_path)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create index
self.indexes[session_id] = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
storage_context=storage_context
)
# Update config
if session_id not in self.session_configs:
self.session_configs[session_id] = {}
self.session_configs[session_id]["vector_store"] = vector_store_choice
# If we're creating a new index, we need to recreate the chat engine but preserve memory
if session_id in self.sessions:
del self.sessions[session_id]
print(f"Recreating chat engine for session {session_id} - Vector store changed to {vector_store_choice}")
return self.indexes[session_id]
def get_evaluator(self, session_id):
"""Get or create the faithfulness evaluator for a session"""
if session_id not in self.evaluators:
if session_id not in self.llms:
raise ValueError(f"LLM must be created before evaluator for session {session_id}")
# Create faithfulness evaluator
self.evaluators[session_id] = FaithfulnessEvaluator(llm=self.llms[session_id])
return self.evaluators[session_id]
def get_chat_engine(self, session_id, llm_choice, vector_store_choice):
"""Create or get a chat engine using the specified LLM and vector store"""
# First, ensure we have the right LLM
llm = self.get_llm(session_id, llm_choice)
# Then, ensure we have the right index
index = self.get_index(session_id, vector_store_choice)
# Get the memory (creates a new one if it doesn't exist)
memory = self.get_memory(session_id)
# Create a new chat engine if needed
if session_id not in self.sessions:
self.sessions[session_id] = index.as_chat_engine(
chat_mode="context",
memory=memory, # Using existing memory
system_prompt=(
"You are a helpful assistant which helps users to understand scientific knowledge"
"about biomechanics of injuries to human bodies."
),
similarity_top_k=3
)
# Make sure we have an evaluator
self.get_evaluator(session_id)
print(f"Created new chat engine for session {session_id} with existing memory")
return self.sessions[session_id]
def clear_session(self, session_id):
"""Clear all resources for a session"""
if session_id in self.sessions:
del self.sessions[session_id]
if session_id in self.evaluators:
del self.evaluators[session_id]
if session_id in self.llms:
del self.llms[session_id]
if session_id in self.indexes:
del self.indexes[session_id]
if session_id in self.memories:
del self.memories[session_id] # Clear memory only when explicitly clearing session
if session_id in self.session_configs:
del self.session_configs[session_id]
print(f"Completely cleared session {session_id} including memory")
# Initialize session manager
session_manager = ChatSessionManager()
def chat_response(message, history, llm_choice, vector_store_choice, session_state):
try:
# Disable UI components at the start of response generation
ui_state = {
"llm_dropdown": gr.update(interactive=False),
"vector_dropdown": gr.update(interactive=False),
"msg": gr.update(interactive=False),
"clear_btn": gr.update(interactive=False),
"status": gr.update(value='<div style="text-align:center; color:#e67e22; font-weight:bold;">⚙️ Processing...</div>')
}
# Manage session state
if not session_state:
session_id = str(uuid.uuid4())
session_state = {
"session_id": session_id,
"total_score": 0.0,
"answer_count": 0,
"current_llm": llm_choice,
"current_vs": vector_store_choice
}
else:
session_id = session_state["session_id"]
# Ensure score tracking fields exist and handle model changes
if "total_score" not in session_state:
session_state["total_score"] = 0.0
if "answer_count" not in session_state:
session_state["answer_count"] = 0
if "current_llm" not in session_state:
session_state["current_llm"] = llm_choice
if "current_vs" not in session_state:
session_state["current_vs"] = vector_store_choice
# Check if model or vector store changed
if session_state["current_llm"] != llm_choice or session_state["current_vs"] != vector_store_choice:
print(f"Configuration changed. Resetting scores.")
session_state["total_score"] = 0.0
session_state["answer_count"] = 0
session_state["current_llm"] = llm_choice
session_state["current_vs"] = vector_store_choice
chat_engine = session_manager.get_chat_engine(session_id, llm_choice, vector_store_choice)
evaluator = session_manager.get_evaluator(session_id)
start_time = time.time()
# First yield to disable UI components
yield history, session_state, ui_state["llm_dropdown"], ui_state["vector_dropdown"], ui_state["msg"], ui_state["clear_btn"], ui_state["status"]
# Use streaming chat
streamer = chat_engine.stream_chat(message)
# Simple variables for tracking content
response_text = ""
thinking_text = ""
full_response = ""
in_thinking = False
response_source_nodes = streamer.source_nodes
# Process the streaming response
for token in streamer.response_gen:
full_response += token
# Simple tag handling
if "<think>" in token:
in_thinking = True
token = token.replace("<think>", "")
if "</think>" in token:
in_thinking = False
token = token.replace("</think>", "")
# Add content to the appropriate buffer
if in_thinking:
thinking_text += token
else:
response_text += token
# Update the UI - thinking shown above response
elapsed_time = time.time() - start_time
if thinking_text:
# Show thinking above response during streaming (open by default)
current_response = f"<details open><summary>🧠 AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nTime: {round(elapsed_time, 3)}s"
else:
current_response = f"{response_text}\n\nTime: {round(elapsed_time, 3)}s"
yield history + [(message, current_response)], session_state, ui_state["llm_dropdown"], ui_state["vector_dropdown"], ui_state["msg"], ui_state["clear_btn"], ui_state["status"]
print(f"Full response:\n{full_response}")
# print(f"Streaming complete. Response: {response_text}")
# After streaming completes, evaluate the response
if evaluator and response_source_nodes:
try:
# Run evaluation with timeout
contexts = [node.get_content() for node in response_source_nodes]
EVAL_TIMEOUT = 5.0
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
evaluator.evaluate,
query=message,
response=response_text,
contexts=contexts
)
try:
eval_result = future.result(timeout=EVAL_TIMEOUT)
faithfulness_score = eval_result.score
# Update scoring
session_state["total_score"] += faithfulness_score
session_state["answer_count"] += 1
avg_score = session_state["total_score"] / session_state["answer_count"]
# Create evaluation info
final_info = f"Time: {round(time.time() - start_time, 3)}s • Score: {faithfulness_score:.3f} • Avg: {avg_score:.3f} ({session_state['answer_count']} questions)"
except concurrent.futures.TimeoutError:
final_info = f"Time: {round(time.time() - start_time, 3)}s • Evaluation timed out"
# Prepare final response with thinking collapsed and evaluation info
if thinking_text:
final_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\n{final_info}"
else:
final_response = f"{response_text}\n\n{final_info}"
# Re-enable UI components
enabled_ui = {
"llm_dropdown": gr.update(interactive=True),
"vector_dropdown": gr.update(interactive=True),
"msg": gr.update(interactive=True),
"clear_btn": gr.update(interactive=True),
"status": gr.update(value='<div style="text-align:center; color:#27ae60; font-weight:bold;">✓ Ready</div>')
}
yield history + [(message, final_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"]
except Exception as e:
print(f"Evaluation error: {str(e)}")
# Simple error handling
if thinking_text:
error_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nError during evaluation"
else:
error_response = f"{response_text}\n\nError during evaluation"
# Re-enable UI components on error
enabled_ui = {
"llm_dropdown": gr.update(interactive=True),
"vector_dropdown": gr.update(interactive=True),
"msg": gr.update(interactive=True),
"clear_btn": gr.update(interactive=True),
"status": gr.update(value='<div style="text-align:center; color:#e74c3c; font-weight:bold;">✗ Error</div>')
}
yield history + [(message, error_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"]
else:
# No evaluation case
elapsed_time = time.time() - start_time
if thinking_text:
final_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nTime: {round(elapsed_time, 3)}s"
else:
final_response = f"{response_text}\n\nTime: {round(elapsed_time, 3)}s"
# Re-enable UI components
enabled_ui = {
"llm_dropdown": gr.update(interactive=True),
"vector_dropdown": gr.update(interactive=True),
"msg": gr.update(interactive=True),
"clear_btn": gr.update(interactive=True),
"status": gr.update(value='<div style="text-align:center; color:#27ae60; font-weight:bold;">✓ Ready</div>')
}
yield history + [(message, final_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"]
except Exception as e:
print(f"Chat error: {str(e)}")
# Re-enable UI components on error
enabled_ui = {
"llm_dropdown": gr.update(interactive=True),
"vector_dropdown": gr.update(interactive=True),
"msg": gr.update(interactive=True),
"clear_btn": gr.update(interactive=True),
"status": gr.update(value='<div style="text-align:center; color:#e74c3c; font-weight:bold;">✗ Error</div>')
}
yield history + [(message, f"Error: {str(e)}")], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"]
# Gradio interface with embedding status
with gr.Blocks(title="De-KCIB(Deep Knowledge Center for Injury Biomechanics)", css="""
details {
border: 1px solid #e0e0e0;
border-radius: 5px;
padding: 0;
margin: 10px 0;
}
summary {
background-color: #f5f5f5;
padding: 8px 15px;
cursor: pointer;
font-weight: 500;
border-radius: 5px 5px 0 0;
user-select: none;
}
summary:hover {
background-color: #e6f7f5;
}
details[open] summary {
border-bottom: 1px solid #e0e0e0;
}
details > div {
padding: 15px;
background-color: #fcfcfc;
border-radius: 0 0 5px 5px;
white-space: pre-wrap;
font-family: monospace;
overflow-x: auto;
}
""") as demo:
session_state = gr.State()
with gr.Row():
gr.Markdown("<img src='/gradio_api/file/logo.png' alt='Innovision Logo' height='150' width='390'>")
with gr.Row():
gr.Markdown("# De-KCIB(Deep Knowledge Center for Injury Biomechanics)")
with gr.Row():
with gr.Column(scale=1):
llm_dropdown = gr.Dropdown(
label="Select Language Model",
choices=list(session_manager.llm_options.values()),
value=next(iter(session_manager.llm_options.values()), None)
)
vector_dropdown = gr.Dropdown(
label="Injury Biomechanics Knowledge Base",
choices=[(v["display_name"], k) for k, v in session_manager.vector_stores.items()],
value=next(iter(session_manager.vector_stores.keys()), None)
)
status_indicator = gr.HTML(
value='<div style="text-align:center; margin-top:15px;">Ready</div>',
label="Status"
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=500,
render_markdown=True,
bubble_full_width=False,
show_copy_button=True
)
msg = gr.Textbox(label="Query")
clear_btn = gr.Button("Clear Session")
msg.submit(
chat_response,
[msg, chatbot, llm_dropdown, vector_dropdown, session_state],
[chatbot, session_state, llm_dropdown, vector_dropdown, msg, clear_btn, status_indicator]
).then(
lambda: "", # Just clear the message box
None,
[msg]
)
def clear_session(session_state):
"""Clear session and reset state"""
# Clear resources in the session manager
if session_state and "session_id" in session_state:
session_manager.clear_session(session_state["session_id"])
# Return empty chat and reset state but preserve session ID
new_state = {"total_score": 0.0, "answer_count": 0}
if session_state and "session_id" in session_state:
new_state["session_id"] = session_state["session_id"]
return [], new_state
clear_btn.click(
clear_session,
[session_state],
[chatbot, session_state],
queue=False
)
# Add queue to enable streaming
demo.queue()
def prewarm_model(model_name):
"""Send a simple query to warm up a specific model"""
try:
print(f"Pre-warming model: {model_name}")
llm = Ollama(
model=model_name,
request_timeout=30,
temperature=0.3
)
# Simple query to initialize the model
_ = llm.complete("Hello world")
print(f"Successfully pre-warmed model: {model_name}")
return True
except Exception as e:
print(f"Error pre-warming model {model_name}: {e}")
return False
# Deployment settings
if __name__ == "__main__":
# Run warm-up to pre-initialize resources
# warm_up_resources()
available_models = model_manager.get_available_models()
for model_name in available_models.values():
prewarm_model(model_name)
# Launch the Gradio app
demo.launch(allowed_paths=["logo.png"])