import gradio as gr import pandas as pd import os import time from src.vector_db import UnifiedQdrant from src.router import LearnedRouter from src.comparison import ComparisonEngine from config import COLLECTION_NAME, NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS # --- Initialization --- print("Initializing dashVectorspace App...") # 1. Initialize DB # Note: In a real HF Space, secrets are in os.environ db = UnifiedQdrant( collection_name=COLLECTION_NAME, vector_size=384, # Assuming MiniLM for demo num_clusters=NUM_CLUSTERS, freshness_shard_id=FRESHNESS_SHARD_ID ) db.initialize() # 2. Initialize Router ROUTER_PATH = "models/router_v1.pkl" if os.path.exists(ROUTER_PATH): router = LearnedRouter.load(ROUTER_PATH) else: print("WARNING: Router model not found. Creating a DUMMY router for demo UI.") router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) # We can't really predict without training, but let's mock it or fail gracefully. # For the UI to load, we need an object. # If we try to predict, it will crash if not trained. # Let's mock the predict method if not trained. router.predict = lambda x: (0, 0.99) # Mock prediction: Cluster 0, High Confidence # 3. Initialize Engine engine = ComparisonEngine(db, router, embedding_model_name="minilm") # --- UI Logic --- def run_comparison(query): if not query: return "Please enter a query.", None, None, None, None # Run Direct Search res_direct = engine.direct_search(query) # Run xVector Search res_xvector = engine.xvector_search(query) # Format Results def format_results(res_dict): points = res_dict["results"] text_res = "" for p in points: # Payload might be dict or object depending on client version/mock payload = p.payload text = payload.get("text", "No text") if payload else "No text" score = p.score text_res += f"- [{score:.4f}] {text[:100]}...\n" return text_res out_direct = format_results(res_direct) out_xvector = format_results(res_xvector) # Metrics metrics_df = pd.DataFrame({ "Metric": ["Latency (ms)", "Shards Searched"], "Brute Force": [res_direct["latency_ms"], res_direct["shards_searched"]], "xVector": [res_xvector["latency_ms"], res_xvector["shards_searched"]] }) # Compute Savings savings = (1 - (res_xvector["shards_searched"] / res_direct["shards_searched"])) * 100 savings_text = f"Compute Savings: {savings:.1f}%" # Telemetry telemetry = f""" **Search Mode:** {res_xvector['mode']} **Router Confidence:** {res_xvector.get('confidence', 0):.4f} **Target Cluster:** {res_xvector.get('target_cluster', 'N/A')} **Shards Scanned:** {res_xvector['shards_searched']} vs {res_direct['shards_searched']} """ return out_direct, out_xvector, metrics_df, savings_text, telemetry # --- Gradio Layout --- with gr.Blocks(title="dashVectorspace: Learned Hybrid Retrieval", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 dashVectorspace: Learned Hybrid Retrieval Engine") gr.Markdown("Comparing **Brute Force Vector Search** vs **xVector (Learned Router + Custom Sharding)**.") with gr.Row(): query_input = gr.Textbox(label="Enter your query", placeholder="e.g., What is the impact of AI on healthcare?", lines=2) submit_btn = gr.Button("🚀 Run Comparison", variant="primary") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🐢 Brute Force (Standard)") out_baseline = gr.Textbox(label="Results", lines=10) with gr.Column(scale=1): gr.Markdown("### ⚡ xVector (Optimized)") out_optimized = gr.Textbox(label="Results", lines=10) with gr.Row(): with gr.Column(): metrics_plot = gr.BarPlot( x="Metric", y="Brute Force", title="Performance Comparison", tooltip=["Metric", "Brute Force", "xVector"], # Gradio BarPlot expects long format usually, but let's try simple DF display first if BarPlot is complex ) # Actually, let's use a simple DataFrame for metrics first, it's cleaner. metrics_table = gr.Dataframe(label="Performance Metrics") with gr.Column(): savings_display = gr.Markdown("### Compute Savings: --%") telemetry_display = gr.Markdown("### Telemetry\nWaiting for query...") submit_btn.click( run_comparison, inputs=[query_input], outputs=[out_baseline, out_optimized, metrics_table, savings_display, telemetry_display] ) if __name__ == "__main__": demo.launch()