Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| import time | |
| from src.vector_db import UnifiedQdrant | |
| from src.router import LearnedRouter | |
| from src.comparison import ComparisonEngine | |
| from config import COLLECTION_NAME, NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS | |
| # --- Initialization --- | |
| print("Initializing dashVectorspace App...") | |
| # 1. Initialize DB | |
| # Note: In a real HF Space, secrets are in os.environ | |
| db = UnifiedQdrant( | |
| collection_name=COLLECTION_NAME, | |
| vector_size=384, # Assuming MiniLM for demo | |
| num_clusters=NUM_CLUSTERS, | |
| freshness_shard_id=FRESHNESS_SHARD_ID | |
| ) | |
| db.initialize() | |
| # 2. Initialize Router | |
| ROUTER_PATH = "models/router_v1.pkl" | |
| if os.path.exists(ROUTER_PATH): | |
| router = LearnedRouter.load(ROUTER_PATH) | |
| else: | |
| print("WARNING: Router model not found. Creating a DUMMY router for demo UI.") | |
| router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| # We can't really predict without training, but let's mock it or fail gracefully. | |
| # For the UI to load, we need an object. | |
| # If we try to predict, it will crash if not trained. | |
| # Let's mock the predict method if not trained. | |
| router.predict = lambda x: (0, 0.99) # Mock prediction: Cluster 0, High Confidence | |
| # 3. Initialize Engine | |
| engine = ComparisonEngine(db, router, embedding_model_name="minilm") | |
| # --- UI Logic --- | |
| def run_comparison(query): | |
| if not query: | |
| return "Please enter a query.", None, None, None, None | |
| # Run Direct Search | |
| res_direct = engine.direct_search(query) | |
| # Run xVector Search | |
| res_xvector = engine.xvector_search(query) | |
| # Format Results | |
| def format_results(res_dict): | |
| points = res_dict["results"] | |
| text_res = "" | |
| for p in points: | |
| # Payload might be dict or object depending on client version/mock | |
| payload = p.payload | |
| text = payload.get("text", "No text") if payload else "No text" | |
| score = p.score | |
| text_res += f"- [{score:.4f}] {text[:100]}...\n" | |
| return text_res | |
| out_direct = format_results(res_direct) | |
| out_xvector = format_results(res_xvector) | |
| # Metrics | |
| metrics_df = pd.DataFrame({ | |
| "Metric": ["Latency (ms)", "Shards Searched"], | |
| "Brute Force": [res_direct["latency_ms"], res_direct["shards_searched"]], | |
| "xVector": [res_xvector["latency_ms"], res_xvector["shards_searched"]] | |
| }) | |
| # Compute Savings | |
| savings = (1 - (res_xvector["shards_searched"] / res_direct["shards_searched"])) * 100 | |
| savings_text = f"Compute Savings: {savings:.1f}%" | |
| # Telemetry | |
| telemetry = f""" | |
| **Search Mode:** {res_xvector['mode']} | |
| **Router Confidence:** {res_xvector.get('confidence', 0):.4f} | |
| **Target Cluster:** {res_xvector.get('target_cluster', 'N/A')} | |
| **Shards Scanned:** {res_xvector['shards_searched']} vs {res_direct['shards_searched']} | |
| """ | |
| return out_direct, out_xvector, metrics_df, savings_text, telemetry | |
| # --- Gradio Layout --- | |
| with gr.Blocks(title="dashVectorspace: Learned Hybrid Retrieval", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🚀 dashVectorspace: Learned Hybrid Retrieval Engine") | |
| gr.Markdown("Comparing **Brute Force Vector Search** vs **xVector (Learned Router + Custom Sharding)**.") | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Enter your query", placeholder="e.g., What is the impact of AI on healthcare?", lines=2) | |
| submit_btn = gr.Button("🚀 Run Comparison", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🐢 Brute Force (Standard)") | |
| out_baseline = gr.Textbox(label="Results", lines=10) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚡ xVector (Optimized)") | |
| out_optimized = gr.Textbox(label="Results", lines=10) | |
| with gr.Row(): | |
| with gr.Column(): | |
| metrics_plot = gr.BarPlot( | |
| x="Metric", | |
| y="Brute Force", | |
| title="Performance Comparison", | |
| tooltip=["Metric", "Brute Force", "xVector"], | |
| # Gradio BarPlot expects long format usually, but let's try simple DF display first if BarPlot is complex | |
| ) | |
| # Actually, let's use a simple DataFrame for metrics first, it's cleaner. | |
| metrics_table = gr.Dataframe(label="Performance Metrics") | |
| with gr.Column(): | |
| savings_display = gr.Markdown("### Compute Savings: --%") | |
| telemetry_display = gr.Markdown("### Telemetry\nWaiting for query...") | |
| submit_btn.click( | |
| run_comparison, | |
| inputs=[query_input], | |
| outputs=[out_baseline, out_optimized, metrics_table, savings_display, telemetry_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |