dashVectorSpace / app.py
justmotes's picture
Deploy dashVectorspace v1 (Full)
b92d96d
raw
history blame
4.86 kB
import gradio as gr
import pandas as pd
import os
import time
from src.vector_db import UnifiedQdrant
from src.router import LearnedRouter
from src.comparison import ComparisonEngine
from config import COLLECTION_NAME, NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS
# --- Initialization ---
print("Initializing dashVectorspace App...")
# 1. Initialize DB
# Note: In a real HF Space, secrets are in os.environ
db = UnifiedQdrant(
collection_name=COLLECTION_NAME,
vector_size=384, # Assuming MiniLM for demo
num_clusters=NUM_CLUSTERS,
freshness_shard_id=FRESHNESS_SHARD_ID
)
db.initialize()
# 2. Initialize Router
ROUTER_PATH = "models/router_v1.pkl"
if os.path.exists(ROUTER_PATH):
router = LearnedRouter.load(ROUTER_PATH)
else:
print("WARNING: Router model not found. Creating a DUMMY router for demo UI.")
router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
# We can't really predict without training, but let's mock it or fail gracefully.
# For the UI to load, we need an object.
# If we try to predict, it will crash if not trained.
# Let's mock the predict method if not trained.
router.predict = lambda x: (0, 0.99) # Mock prediction: Cluster 0, High Confidence
# 3. Initialize Engine
engine = ComparisonEngine(db, router, embedding_model_name="minilm")
# --- UI Logic ---
def run_comparison(query):
if not query:
return "Please enter a query.", None, None, None, None
# Run Direct Search
res_direct = engine.direct_search(query)
# Run xVector Search
res_xvector = engine.xvector_search(query)
# Format Results
def format_results(res_dict):
points = res_dict["results"]
text_res = ""
for p in points:
# Payload might be dict or object depending on client version/mock
payload = p.payload
text = payload.get("text", "No text") if payload else "No text"
score = p.score
text_res += f"- [{score:.4f}] {text[:100]}...\n"
return text_res
out_direct = format_results(res_direct)
out_xvector = format_results(res_xvector)
# Metrics
metrics_df = pd.DataFrame({
"Metric": ["Latency (ms)", "Shards Searched"],
"Brute Force": [res_direct["latency_ms"], res_direct["shards_searched"]],
"xVector": [res_xvector["latency_ms"], res_xvector["shards_searched"]]
})
# Compute Savings
savings = (1 - (res_xvector["shards_searched"] / res_direct["shards_searched"])) * 100
savings_text = f"Compute Savings: {savings:.1f}%"
# Telemetry
telemetry = f"""
**Search Mode:** {res_xvector['mode']}
**Router Confidence:** {res_xvector.get('confidence', 0):.4f}
**Target Cluster:** {res_xvector.get('target_cluster', 'N/A')}
**Shards Scanned:** {res_xvector['shards_searched']} vs {res_direct['shards_searched']}
"""
return out_direct, out_xvector, metrics_df, savings_text, telemetry
# --- Gradio Layout ---
with gr.Blocks(title="dashVectorspace: Learned Hybrid Retrieval", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🚀 dashVectorspace: Learned Hybrid Retrieval Engine")
gr.Markdown("Comparing **Brute Force Vector Search** vs **xVector (Learned Router + Custom Sharding)**.")
with gr.Row():
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., What is the impact of AI on healthcare?", lines=2)
submit_btn = gr.Button("🚀 Run Comparison", variant="primary")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🐢 Brute Force (Standard)")
out_baseline = gr.Textbox(label="Results", lines=10)
with gr.Column(scale=1):
gr.Markdown("### ⚡ xVector (Optimized)")
out_optimized = gr.Textbox(label="Results", lines=10)
with gr.Row():
with gr.Column():
metrics_plot = gr.BarPlot(
x="Metric",
y="Brute Force",
title="Performance Comparison",
tooltip=["Metric", "Brute Force", "xVector"],
# Gradio BarPlot expects long format usually, but let's try simple DF display first if BarPlot is complex
)
# Actually, let's use a simple DataFrame for metrics first, it's cleaner.
metrics_table = gr.Dataframe(label="Performance Metrics")
with gr.Column():
savings_display = gr.Markdown("### Compute Savings: --%")
telemetry_display = gr.Markdown("### Telemetry\nWaiting for query...")
submit_btn.click(
run_comparison,
inputs=[query_input],
outputs=[out_baseline, out_optimized, metrics_table, savings_display, telemetry_display]
)
if __name__ == "__main__":
demo.launch()