Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| import random | |
| import pandas as pd | |
| from src.vector_db import UnifiedQdrant | |
| from src.router import LearnedRouter | |
| from src.data_pipeline import get_embedding | |
| # --- Configuration --- | |
| COLLECTION_NAME = "dashVector_v1" | |
| VECTOR_SIZE = 384 # MiniLM-L6-v2 | |
| NUM_CLUSTERS = 16 | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| # --- Initialize Backend --- | |
| # We initialize once at startup | |
| vector_db = UnifiedQdrant(COLLECTION_NAME, VECTOR_SIZE, NUM_CLUSTERS) | |
| vector_db.initialize() | |
| # Load Router (Ensure it exists, else mock/warn) | |
| ROUTER_PATH = "models/router_v1.pkl" | |
| try: | |
| router = LearnedRouter.load(ROUTER_PATH) | |
| except Exception as e: | |
| print(f"Warning: Could not load router: {e}. Using dummy router for UI demo if needed.") | |
| router = None | |
| # --- HTML Templates (Extracted from dashVector_benchmark.html) --- | |
| # --- HTML Templates (Extracted from dashVector_benchmark.html) --- | |
| HEAD_HTML = """ | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet"> | |
| <link href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" rel="stylesheet"> | |
| <style> | |
| body { font-family: 'Inter', sans-serif; background-color: #f8f9fa; } | |
| .fade-in { animation: fadeIn 0.5s ease-out forwards; } | |
| @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } } | |
| /* Hide Gradio footer */ | |
| footer { display: none !important; } | |
| .gradio-container { max-width: 100% !important; padding: 0 !important; margin: 0 !important; background-color: #f8f9fa; } | |
| /* Custom Scrollbar */ | |
| .custom-scrollbar::-webkit-scrollbar { height: 8px; width: 8px; } | |
| .custom-scrollbar::-webkit-scrollbar-track { background: #f1f1f1; } | |
| .custom-scrollbar::-webkit-scrollbar-thumb { background: #c1c1c1; border-radius: 4px; } | |
| .custom-scrollbar::-webkit-scrollbar-thumb:hover { background: #a8a8a8; } | |
| /* Overwrite Gradio Input Styles to match Reference */ | |
| #custom-input textarea { | |
| background-color: white !important; | |
| border: 1px solid #cbd5e1 !important; | |
| border-radius: 0.75rem !important; /* rounded-xl */ | |
| padding: 0.75rem 1rem !important; | |
| font-size: 1rem !important; | |
| box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05) !important; | |
| height: 50px !important; /* Fixed height for alignment */ | |
| } | |
| #custom-input textarea:focus { | |
| outline: 2px solid #3b82f6 !important; /* blue-500 */ | |
| border-color: #3b82f6 !important; | |
| } | |
| /* Search Bar Layout Fix */ | |
| .search-row { | |
| display: flex !important; | |
| flex-direction: row !important; | |
| align-items: flex-start !important; | |
| gap: 1rem !important; | |
| flex-wrap: nowrap !important; /* Prevent wrapping */ | |
| } | |
| /* Loader Overlay */ | |
| .loader-overlay { | |
| position: absolute; inset: 0; background: rgba(255,255,255,0.8); | |
| backdrop-filter: blur(4px); z-index: 50; | |
| display: flex; flex-direction: column; align-items: center; justify-content: center; | |
| } | |
| .spinner { | |
| width: 4rem; height: 4rem; border: 4px solid #e2e8f0; | |
| border-top-color: #2563eb; border-radius: 50%; | |
| animation: spin 1s linear infinite; | |
| } | |
| @keyframes spin { to { transform: rotate(360deg); } } | |
| </style> | |
| """ | |
| NAVBAR_HTML = """ | |
| <header class="bg-white border-b border-slate-200 sticky top-0 z-40 shadow-sm w-full"> | |
| <div class="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between"> | |
| <div class="flex items-center gap-2"> | |
| <!-- User Logo Removed --> | |
| <h1 class="text-xl font-bold tracking-tight text-slate-900">dashVector</h1> | |
| </div> | |
| <div class="flex items-center gap-4"> | |
| <div class="hidden md:flex items-center gap-1.5 px-3 py-1 bg-slate-100 rounded-full border border-slate-200"> | |
| <span class="material-symbols-outlined text-slate-500 text-sm">database</span> | |
| <span class="text-xs font-medium text-slate-600">Dataset: <span class="font-bold text-slate-800">MS Marco (25k)</span></span> | |
| </div> | |
| </div> | |
| </div> | |
| </header> | |
| """ | |
| FOOTER_INFO_HTML = """ | |
| <div class="grid grid-cols-1 md:grid-cols-3 gap-4 text-sm mt-6"> | |
| <div class="bg-blue-50 border border-blue-100 p-4 rounded-xl"> | |
| <h3 class="font-semibold text-blue-900 mb-2 flex items-center gap-2"> | |
| <span class="material-symbols-outlined text-base">architecture</span> | |
| Architecture | |
| </h3> | |
| <p class="text-blue-800/80"> | |
| Improves search efficiency by using a <span class="font-bold">Router Model</span> to predict specific data shards, reducing the search space on the Vector DB. | |
| </p> | |
| </div> | |
| <div class="bg-orange-50 border border-orange-100 p-4 rounded-xl"> | |
| <h3 class="font-semibold text-orange-900 mb-2 flex items-center gap-2"> | |
| <span class="material-symbols-outlined text-base">database</span> | |
| Vector Database | |
| </h3> | |
| <p class="text-orange-800/80"> | |
| Utilizes <span class="font-bold">Qdrant</span> for high-performance vector storage and retrieval, benchmarking direct search vs. routed search across 16 shards. | |
| </p> | |
| </div> | |
| <div class="bg-purple-50 border border-purple-100 p-4 rounded-xl"> | |
| <h3 class="font-semibold text-purple-900 mb-2 flex items-center gap-2"> | |
| <span class="material-symbols-outlined text-base">psychology</span> | |
| Methodology | |
| </h3> | |
| <p class="text-purple-800/80"> | |
| Router predicts shard probabilities. Shards are iteratively added to the search scope until the <strong>cumulative confidence > 0.9</strong>, balancing accuracy and speed. | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| EMPTY_STATE_HTML = """ | |
| <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col min-h-[400px] items-center justify-center text-slate-400"> | |
| <div class="bg-slate-50 p-6 rounded-full mb-4"> | |
| <span class="material-symbols-outlined text-6xl text-slate-200">bar_chart</span> | |
| </div> | |
| <p class="text-lg font-medium text-slate-500">Ready to benchmark</p> | |
| <p class="text-sm">Enter a query above to compare routing architectures.</p> | |
| </div> | |
| """ | |
| LOADER_HTML = """ | |
| <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col min-h-[400px] relative"> | |
| <div class="loader-overlay"> | |
| <div class="spinner"></div> | |
| <p class="mt-4 text-slate-600 font-medium animate-pulse">Running inferences & calculating metrics...</p> | |
| <div class="text-xs text-slate-400 mt-2">Router Model predicting shards...</div> | |
| </div> | |
| </div> | |
| """ | |
| def generate_table_html(rows): | |
| rows_html = "" | |
| for i, row in enumerate(rows): | |
| delay = i * 100 | |
| width_pct = int(float(row['accuracy']) * 100) | |
| rows_html += f""" | |
| <tr class="hover:bg-slate-50 transition-colors fade-in" style="animation-delay: {delay}ms; opacity: 0;"> | |
| <td class="px-6 py-4 whitespace-nowrap"> | |
| <div class="flex items-center"> | |
| <div class="h-8 w-8 rounded bg-indigo-100 text-indigo-600 flex items-center justify-center mr-3 font-bold text-xs">EM</div> | |
| <div class="text-sm font-medium text-slate-900">{row['embedding']}</div> | |
| </div> | |
| </td> | |
| <td class="px-6 py-4 whitespace-nowrap"> | |
| <div class="text-sm text-slate-700 font-medium">{row['router']}</div> | |
| <div class="text-xs text-slate-400">Classifier</div> | |
| </td> | |
| <td class="px-6 py-4 whitespace-nowrap bg-blue-50/30 border-l border-r border-blue-100"> | |
| <div class="flex flex-col gap-1"> | |
| <div class="flex items-center justify-between"> | |
| <span class="text-xs text-slate-500">Time:</span> | |
| <span class="text-sm font-bold text-blue-700">{row['optimizedTime']}</span> | |
| </div> | |
| <div class="flex items-center justify-between"> | |
| <span class="text-xs text-slate-500">Shards:</span> | |
| <span class="text-xs font-mono bg-blue-100 text-blue-800 px-1.5 rounded">{row['shardsSearched']}</span> | |
| </div> | |
| <div class="w-full bg-slate-200 rounded-full h-1.5 mt-1"> | |
| <div class="bg-blue-500 h-1.5 rounded-full" style="width: {width_pct}%"></div> | |
| </div> | |
| <div class="flex justify-between text-[10px] text-slate-400 mt-0.5"> | |
| <span>Acc: {row['accuracy']}</span> | |
| <span>Conf: {row['confDisplay']}</span> | |
| </div> | |
| </div> | |
| </td> | |
| <td class="px-6 py-4 whitespace-nowrap"> | |
| <div class="flex flex-col gap-1"> | |
| <span class="text-sm font-semibold text-slate-600">{row['directTime']}</span> | |
| <span class="text-xs text-slate-400">Full Scan ({row['totalShards']} Shards)</span> | |
| </div> | |
| </td> | |
| <td class="px-6 py-4 whitespace-nowrap"> | |
| <div class="flex items-center"> | |
| <span class="text-lg font-bold text-green-600">{row['efficiency']}</span> | |
| <span class="material-symbols-outlined text-green-600 text-sm ml-1">trending_up</span> | |
| </div> | |
| <div class="text-xs text-green-700/70">Faster</div> | |
| </td> | |
| </tr> | |
| """ | |
| return f""" | |
| <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col flex-grow min-h-[500px]"> | |
| <div class="px-6 py-4 border-b border-slate-100 flex justify-between items-center bg-slate-50/50"> | |
| <h2 class="text-lg font-semibold text-slate-800 flex items-center gap-2"> | |
| <span class="material-symbols-outlined text-slate-500">table_chart</span> | |
| Performance Metrics | |
| </h2> | |
| <div class="text-xs text-slate-500 flex items-center gap-2"> | |
| <span class="flex items-center gap-1"><div class="w-2 h-2 rounded-full bg-green-500"></div> High Efficiency</span> | |
| <span class="flex items-center gap-1"><div class="w-2 h-2 rounded-full bg-slate-300"></div> Baseline</span> | |
| </div> | |
| </div> | |
| <div class="overflow-x-auto custom-scrollbar flex-grow relative"> | |
| <table class="min-w-full divide-y divide-slate-200"> | |
| <thead class="bg-slate-50 sticky top-0 z-10"> | |
| <tr> | |
| <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Embedding Model</th> | |
| <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Router Model</th> | |
| <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider bg-blue-50/50 border-l border-r border-blue-100 text-blue-800">dashVector Search (Optimized)</th> | |
| <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Direct Qdrant Search (Baseline)</th> | |
| <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider text-green-700">Efficiency Gain</th> | |
| </tr> | |
| </thead> | |
| <tbody class="bg-white divide-y divide-slate-100"> | |
| {rows_html} | |
| </tbody> | |
| </table> | |
| </div> | |
| </div> | |
| """ | |
| def show_loader(): | |
| return LOADER_HTML | |
| def run_benchmark(query): | |
| print(f"DEBUG: Starting benchmark for query: {query}") | |
| try: | |
| # Perform Search (Live) | |
| start_total = time.time() | |
| # Generate Embedding | |
| print("DEBUG: Generating embedding...") | |
| query_vec = get_embedding(query, model_name=EMBEDDING_MODEL) | |
| print("DEBUG: Embedding generated.") | |
| # Router Prediction | |
| if router: | |
| print("DEBUG: Predicting clusters...") | |
| # Now returns list of clusters and cumulative confidence | |
| target_clusters, confidence = router.predict(query_vec) | |
| print(f"DEBUG: Predicted clusters {target_clusters} with cumulative confidence {confidence}") | |
| else: | |
| print("DEBUG: No router loaded, using mock.") | |
| target_clusters, confidence = [0], 0.95 # Mock | |
| # Search | |
| print("DEBUG: Searching Qdrant...") | |
| # Now accepts list of clusters | |
| results, mode = vector_db.search_hybrid(query_vec, target_clusters, confidence) | |
| print(f"DEBUG: Search complete. Found {len(results)} results.") | |
| end_total = time.time() | |
| latency_ms = (end_total - start_total) * 1000 | |
| # Construct Data Rows | |
| # Live Row (MiniLM + Logistic Regression) | |
| shards_searched = len(target_clusters) | |
| total_shards = 16 # Updated to 16 | |
| # Estimate baseline time (mock calculation for demo if we don't run full scan) | |
| # Or we could actually run full scan if we wanted true comparison, but for speed we estimate | |
| direct_time = latency_ms * (total_shards / max(shards_searched, 1)) * 1.1 | |
| live_row = { | |
| "embedding": "MiniLM-L6-v2 (Active)", | |
| "router": "Logistic Regression", # Updated label | |
| "optimizedTime": f"{latency_ms:.1f} ms", | |
| "shardsSearched": f"{shards_searched} / {total_shards}", | |
| "totalShards": total_shards, | |
| "accuracy": f"{confidence:.2f}", | |
| "confDisplay": f"{confidence*100:.1f}%", | |
| "directTime": f"{direct_time:.1f} ms", | |
| "efficiency": f"+{((1 - latency_ms/direct_time)*100):.1f}%" | |
| } | |
| # Reference Rows (Static - Updated) | |
| ref_rows = [ | |
| { | |
| "embedding": "Gemma 300M", | |
| "router": "LightGBM", | |
| "optimizedTime": "128 ms", | |
| "shardsSearched": "9 / 16", | |
| "totalShards": 16, | |
| "accuracy": "0.97", | |
| "confDisplay": "97.1%", | |
| "directTime": "220 ms", | |
| "efficiency": "+41.8%" | |
| }, | |
| { | |
| "embedding": "Qwen 600M", | |
| "router": "Tiny MLP", | |
| "optimizedTime": "109 ms", | |
| "shardsSearched": "7 / 16", | |
| "totalShards": 16, | |
| "accuracy": "0.90", | |
| "confDisplay": "90.1%", | |
| "directTime": "235 ms", | |
| "efficiency": "+53.6%" | |
| } | |
| ] | |
| all_rows = [live_row] + ref_rows | |
| print("DEBUG: Returning final HTML.") | |
| return generate_table_html(all_rows) | |
| except Exception as e: | |
| import traceback | |
| error_msg = traceback.format_exc() | |
| print(f"CRITICAL ERROR in run_benchmark: {error_msg}") | |
| # Return Error HTML | |
| return f""" | |
| <div class="bg-red-50 border border-red-200 rounded-2xl p-6 text-red-800"> | |
| <h3 class="font-bold text-lg mb-2 flex items-center gap-2"> | |
| <span class="material-symbols-outlined">error</span> | |
| Runtime Error | |
| </h3> | |
| <p class="mb-4">An error occurred while running the benchmark:</p> | |
| <pre class="bg-red-100 p-4 rounded-lg text-xs font-mono overflow-x-auto">{error_msg}</pre> | |
| </div> | |
| """ | |
| # --- Gradio App --- | |
| with gr.Blocks(theme=gr.themes.Base(), css=None, head=HEAD_HTML) as demo: | |
| gr.HTML(NAVBAR_HTML) | |
| with gr.Column(elem_classes="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8 gap-6"): | |
| # Search Section | |
| with gr.Group(elem_classes="bg-white p-6 rounded-2xl shadow-sm border border-slate-200 mb-6"): | |
| gr.HTML('<label class="block text-sm font-medium text-slate-700 mb-2">Evaluate Search Architecture</label>') | |
| # Use a Row with custom CSS class for Flexbox layout | |
| with gr.Row(elem_classes="search-row"): | |
| query_input = gr.Textbox( | |
| placeholder="Enter a benchmark query (e.g., 'climate change impact')...", | |
| show_label=False, | |
| elem_id="custom-input", | |
| container=False, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button( | |
| "Run Benchmark", | |
| variant="primary", | |
| scale=1, | |
| elem_classes="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-3 px-6 rounded-xl shadow-md transition-all h-[50px]" # Fixed height to match input | |
| ) | |
| # Results Section | |
| results_area = gr.HTML(EMPTY_STATE_HTML) | |
| # Footer Info | |
| gr.HTML(FOOTER_INFO_HTML) | |
| # Interactions: Simplified (Single Step) | |
| submit_btn.click( | |
| run_benchmark, | |
| inputs=[query_input], | |
| outputs=[results_area] | |
| ) | |
| query_input.submit( | |
| run_benchmark, | |
| inputs=[query_input], | |
| outputs=[results_area] | |
| ) | |
| if __name__ == "__main__": | |
| # Disable queue to prevent h11 LocalProtocolError | |
| demo.launch() | |