File size: 4,855 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import pandas as pd
import os
import time
from src.vector_db import UnifiedQdrant
from src.router import LearnedRouter
from src.comparison import ComparisonEngine
from config import COLLECTION_NAME, NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS

# --- Initialization ---
print("Initializing dashVectorspace App...")

# 1. Initialize DB
# Note: In a real HF Space, secrets are in os.environ
db = UnifiedQdrant(
    collection_name=COLLECTION_NAME,
    vector_size=384, # Assuming MiniLM for demo
    num_clusters=NUM_CLUSTERS,
    freshness_shard_id=FRESHNESS_SHARD_ID
)
db.initialize()

# 2. Initialize Router
ROUTER_PATH = "models/router_v1.pkl"
if os.path.exists(ROUTER_PATH):
    router = LearnedRouter.load(ROUTER_PATH)
else:
    print("WARNING: Router model not found. Creating a DUMMY router for demo UI.")
    router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
    # We can't really predict without training, but let's mock it or fail gracefully.
    # For the UI to load, we need an object.
    # If we try to predict, it will crash if not trained.
    # Let's mock the predict method if not trained.
    router.predict = lambda x: (0, 0.99) # Mock prediction: Cluster 0, High Confidence

# 3. Initialize Engine
engine = ComparisonEngine(db, router, embedding_model_name="minilm")

# --- UI Logic ---
def run_comparison(query):
    if not query:
        return "Please enter a query.", None, None, None, None

    # Run Direct Search
    res_direct = engine.direct_search(query)
    
    # Run xVector Search
    res_xvector = engine.xvector_search(query)
    
    # Format Results
    def format_results(res_dict):
        points = res_dict["results"]
        text_res = ""
        for p in points:
            # Payload might be dict or object depending on client version/mock
            payload = p.payload
            text = payload.get("text", "No text") if payload else "No text"
            score = p.score
            text_res += f"- [{score:.4f}] {text[:100]}...\n"
        return text_res

    out_direct = format_results(res_direct)
    out_xvector = format_results(res_xvector)
    
    # Metrics
    metrics_df = pd.DataFrame({
        "Metric": ["Latency (ms)", "Shards Searched"],
        "Brute Force": [res_direct["latency_ms"], res_direct["shards_searched"]],
        "xVector": [res_xvector["latency_ms"], res_xvector["shards_searched"]]
    })
    
    # Compute Savings
    savings = (1 - (res_xvector["shards_searched"] / res_direct["shards_searched"])) * 100
    savings_text = f"Compute Savings: {savings:.1f}%"
    
    # Telemetry
    telemetry = f"""
    **Search Mode:** {res_xvector['mode']}
    **Router Confidence:** {res_xvector.get('confidence', 0):.4f}
    **Target Cluster:** {res_xvector.get('target_cluster', 'N/A')}
    **Shards Scanned:** {res_xvector['shards_searched']} vs {res_direct['shards_searched']}
    """
    
    return out_direct, out_xvector, metrics_df, savings_text, telemetry

# --- Gradio Layout ---
with gr.Blocks(title="dashVectorspace: Learned Hybrid Retrieval", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🚀 dashVectorspace: Learned Hybrid Retrieval Engine")
    gr.Markdown("Comparing **Brute Force Vector Search** vs **xVector (Learned Router + Custom Sharding)**.")
    
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query", placeholder="e.g., What is the impact of AI on healthcare?", lines=2)
        submit_btn = gr.Button("🚀 Run Comparison", variant="primary")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 🐢 Brute Force (Standard)")
            out_baseline = gr.Textbox(label="Results", lines=10)
        
        with gr.Column(scale=1):
            gr.Markdown("### ⚡ xVector (Optimized)")
            out_optimized = gr.Textbox(label="Results", lines=10)
            
    with gr.Row():
        with gr.Column():
            metrics_plot = gr.BarPlot(
                x="Metric",
                y="Brute Force",
                title="Performance Comparison",
                tooltip=["Metric", "Brute Force", "xVector"],
                # Gradio BarPlot expects long format usually, but let's try simple DF display first if BarPlot is complex
            )
            # Actually, let's use a simple DataFrame for metrics first, it's cleaner.
            metrics_table = gr.Dataframe(label="Performance Metrics")
            
        with gr.Column():
            savings_display = gr.Markdown("### Compute Savings: --%")
            telemetry_display = gr.Markdown("### Telemetry\nWaiting for query...")

    submit_btn.click(
        run_comparison, 
        inputs=[query_input], 
        outputs=[out_baseline, out_optimized, metrics_table, savings_display, telemetry_display]
    )

if __name__ == "__main__":
    demo.launch()