import gradio as gr import os import json import time import pandas as pd import numpy as np from src.vector_db import UnifiedQdrant from src.router import LearnedRouter from src.data_pipeline import get_embedding from config import ( COLLECTIONS, EMBEDDING_MODELS, ROUTER_MODELS, NUM_CLUSTERS, FRESHNESS_SHARD_ID ) # --- Initialize Backend --- print("Initializing Backend...") # 1. Vector DB Clients # We need clients for both Prod (Sharded) and Base (Unsharded) for each model dbs = {} for model_key, cols in COLLECTIONS.items(): # Load Dimension from JSON try: with open(f"models/model_info_{model_key}.json", "r") as f: vec_size = json.load(f)["dim"] except: print(f"Warning: Could not load model info for {model_key}. Using default 384.") vec_size = 384 # Load Shard Sizes try: with open(f"models/shard_sizes_{model_key}.json", "r") as f: shard_sizes = json.load(f) # Convert keys to int shard_sizes = {int(k): v for k, v in shard_sizes.items()} dbs[f"{model_key}_sizes"] = shard_sizes except: print(f"Warning: Could not load shard sizes for {model_key}.") dbs[f"{model_key}_sizes"] = {} # Prod print(f"Initializing DB: {cols['prod']}...") db_prod = UnifiedQdrant(cols['prod'], vector_size=vec_size, num_clusters=NUM_CLUSTERS, freshness_shard_id=FRESHNESS_SHARD_ID) db_prod.initialize(is_baseline=False) dbs[f"{model_key}_prod"] = db_prod # Base print(f"Initializing DB: {cols['base']}...") db_base = UnifiedQdrant(cols['base'], vector_size=vec_size, num_clusters=1) db_base.initialize(is_baseline=True) dbs[f"{model_key}_base"] = db_base # 2. Load Routers routers = {} for model_key in EMBEDDING_MODELS.keys(): for router_type in ROUTER_MODELS: router_path = f"models/router_{model_key}_{router_type}.pkl" try: print(f"Loading Router: {router_path}...") routers[f"{model_key}_{router_type}"] = LearnedRouter.load(router_path) except Exception as e: print(f"Warning: Could not load {router_path}: {e}. Using None.") routers[f"{model_key}_{router_type}"] = None # --- HTML Templates --- HEAD_HTML = """ """ NAVBAR_HTML = """
Logo

dashVector Experiment Matrix

""" FOOTER_INFO_HTML = """

architecture Architecture

Improves search efficiency by using a Router Model to predict specific data shards.

database Vector Database

Utilizes Qdrant for high-performance vector storage and retrieval.

psychology Methodology

Shards are iteratively added until cumulative confidence > 0.9.

""" EMPTY_STATE_HTML = """
bar_chart

Ready to benchmark

Enter a query above to compare routing architectures.

""" def generate_table_html(rows): rows_html = "" for i, row in enumerate(rows): delay = i * 50 # Faster stagger width_pct = int(float(row['accuracy']) * 100) rows_html += f"""
{row['embedding_name']} {row['dims']} dim
{row['router_name']} {row['router_desc']}
Total Latency {row['optimizedTime']}
Router Overhead {row['overhead']}
Scanned: {row['shardsSearched']}
check_circle Router Conf: {row['confDisplay']}
Time: {row['directTime']}
Recall@10: {row['recall_10']}
Recall@5: {row['recall_5']}
Time: {row['baselineTime']}
Single Index
{row['efficiency_sharded']} bolt
[vs Sharded]
{row['efficiency_base']} bolt
[vs Base]
""" return f"""

grid_view Experiment Matrix (3x3)

