Spaces:
Sleeping
Sleeping
| # app.py | |
| from llama_index.embeddings.huggingface_optimum import OptimumEmbedding | |
| import gradio as gr | |
| from llama_index.core import Settings | |
| from llama_index.core import VectorStoreIndex, StorageContext, Response | |
| from llama_index.vector_stores.duckdb import DuckDBVectorStore | |
| from llama_index.llms.ollama import Ollama | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from llama_index.core.evaluation import FaithfulnessEvaluator | |
| import json | |
| import ollama | |
| import os | |
| import uuid | |
| import nest_asyncio | |
| from huggingface_hub import snapshot_download | |
| import html | |
| from gradio.themes.utils import fonts, sizes | |
| from gradio.themes import Base | |
| import concurrent.futures | |
| import time | |
| nest_asyncio.apply() | |
| # Create a custom theme with larger text | |
| large_text_theme = Base( | |
| # Increase all font sizes by ~25% | |
| font=[fonts.GoogleFont("Roboto"), "ui-sans-serif", "sans-serif"], | |
| font_mono=[fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"], | |
| text_size=sizes.text_lg, # Base text size (default is text_md) | |
| radius_size=sizes.radius_md, | |
| ) | |
| CONFIG_PATH = "config.json" | |
| PERSISTENT_DIR = "/data" | |
| FORCE_UPDATE_FLAG = False | |
| DEFAULT_LLM = "Jatin19K/unsloth-q5_k_m-mistral-nemo-instruct-2407" | |
| DEFAULT_VECTOR_STORE = "CFIR" | |
| EMBED_MODEL_PATH = os.path.join(PERSISTENT_DIR, "bge_onnx") | |
| VECTOR_STORE_DIR = os.path.join(PERSISTENT_DIR, "vector_stores") | |
| token = os.getenv("HF_TOKEN") | |
| dataset_id = os.getenv("DATASET_ID") | |
| def download_data_if_needed(): | |
| global FORCE_UPDATE_FLAG | |
| if not os.path.exists(EMBED_MODEL_PATH) or not os.path.exists(VECTOR_STORE_DIR): | |
| FORCE_UPDATE_FLAG = True | |
| if FORCE_UPDATE_FLAG: | |
| snapshot_download( | |
| repo_id=dataset_id, | |
| repo_type="dataset", | |
| token=token, | |
| local_dir=PERSISTENT_DIR | |
| ) | |
| print("Data downloaded successfully.") | |
| else: | |
| print("Data exists.") | |
| download_data_if_needed() | |
| class ModelManager: | |
| def __init__(self): | |
| self.config = self._load_config() | |
| self.available_models = self._initialize_models() | |
| def _load_config(self): | |
| """Load model configuration from JSON file""" | |
| try: | |
| with open(CONFIG_PATH, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading config: {e}") | |
| return {"models": []} | |
| def _initialize_models(self): | |
| """Initialize and verify all models from config""" | |
| config_models = self.config.get("models", []) | |
| available_models = {} | |
| # Get currently available Ollama models | |
| try: | |
| current_models = {m['name']: m['name'] for m in ollama.list()['models']} | |
| print(current_models) | |
| except Exception as e: | |
| print(f"Error fetching current models: {e}") | |
| current_models = {} | |
| # Check each configured model | |
| for model_name in config_models: | |
| if model_name not in current_models: | |
| print(f"Model {model_name} not found locally. Attempting to pull...") | |
| try: | |
| ollama.pull(model_name) | |
| available_models[model_name] = model_name | |
| print(f"Successfully pulled model {model_name}") | |
| except Exception as e: | |
| print(f"Error pulling model {model_name}: {e}") | |
| continue | |
| else: | |
| available_models[model_name] = model_name | |
| return available_models | |
| def get_available_models(self): | |
| """Return dictionary of available models""" | |
| return self.available_models | |
| class EmbeddingManager: | |
| def __init__(self): | |
| self.embed_model = None | |
| self._initialize_embed_model() | |
| def _initialize_embed_model(self): | |
| """Initialize BGE ONNX embedding model with validation""" | |
| try: | |
| if not os.path.exists(EMBED_MODEL_PATH): | |
| raise FileNotFoundError(f"BGE ONNX model not found at {EMBED_MODEL_PATH}") | |
| self.embed_model = OptimumEmbedding(folder_name=EMBED_MODEL_PATH) | |
| Settings.embed_model = self.embed_model | |
| print("Successfully initialized BGE embedding model") | |
| except Exception as e: | |
| print(f"Embedding model error: {e}") | |
| # Initialize managers | |
| model_manager = ModelManager() | |
| embed_manager = EmbeddingManager() | |
| # Warm-up function to pre-initialize resources | |
| def warm_up_resources(): | |
| """Pre-initialize models and resources to reduce first response time""" | |
| print("Warming up resources...") | |
| try: | |
| # Use the predefined default model and vector store | |
| default_model = DEFAULT_LLM | |
| default_store = DEFAULT_VECTOR_STORE | |
| # Get available models and stores | |
| available_models = model_manager.get_available_models() | |
| available_stores = get_available_vector_stores() | |
| # Debugging information | |
| print(f"Default model we want to use: {default_model}") | |
| print(f"Available models: {available_models}") | |
| # Check if default model is configured | |
| if default_model not in model_manager.config.get("models", []): | |
| print(f"Warning: {default_model} is not in configured models list") | |
| # Check if default model and store are available | |
| if default_store in available_stores: | |
| # Try to use default model if it's available | |
| if default_model in available_models: | |
| model_to_use = default_model | |
| print(f"Using default model {model_to_use} and store {default_store} for warmup") | |
| # Create a dummy session | |
| dummy_session_id = f"warmup_{uuid.uuid4()}" | |
| # Configure LLM | |
| llm = Ollama( | |
| model=model_to_use, | |
| request_timeout=120, | |
| temperature=0.3 | |
| ) | |
| Settings.llm = llm | |
| # Preload vector store | |
| vs_path = available_stores[default_store]["path"] | |
| vector_store = DuckDBVectorStore.from_local(vs_path) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| # Initialize index and chat engine | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, | |
| storage_context=storage_context | |
| ) | |
| # Create evaluator | |
| evaluator = FaithfulnessEvaluator(llm=llm) | |
| # Optionally run a simple query to fully initialize components | |
| memory = ChatMemoryBuffer.from_defaults() | |
| chat_engine = index.as_chat_engine( | |
| chat_mode="context", | |
| memory=memory, | |
| system_prompt=( | |
| "You are a helpful assistant which helps users to understand scientific knowledge" | |
| "about biomechanics of injuries to human bodies." | |
| ), | |
| similarity_top_k=3 | |
| ) | |
| print(f"Warm-up complete. Models and resources pre-initialized.") | |
| return True | |
| else: | |
| print(f"Warm-up skipped: Default model {default_model} not available") | |
| print(f"Available models: {list(available_models.keys())}") | |
| else: | |
| print(f"Warm-up skipped: Default store {default_store} not available") | |
| print(f"Available stores: {list(available_stores.keys())}") | |
| return False | |
| except Exception as e: | |
| print(f"Warm-up error: {str(e)}") | |
| return False | |
| def get_available_vector_stores(): | |
| """Scan vector store directory for DuckDB files, supporting nested directories""" | |
| vector_stores = {} | |
| if os.path.exists(VECTOR_STORE_DIR): | |
| # Add default store if it exists | |
| cfir_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb") | |
| if os.path.exists(cfir_path): | |
| vector_stores[DEFAULT_VECTOR_STORE] = { | |
| "path": cfir_path, | |
| "display_name": DEFAULT_VECTOR_STORE | |
| } | |
| # Scan for .duckdb files in root directory and subdirectories | |
| for root, dirs, files in os.walk(VECTOR_STORE_DIR): | |
| for file in files: | |
| if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| # Skip the default store since we've already handled it | |
| if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| continue | |
| # Get the full path to the file | |
| file_path = os.path.join(root, file) | |
| # Calculate store_name: combine category and subcategory | |
| rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR) | |
| path_parts = rel_path.split(os.sep) | |
| if len(path_parts) == 1: | |
| # Files in the root directory | |
| store_name = path_parts[0][:-7] # Remove .duckdb | |
| display_name = store_name | |
| else: | |
| # Files in subdirectories | |
| category = path_parts[0] | |
| file_name = path_parts[-1][:-7] # Remove .duckdb | |
| store_name = f"{category}_{file_name}" | |
| display_name = f"{category} - {file_name}" | |
| vector_stores[store_name] = { | |
| "path": file_path, | |
| "display_name": display_name | |
| } | |
| return vector_stores | |
| class ChatSessionManager: | |
| def __init__(self): | |
| self.sessions = {} | |
| self.evaluators = {} | |
| self.llms = {} | |
| self.indexes = {} | |
| self.memories = {} # Store memories separately | |
| self.llm_options = model_manager.get_available_models() | |
| self.vector_stores = get_available_vector_stores() | |
| # Track which model and store are used for each session | |
| self.session_configs = {} | |
| def refresh_models(self): | |
| self.llm_options = model_manager.get_available_models() | |
| def refresh_vector_stores(self): | |
| self.vector_stores = get_available_vector_stores() | |
| def get_memory(self, session_id): | |
| """Get or create a memory buffer for the session""" | |
| if session_id not in self.memories: | |
| self.memories[session_id] = ChatMemoryBuffer.from_defaults() | |
| print(f"Created new memory for session {session_id}") | |
| return self.memories[session_id] | |
| def get_llm(self, session_id, llm_choice): | |
| """Create or get an LLM instance for the given session""" | |
| # Verify model exists | |
| if llm_choice not in self.llm_options.values(): | |
| raise ValueError(f"Model {llm_choice} not available") | |
| # Create a new LLM if needed | |
| if session_id not in self.llms or self.session_configs.get(session_id, {}).get("llm") != llm_choice: | |
| # Configure LLM | |
| llm = Ollama( | |
| model=llm_choice, | |
| request_timeout=120, | |
| temperature=0.3 | |
| ) | |
| self.llms[session_id] = llm | |
| # Update config | |
| if session_id not in self.session_configs: | |
| self.session_configs[session_id] = {} | |
| self.session_configs[session_id]["llm"] = llm_choice | |
| # Set as default LLM for this session | |
| Settings.llm = llm | |
| # We need to recreate the chat engine when LLM changes, but preserve memory | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| print(f"Recreating chat engine for session {session_id} - LLM changed to {llm_choice}") | |
| return self.llms[session_id] | |
| def get_index(self, session_id, vector_store_choice): | |
| """Create or get a vector index for the given session""" | |
| # Verify vector store exists | |
| if vector_store_choice not in self.vector_stores: | |
| raise ValueError(f"Vector store {vector_store_choice} not found") | |
| # Create a new index if needed | |
| if (session_id not in self.indexes or | |
| self.session_configs.get(session_id, {}).get("vector_store") != vector_store_choice): | |
| # Load vector store | |
| vs_path = self.vector_stores[vector_store_choice]["path"] | |
| vector_store = DuckDBVectorStore.from_local(vs_path) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| # Create index | |
| self.indexes[session_id] = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, | |
| storage_context=storage_context | |
| ) | |
| # Update config | |
| if session_id not in self.session_configs: | |
| self.session_configs[session_id] = {} | |
| self.session_configs[session_id]["vector_store"] = vector_store_choice | |
| # If we're creating a new index, we need to recreate the chat engine but preserve memory | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| print(f"Recreating chat engine for session {session_id} - Vector store changed to {vector_store_choice}") | |
| return self.indexes[session_id] | |
| def get_evaluator(self, session_id): | |
| """Get or create the faithfulness evaluator for a session""" | |
| if session_id not in self.evaluators: | |
| if session_id not in self.llms: | |
| raise ValueError(f"LLM must be created before evaluator for session {session_id}") | |
| # Create faithfulness evaluator | |
| self.evaluators[session_id] = FaithfulnessEvaluator(llm=self.llms[session_id]) | |
| return self.evaluators[session_id] | |
| def get_chat_engine(self, session_id, llm_choice, vector_store_choice): | |
| """Create or get a chat engine using the specified LLM and vector store""" | |
| # First, ensure we have the right LLM | |
| llm = self.get_llm(session_id, llm_choice) | |
| # Then, ensure we have the right index | |
| index = self.get_index(session_id, vector_store_choice) | |
| # Get the memory (creates a new one if it doesn't exist) | |
| memory = self.get_memory(session_id) | |
| # Create a new chat engine if needed | |
| if session_id not in self.sessions: | |
| self.sessions[session_id] = index.as_chat_engine( | |
| chat_mode="context", | |
| memory=memory, # Using existing memory | |
| system_prompt=( | |
| "You are a helpful assistant which helps users to understand scientific knowledge" | |
| "about biomechanics of injuries to human bodies." | |
| ), | |
| similarity_top_k=3 | |
| ) | |
| # Make sure we have an evaluator | |
| self.get_evaluator(session_id) | |
| print(f"Created new chat engine for session {session_id} with existing memory") | |
| return self.sessions[session_id] | |
| def clear_session(self, session_id): | |
| """Clear all resources for a session""" | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| if session_id in self.evaluators: | |
| del self.evaluators[session_id] | |
| if session_id in self.llms: | |
| del self.llms[session_id] | |
| if session_id in self.indexes: | |
| del self.indexes[session_id] | |
| if session_id in self.memories: | |
| del self.memories[session_id] # Clear memory only when explicitly clearing session | |
| if session_id in self.session_configs: | |
| del self.session_configs[session_id] | |
| print(f"Completely cleared session {session_id} including memory") | |
| # Initialize session manager | |
| session_manager = ChatSessionManager() | |
| def chat_response(message, history, llm_choice, vector_store_choice, session_state): | |
| try: | |
| # Disable UI components at the start of response generation | |
| ui_state = { | |
| "llm_dropdown": gr.update(interactive=False), | |
| "vector_dropdown": gr.update(interactive=False), | |
| "msg": gr.update(interactive=False), | |
| "clear_btn": gr.update(interactive=False), | |
| "status": gr.update(value='<div style="text-align:center; color:#e67e22; font-weight:bold;">⚙️ Processing...</div>') | |
| } | |
| # Manage session state | |
| if not session_state: | |
| session_id = str(uuid.uuid4()) | |
| session_state = { | |
| "session_id": session_id, | |
| "total_score": 0.0, | |
| "answer_count": 0, | |
| "current_llm": llm_choice, | |
| "current_vs": vector_store_choice | |
| } | |
| else: | |
| session_id = session_state["session_id"] | |
| # Ensure score tracking fields exist and handle model changes | |
| if "total_score" not in session_state: | |
| session_state["total_score"] = 0.0 | |
| if "answer_count" not in session_state: | |
| session_state["answer_count"] = 0 | |
| if "current_llm" not in session_state: | |
| session_state["current_llm"] = llm_choice | |
| if "current_vs" not in session_state: | |
| session_state["current_vs"] = vector_store_choice | |
| # Check if model or vector store changed | |
| if session_state["current_llm"] != llm_choice or session_state["current_vs"] != vector_store_choice: | |
| print(f"Configuration changed. Resetting scores.") | |
| session_state["total_score"] = 0.0 | |
| session_state["answer_count"] = 0 | |
| session_state["current_llm"] = llm_choice | |
| session_state["current_vs"] = vector_store_choice | |
| chat_engine = session_manager.get_chat_engine(session_id, llm_choice, vector_store_choice) | |
| evaluator = session_manager.get_evaluator(session_id) | |
| start_time = time.time() | |
| # First yield to disable UI components | |
| yield history, session_state, ui_state["llm_dropdown"], ui_state["vector_dropdown"], ui_state["msg"], ui_state["clear_btn"], ui_state["status"] | |
| # Use streaming chat | |
| streamer = chat_engine.stream_chat(message) | |
| # Simple variables for tracking content | |
| response_text = "" | |
| thinking_text = "" | |
| full_response = "" | |
| in_thinking = False | |
| response_source_nodes = streamer.source_nodes | |
| # Process the streaming response | |
| for token in streamer.response_gen: | |
| full_response += token | |
| # Simple tag handling | |
| if "<think>" in token: | |
| in_thinking = True | |
| token = token.replace("<think>", "") | |
| if "</think>" in token: | |
| in_thinking = False | |
| token = token.replace("</think>", "") | |
| # Add content to the appropriate buffer | |
| if in_thinking: | |
| thinking_text += token | |
| else: | |
| response_text += token | |
| # Update the UI - thinking shown above response | |
| elapsed_time = time.time() - start_time | |
| if thinking_text: | |
| # Show thinking above response during streaming (open by default) | |
| current_response = f"<details open><summary>🧠 AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nTime: {round(elapsed_time, 3)}s" | |
| else: | |
| current_response = f"{response_text}\n\nTime: {round(elapsed_time, 3)}s" | |
| yield history + [(message, current_response)], session_state, ui_state["llm_dropdown"], ui_state["vector_dropdown"], ui_state["msg"], ui_state["clear_btn"], ui_state["status"] | |
| print(f"Full response:\n{full_response}") | |
| # print(f"Streaming complete. Response: {response_text}") | |
| # After streaming completes, evaluate the response | |
| if evaluator and response_source_nodes: | |
| try: | |
| # Run evaluation with timeout | |
| contexts = [node.get_content() for node in response_source_nodes] | |
| EVAL_TIMEOUT = 5.0 | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit( | |
| evaluator.evaluate, | |
| query=message, | |
| response=response_text, | |
| contexts=contexts | |
| ) | |
| try: | |
| eval_result = future.result(timeout=EVAL_TIMEOUT) | |
| faithfulness_score = eval_result.score | |
| # Update scoring | |
| session_state["total_score"] += faithfulness_score | |
| session_state["answer_count"] += 1 | |
| avg_score = session_state["total_score"] / session_state["answer_count"] | |
| # Create evaluation info | |
| final_info = f"Time: {round(time.time() - start_time, 3)}s • Score: {faithfulness_score:.3f} • Avg: {avg_score:.3f} ({session_state['answer_count']} questions)" | |
| except concurrent.futures.TimeoutError: | |
| final_info = f"Time: {round(time.time() - start_time, 3)}s • Evaluation timed out" | |
| # Prepare final response with thinking collapsed and evaluation info | |
| if thinking_text: | |
| final_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\n{final_info}" | |
| else: | |
| final_response = f"{response_text}\n\n{final_info}" | |
| # Re-enable UI components | |
| enabled_ui = { | |
| "llm_dropdown": gr.update(interactive=True), | |
| "vector_dropdown": gr.update(interactive=True), | |
| "msg": gr.update(interactive=True), | |
| "clear_btn": gr.update(interactive=True), | |
| "status": gr.update(value='<div style="text-align:center; color:#27ae60; font-weight:bold;">✓ Ready</div>') | |
| } | |
| yield history + [(message, final_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"] | |
| except Exception as e: | |
| print(f"Evaluation error: {str(e)}") | |
| # Simple error handling | |
| if thinking_text: | |
| error_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nError during evaluation" | |
| else: | |
| error_response = f"{response_text}\n\nError during evaluation" | |
| # Re-enable UI components on error | |
| enabled_ui = { | |
| "llm_dropdown": gr.update(interactive=True), | |
| "vector_dropdown": gr.update(interactive=True), | |
| "msg": gr.update(interactive=True), | |
| "clear_btn": gr.update(interactive=True), | |
| "status": gr.update(value='<div style="text-align:center; color:#e74c3c; font-weight:bold;">✗ Error</div>') | |
| } | |
| yield history + [(message, error_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"] | |
| else: | |
| # No evaluation case | |
| elapsed_time = time.time() - start_time | |
| if thinking_text: | |
| final_response = f"<details><summary>🧠 Show AI Thinking</summary><div>{html.escape(thinking_text)}</div></details>\n\n{response_text}\n\nTime: {round(elapsed_time, 3)}s" | |
| else: | |
| final_response = f"{response_text}\n\nTime: {round(elapsed_time, 3)}s" | |
| # Re-enable UI components | |
| enabled_ui = { | |
| "llm_dropdown": gr.update(interactive=True), | |
| "vector_dropdown": gr.update(interactive=True), | |
| "msg": gr.update(interactive=True), | |
| "clear_btn": gr.update(interactive=True), | |
| "status": gr.update(value='<div style="text-align:center; color:#27ae60; font-weight:bold;">✓ Ready</div>') | |
| } | |
| yield history + [(message, final_response)], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"] | |
| except Exception as e: | |
| print(f"Chat error: {str(e)}") | |
| # Re-enable UI components on error | |
| enabled_ui = { | |
| "llm_dropdown": gr.update(interactive=True), | |
| "vector_dropdown": gr.update(interactive=True), | |
| "msg": gr.update(interactive=True), | |
| "clear_btn": gr.update(interactive=True), | |
| "status": gr.update(value='<div style="text-align:center; color:#e74c3c; font-weight:bold;">✗ Error</div>') | |
| } | |
| yield history + [(message, f"Error: {str(e)}")], session_state, enabled_ui["llm_dropdown"], enabled_ui["vector_dropdown"], enabled_ui["msg"], enabled_ui["clear_btn"], enabled_ui["status"] | |
| # Gradio interface with embedding status | |
| with gr.Blocks(title="De-KCIB(Deep Knowledge Center for Injury Biomechanics)", css=""" | |
| details { | |
| border: 1px solid #e0e0e0; | |
| border-radius: 5px; | |
| padding: 0; | |
| margin: 10px 0; | |
| } | |
| summary { | |
| background-color: #f5f5f5; | |
| padding: 8px 15px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| border-radius: 5px 5px 0 0; | |
| user-select: none; | |
| } | |
| summary:hover { | |
| background-color: #e6f7f5; | |
| } | |
| details[open] summary { | |
| border-bottom: 1px solid #e0e0e0; | |
| } | |
| details > div { | |
| padding: 15px; | |
| background-color: #fcfcfc; | |
| border-radius: 0 0 5px 5px; | |
| white-space: pre-wrap; | |
| font-family: monospace; | |
| overflow-x: auto; | |
| } | |
| """) as demo: | |
| session_state = gr.State() | |
| with gr.Row(): | |
| gr.Markdown("<img src='/gradio_api/file/logo.png' alt='Innovision Logo' height='150' width='390'>") | |
| with gr.Row(): | |
| gr.Markdown("# De-KCIB(Deep Knowledge Center for Injury Biomechanics)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| llm_dropdown = gr.Dropdown( | |
| label="Select Language Model", | |
| choices=list(session_manager.llm_options.values()), | |
| value=next(iter(session_manager.llm_options.values()), None) | |
| ) | |
| vector_dropdown = gr.Dropdown( | |
| label="Injury Biomechanics Knowledge Base", | |
| choices=[(v["display_name"], k) for k, v in session_manager.vector_stores.items()], | |
| value=next(iter(session_manager.vector_stores.keys()), None) | |
| ) | |
| status_indicator = gr.HTML( | |
| value='<div style="text-align:center; margin-top:15px;">Ready</div>', | |
| label="Status" | |
| ) | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| render_markdown=True, | |
| bubble_full_width=False, | |
| show_copy_button=True | |
| ) | |
| msg = gr.Textbox(label="Query") | |
| clear_btn = gr.Button("Clear Session") | |
| msg.submit( | |
| chat_response, | |
| [msg, chatbot, llm_dropdown, vector_dropdown, session_state], | |
| [chatbot, session_state, llm_dropdown, vector_dropdown, msg, clear_btn, status_indicator] | |
| ).then( | |
| lambda: "", # Just clear the message box | |
| None, | |
| [msg] | |
| ) | |
| def clear_session(session_state): | |
| """Clear session and reset state""" | |
| # Clear resources in the session manager | |
| if session_state and "session_id" in session_state: | |
| session_manager.clear_session(session_state["session_id"]) | |
| # Return empty chat and reset state but preserve session ID | |
| new_state = {"total_score": 0.0, "answer_count": 0} | |
| if session_state and "session_id" in session_state: | |
| new_state["session_id"] = session_state["session_id"] | |
| return [], new_state | |
| clear_btn.click( | |
| clear_session, | |
| [session_state], | |
| [chatbot, session_state], | |
| queue=False | |
| ) | |
| # Add queue to enable streaming | |
| demo.queue() | |
| def prewarm_model(model_name): | |
| """Send a simple query to warm up a specific model""" | |
| try: | |
| print(f"Pre-warming model: {model_name}") | |
| llm = Ollama( | |
| model=model_name, | |
| request_timeout=30, | |
| temperature=0.3 | |
| ) | |
| # Simple query to initialize the model | |
| _ = llm.complete("Hello world") | |
| print(f"Successfully pre-warmed model: {model_name}") | |
| return True | |
| except Exception as e: | |
| print(f"Error pre-warming model {model_name}: {e}") | |
| return False | |
| # Deployment settings | |
| if __name__ == "__main__": | |
| # Run warm-up to pre-initialize resources | |
| # warm_up_resources() | |
| available_models = model_manager.get_available_models() | |
| for model_name in available_models.values(): | |
| prewarm_model(model_name) | |
| # Launch the Gradio app | |
| demo.launch(allowed_paths=["logo.png"]) | |