Optimized Baseline
{rows_html}
Embedding Model Router Model dashVector Performance (Optimized) Direct Search Efficiency Gain
With Sharding (16) No Sharding (1)
""" def run_benchmark(query): print(f"DEBUG: Starting benchmark for query: {query}") rows = [] # Loop over Embedding Models for model_key, model_name in EMBEDDING_MODELS.items(): print(f"--- Processing {model_key} ---") # 1. Generate Embedding # Note: This might be slow. try: query_vec = get_embedding(query, model_name=model_name) except Exception as e: print(f"Error generating embedding for {model_key}: {e}") continue dims = len(query_vec) # 2. Run Baseline Search (Unsharded) # We run this once per embedding model db_base = dbs.get(f"{model_key}_base") start_base = time.time() if db_base: base_results = db_base.search_baseline(query_vec) base_ids = set(p.id for p in base_results) else: base_results = [] base_ids = set() end_base = time.time() baseline_time_ms = (end_base - start_base) * 1000 # 3. Reference: Direct Sharded Search (Full Scan on Prod) # This gives us the "With Sharding" latency db_prod = dbs.get(f"{model_key}_prod") if db_prod: start_sharded = time.time() # Calling search_baseline on db_prod (UnifiedQdrant) performs a full scan if no shard selector _ = db_prod.search_baseline(query_vec) end_sharded = time.time() direct_sharded_time_ms = (end_sharded - start_sharded) * 1000 else: direct_sharded_time_ms = baseline_time_ms * 1.2 # Fallback # 3. Loop over Router Models for router_type in ROUTER_MODELS: router_key = f"{model_key}_{router_type}" router = routers.get(router_key) db_prod = dbs.get(f"{model_key}_prod") if not router or not db_prod: # Mock if missing target_clusters = [0, 1, 2] confidence = 0.85 overhead_ms = 0.5 prod_results = [] latency_ms = 50 else: # Predict start_router = time.time() target_clusters, confidence = router.predict(query_vec) end_router = time.time() overhead_ms = (end_router - start_router) * 1000 # Search Prod start_search = time.time() prod_results, _ = db_prod.search_hybrid(query_vec, target_clusters, confidence) end_search = time.time() latency_ms = (end_search - start_search) * 1000 + overhead_ms # Calculate Vectors Scanned shard_sizes = dbs.get(f"{model_key}_sizes", {}) vectors_scanned = sum(shard_sizes.get(c, 0) for c in target_clusters) total_vectors = sum(shard_sizes.values()) if shard_sizes else 1000 # Default to 1k if missing vectors_scanned_pct = (vectors_scanned / total_vectors) * 100 if total_vectors > 0 else 0 # Calculate Recall for Optimized (vs Baseline) # Use content matching (payload) because IDs might differ if indexed separately base_contents = set(p.payload.get('text', str(p.payload)) for p in base_results) prod_contents = set(p.payload.get('text', str(p.payload)) for p in prod_results) if base_contents: intersection_10 = len(base_contents.intersection(prod_contents)) recall_10 = (intersection_10 / len(base_contents)) * 100 # Recall@5 base_contents_5 = set(p.payload.get('text', str(p.payload)) for p in base_results[:5]) prod_contents_5 = set(p.payload.get('text', str(p.payload)) for p in prod_results[:5]) if base_contents_5: intersection_5 = len(base_contents_5.intersection(prod_contents_5)) recall_5 = (intersection_5 / len(base_contents_5)) * 100 else: recall_5 = 0.0 else: recall_10 = 0.0 recall_5 = 0.0 # Efficiency Gain: (Reference - Optimized) / Reference # 1. vs Direct Sharded if direct_sharded_time_ms > 0: eff_gain_sharded = ((direct_sharded_time_ms - latency_ms) / direct_sharded_time_ms) * 100 else: eff_gain_sharded = 0.0 # 2. vs Base (No Sharding) if baseline_time_ms > 0: eff_gain_base = ((baseline_time_ms - latency_ms) / baseline_time_ms) * 100 else: eff_gain_base = 0.0 # Formatting row = { "embedding_name": "MiniLM-L6-v2" if model_key == "minilm" else ("BGE-Small-en-v1.5" if model_key == "bge" else "Qwen2.5-0.5B-Instruct"), "dims": dims, "router_name": "Logistic Regression" if router_type == "logistic" else ("LightGBM" if router_type == "lightgbm" else "Tiny MLP"), "router_desc": "Linear" if router_type == "logistic" else ("Gradient Boosting" if router_type == "lightgbm" else "Neural Net"), "optimizedTime": f"{latency_ms:.1f} ms", "overhead": f"{overhead_ms:.1f} ms", "shardsSearched": f"{vectors_scanned_pct:.1f}% ({len(target_clusters)}/{NUM_CLUSTERS} shards)", "accuracy": f"{confidence:.2f}", "confDisplay": f"{confidence*100:.1f}%", "directTime": f"{direct_sharded_time_ms:.1f} ms", "baselineTime": f"{baseline_time_ms:.1f} ms", "baselineTime": f"{baseline_time_ms:.1f} ms", "recall_10": f"{recall_10:.1f}%", "recall_5": f"{recall_5:.1f}%", "efficiency_sharded": f"{eff_gain_sharded:.1f}%", "efficiency_base": f"{eff_gain_base:.1f}%" } rows.append(row) return generate_table_html(rows) with gr.Blocks(theme=gr.themes.Base(), css=None, head=HEAD_HTML) as demo: gr.HTML(NAVBAR_HTML) with gr.Column(elem_classes="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8 gap-6"): with gr.Group(elem_classes="bg-white p-6 rounded-2xl shadow-sm border border-slate-200 mb-6"): gr.HTML('') with gr.Row(elem_classes="search-row"): query_input = gr.Textbox(placeholder="Enter a benchmark query...", show_label=False, elem_id="custom-input", container=False, scale=4) submit_btn = gr.Button("Run Benchmark", variant="primary", scale=1, elem_classes="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-3 px-6 rounded-xl shadow-md transition-all h-[50px]") results_area = gr.HTML(EMPTY_STATE_HTML) gr.HTML(FOOTER_INFO_HTML) submit_btn.click(run_benchmark, inputs=[query_input], outputs=[results_area]) query_input.submit(run_benchmark, inputs=[query_input], outputs=[results_area]) if __name__ == "__main__": demo.launch()