justmotes commited on
Commit
51fc709
·
verified ·
1 Parent(s): b9df6ef

Deploy 9-Row Benchmark (via API)

Browse files
app.py CHANGED
@@ -1,34 +1,68 @@
1
  import gradio as gr
2
  import os
 
3
  import time
4
- import random
5
  import pandas as pd
 
6
  from src.vector_db import UnifiedQdrant
7
  from src.router import LearnedRouter
8
  from src.data_pipeline import get_embedding
9
-
10
- # --- Configuration ---
11
- COLLECTION_NAME = "dashVector_v1"
12
- VECTOR_SIZE = 384 # MiniLM-L6-v2
13
- NUM_CLUSTERS = 16
14
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
15
 
16
  # --- Initialize Backend ---
17
- # We initialize once at startup
18
- vector_db = UnifiedQdrant(COLLECTION_NAME, VECTOR_SIZE, NUM_CLUSTERS)
19
- vector_db.initialize()
20
 
21
- # Load Router (Ensure it exists, else mock/warn)
22
- ROUTER_PATH = "models/router_v1.pkl"
23
- try:
24
- router = LearnedRouter.load(ROUTER_PATH)
25
- except Exception as e:
26
- print(f"Warning: Could not load router: {e}. Using dummy router for UI demo if needed.")
27
- router = None
 
 
 
 
28
 
29
- # --- HTML Templates (Extracted from dashVector_benchmark.html) ---
 
 
 
 
 
 
 
 
 
30
 
31
- # --- HTML Templates (Extracted from dashVector_benchmark.html) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  HEAD_HTML = """
34
  <script src="https://cdn.tailwindcss.com"></script>
@@ -38,50 +72,22 @@ HEAD_HTML = """
38
  body { font-family: 'Inter', sans-serif; background-color: #f8f9fa; }
39
  .fade-in { animation: fadeIn 0.5s ease-out forwards; }
40
  @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
41
- /* Hide Gradio footer */
42
  footer { display: none !important; }
43
  .gradio-container { max-width: 100% !important; padding: 0 !important; margin: 0 !important; background-color: #f8f9fa; }
44
- /* Custom Scrollbar */
45
  .custom-scrollbar::-webkit-scrollbar { height: 8px; width: 8px; }
46
  .custom-scrollbar::-webkit-scrollbar-track { background: #f1f1f1; }
47
  .custom-scrollbar::-webkit-scrollbar-thumb { background: #c1c1c1; border-radius: 4px; }
48
  .custom-scrollbar::-webkit-scrollbar-thumb:hover { background: #a8a8a8; }
49
-
50
- /* Overwrite Gradio Input Styles to match Reference */
51
  #custom-input textarea {
52
- background-color: white !important;
53
- border: 1px solid #cbd5e1 !important;
54
- border-radius: 0.75rem !important; /* rounded-xl */
55
- padding: 0.75rem 1rem !important;
56
- font-size: 1rem !important;
57
- box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05) !important;
58
- height: 50px !important; /* Fixed height for alignment */
59
- }
60
- #custom-input textarea:focus {
61
- outline: 2px solid #3b82f6 !important; /* blue-500 */
62
- border-color: #3b82f6 !important;
63
- }
64
-
65
- /* Search Bar Layout Fix */
66
- .search-row {
67
- display: flex !important;
68
- flex-direction: row !important;
69
- align-items: flex-start !important;
70
- gap: 1rem !important;
71
- flex-wrap: nowrap !important; /* Prevent wrapping */
72
- }
73
-
74
- /* Loader Overlay */
75
- .loader-overlay {
76
- position: absolute; inset: 0; background: rgba(255,255,255,0.8);
77
- backdrop-filter: blur(4px); z-index: 50;
78
- display: flex; flex-direction: column; align-items: center; justify-content: center;
79
- }
80
- .spinner {
81
- width: 4rem; height: 4rem; border: 4px solid #e2e8f0;
82
- border-top-color: #2563eb; border-radius: 50%;
83
- animation: spin 1s linear infinite;
84
  }
 
 
 
 
85
  @keyframes spin { to { transform: rotate(360deg); } }
86
  </style>
87
  """
@@ -89,9 +95,9 @@ HEAD_HTML = """
89
  NAVBAR_HTML = """
90
  <header class="bg-white border-b border-slate-200 sticky top-0 z-40 shadow-sm w-full">
91
  <div class="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between">
92
- <div class="flex items-center gap-2">
93
- <!-- User Logo Removed -->
94
- <h1 class="text-xl font-bold tracking-tight text-slate-900">dashVector</h1>
95
  </div>
96
  <div class="flex items-center gap-4">
97
  <div class="hidden md:flex items-center gap-1.5 px-3 py-1 bg-slate-100 rounded-full border border-slate-200">
@@ -106,130 +112,117 @@ NAVBAR_HTML = """
106
  FOOTER_INFO_HTML = """
107
  <div class="grid grid-cols-1 md:grid-cols-3 gap-4 text-sm mt-6">
108
  <div class="bg-blue-50 border border-blue-100 p-4 rounded-xl">
109
- <h3 class="font-semibold text-blue-900 mb-2 flex items-center gap-2">
110
- <span class="material-symbols-outlined text-base">architecture</span>
111
- Architecture
112
- </h3>
113
- <p class="text-blue-800/80">
114
- Improves search efficiency by using a <span class="font-bold">Router Model</span> to predict specific data shards, reducing the search space on the Vector DB.
115
- </p>
116
  </div>
117
  <div class="bg-orange-50 border border-orange-100 p-4 rounded-xl">
118
- <h3 class="font-semibold text-orange-900 mb-2 flex items-center gap-2">
119
- <span class="material-symbols-outlined text-base">database</span>
120
- Vector Database
121
- </h3>
122
- <p class="text-orange-800/80">
123
- Utilizes <span class="font-bold">Qdrant</span> for high-performance vector storage and retrieval, benchmarking direct search vs. routed search across 16 shards.
124
- </p>
125
  </div>
126
  <div class="bg-purple-50 border border-purple-100 p-4 rounded-xl">
127
- <h3 class="font-semibold text-purple-900 mb-2 flex items-center gap-2">
128
- <span class="material-symbols-outlined text-base">psychology</span>
129
- Methodology
130
- </h3>
131
- <p class="text-purple-800/80">
132
- Router predicts shard probabilities. Shards are iteratively added to the search scope until the <strong>cumulative confidence > 0.9</strong>, balancing accuracy and speed.
133
- </p>
134
  </div>
135
  </div>
136
  """
137
 
138
  EMPTY_STATE_HTML = """
139
  <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col min-h-[400px] items-center justify-center text-slate-400">
140
- <div class="bg-slate-50 p-6 rounded-full mb-4">
141
- <span class="material-symbols-outlined text-6xl text-slate-200">bar_chart</span>
142
- </div>
143
  <p class="text-lg font-medium text-slate-500">Ready to benchmark</p>
144
  <p class="text-sm">Enter a query above to compare routing architectures.</p>
145
  </div>
146
  """
147
 
148
- LOADER_HTML = """
149
- <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col min-h-[400px] relative">
150
- <div class="loader-overlay">
151
- <div class="spinner"></div>
152
- <p class="mt-4 text-slate-600 font-medium animate-pulse">Running inferences & calculating metrics...</p>
153
- <div class="text-xs text-slate-400 mt-2">Router Model predicting shards...</div>
154
- </div>
155
- </div>
156
- """
157
-
158
  def generate_table_html(rows):
159
  rows_html = ""
160
  for i, row in enumerate(rows):
161
- delay = i * 100
162
  width_pct = int(float(row['accuracy']) * 100)
163
 
164
  rows_html += f"""
165
  <tr class="hover:bg-slate-50 transition-colors fade-in" style="animation-delay: {delay}ms; opacity: 0;">
166
- <td class="px-6 py-4 whitespace-nowrap">
167
- <div class="flex items-center">
168
- <div class="h-8 w-8 rounded bg-indigo-100 text-indigo-600 flex items-center justify-center mr-3 font-bold text-xs">EM</div>
169
- <div class="text-sm font-medium text-slate-900">{row['embedding']}</div>
170
  </div>
171
  </td>
172
- <td class="px-6 py-4 whitespace-nowrap">
173
- <div class="text-sm text-slate-700 font-medium">{row['router']}</div>
174
- <div class="text-xs text-slate-400">Classifier</div>
 
 
175
  </td>
176
- <td class="px-6 py-4 whitespace-nowrap bg-blue-50/30 border-l border-r border-blue-100">
177
- <div class="flex flex-col gap-1">
178
- <div class="flex items-center justify-between">
179
- <span class="text-xs text-slate-500">Time:</span>
180
- <span class="text-sm font-bold text-blue-700">{row['optimizedTime']}</span>
181
- </div>
182
- <div class="flex items-center justify-between">
183
- <span class="text-xs text-slate-500">Shards:</span>
184
- <span class="text-xs font-mono bg-blue-100 text-blue-800 px-1.5 rounded">{row['shardsSearched']}</span>
 
 
185
  </div>
186
- <div class="w-full bg-slate-200 rounded-full h-1.5 mt-1">
187
- <div class="bg-blue-500 h-1.5 rounded-full" style="width: {width_pct}%"></div>
 
 
 
 
 
188
  </div>
189
- <div class="flex justify-between text-[10px] text-slate-400 mt-0.5">
190
- <span>Acc: {row['accuracy']}</span>
191
- <span>Conf: {row['confDisplay']}</span>
192
  </div>
193
  </div>
194
  </td>
195
- <td class="px-6 py-4 whitespace-nowrap">
196
- <div class="flex flex-col gap-1">
197
- <span class="text-sm font-semibold text-slate-600">{row['directTime']}</span>
198
- <span class="text-xs text-slate-400">Full Scan ({row['totalShards']} Shards)</span>
 
 
 
199
  </div>
200
  </td>
201
- <td class="px-6 py-4 whitespace-nowrap">
202
- <div class="flex items-center">
203
- <span class="text-lg font-bold text-green-600">{row['efficiency']}</span>
204
- <span class="material-symbols-outlined text-green-600 text-sm ml-1">trending_up</span>
 
 
 
205
  </div>
206
- <div class="text-xs text-green-700/70">Faster</div>
207
  </td>
208
  </tr>
209
  """
210
 
211
  return f"""
212
- <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col flex-grow min-h-[500px]">
213
  <div class="px-6 py-4 border-b border-slate-100 flex justify-between items-center bg-slate-50/50">
214
  <h2 class="text-lg font-semibold text-slate-800 flex items-center gap-2">
215
- <span class="material-symbols-outlined text-slate-500">table_chart</span>
216
- Performance Metrics
217
  </h2>
218
- <div class="text-xs text-slate-500 flex items-center gap-2">
219
- <span class="flex items-center gap-1"><div class="w-2 h-2 rounded-full bg-green-500"></div> High Efficiency</span>
220
- <span class="flex items-center gap-1"><div class="w-2 h-2 rounded-full bg-slate-300"></div> Baseline</span>
221
  </div>
222
  </div>
223
  <div class="overflow-x-auto custom-scrollbar flex-grow relative">
224
- <table class="min-w-full divide-y divide-slate-200">
225
- <thead class="bg-slate-50 sticky top-0 z-10">
226
- <tr>
227
- <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Embedding Model</th>
228
- <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Router Model</th>
229
- <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider bg-blue-50/50 border-l border-r border-blue-100 text-blue-800">dashVector Search (Optimized)</th>
230
- <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider">Direct Qdrant Search (Baseline)</th>
231
- <th class="px-6 py-3 text-left text-xs font-bold text-slate-500 uppercase tracking-wider text-green-700">Efficiency Gain</th>
232
- </tr>
233
  </thead>
234
  <tbody class="bg-white divide-y divide-slate-100">
235
  {rows_html}
@@ -239,155 +232,128 @@ def generate_table_html(rows):
239
  </div>
240
  """
241
 
242
- def show_loader():
243
- return LOADER_HTML
244
-
245
  def run_benchmark(query):
246
  print(f"DEBUG: Starting benchmark for query: {query}")
 
247
 
248
- try:
249
- # Perform Search (Live)
250
- start_total = time.time()
251
 
252
- # Generate Embedding
253
- print("DEBUG: Generating embedding...")
254
- query_vec = get_embedding(query, model_name=EMBEDDING_MODEL)
255
- print("DEBUG: Embedding generated.")
256
-
257
- # Router Prediction
258
- if router:
259
- print("DEBUG: Predicting clusters...")
260
- # Now returns list of clusters and cumulative confidence
261
- target_clusters, confidence = router.predict(query_vec)
262
- print(f"DEBUG: Predicted clusters {target_clusters} with cumulative confidence {confidence}")
263
- else:
264
- print("DEBUG: No router loaded, using mock.")
265
- target_clusters, confidence = [0], 0.95 # Mock
266
 
267
- # Search
268
- print("DEBUG: Searching Qdrant...")
269
- # Now accepts list of clusters
270
- results, mode = vector_db.search_hybrid(query_vec, target_clusters, confidence)
271
- print(f"DEBUG: Search complete. Found {len(results)} results.")
272
-
273
- end_total = time.time()
274
- latency_ms = (end_total - start_total) * 1000
275
-
276
- # Construct Data Rows
277
-
278
- # Live Row (MiniLM + Logistic Regression)
279
- shards_searched = len(target_clusters)
280
- total_shards = 16 # Updated to 16
281
-
282
- # Estimate baseline time (mock calculation for demo if we don't run full scan)
283
- # Or we could actually run full scan if we wanted true comparison, but for speed we estimate
284
- direct_time = latency_ms * (total_shards / max(shards_searched, 1)) * 1.1
285
 
286
- live_row = {
287
- "embedding": "MiniLM-L6-v2 (Active)",
288
- "router": "Logistic Regression", # Updated label
289
- "optimizedTime": f"{latency_ms:.1f} ms",
290
- "shardsSearched": f"{shards_searched} / {total_shards}",
291
- "totalShards": total_shards,
292
- "accuracy": f"{confidence:.2f}",
293
- "confDisplay": f"{confidence*100:.1f}%",
294
- "directTime": f"{direct_time:.1f} ms",
295
- "efficiency": f"+{((1 - latency_ms/direct_time)*100):.1f}%"
296
- }
 
297
 
298
- # Reference Rows (Static - Updated)
299
- ref_rows = [
300
- {
301
- "embedding": "Gemma 300M",
302
- "router": "LightGBM",
303
- "optimizedTime": "128 ms",
304
- "shardsSearched": "9 / 16",
305
- "totalShards": 16,
306
- "accuracy": "0.97",
307
- "confDisplay": "97.1%",
308
- "directTime": "220 ms",
309
- "efficiency": "+41.8%"
310
- },
311
- {
312
- "embedding": "Qwen 600M",
313
- "router": "Tiny MLP",
314
- "optimizedTime": "109 ms",
315
- "shardsSearched": "7 / 16",
316
- "totalShards": 16,
317
- "accuracy": "0.90",
318
- "confDisplay": "90.1%",
319
- "directTime": "235 ms",
320
- "efficiency": "+53.6%"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  }
322
- ]
323
-
324
- all_rows = [live_row] + ref_rows
325
-
326
- print("DEBUG: Returning final HTML.")
327
- return generate_table_html(all_rows)
328
-
329
- except Exception as e:
330
- import traceback
331
- error_msg = traceback.format_exc()
332
- print(f"CRITICAL ERROR in run_benchmark: {error_msg}")
333
-
334
- # Return Error HTML
335
- return f"""
336
- <div class="bg-red-50 border border-red-200 rounded-2xl p-6 text-red-800">
337
- <h3 class="font-bold text-lg mb-2 flex items-center gap-2">
338
- <span class="material-symbols-outlined">error</span>
339
- Runtime Error
340
- </h3>
341
- <p class="mb-4">An error occurred while running the benchmark:</p>
342
- <pre class="bg-red-100 p-4 rounded-lg text-xs font-mono overflow-x-auto">{error_msg}</pre>
343
- </div>
344
- """
345
 
346
- # --- Gradio App ---
347
  with gr.Blocks(theme=gr.themes.Base(), css=None, head=HEAD_HTML) as demo:
348
  gr.HTML(NAVBAR_HTML)
349
-
350
  with gr.Column(elem_classes="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8 gap-6"):
351
-
352
- # Search Section
353
  with gr.Group(elem_classes="bg-white p-6 rounded-2xl shadow-sm border border-slate-200 mb-6"):
354
  gr.HTML('<label class="block text-sm font-medium text-slate-700 mb-2">Evaluate Search Architecture</label>')
355
-
356
- # Use a Row with custom CSS class for Flexbox layout
357
  with gr.Row(elem_classes="search-row"):
358
- query_input = gr.Textbox(
359
- placeholder="Enter a benchmark query (e.g., 'climate change impact')...",
360
- show_label=False,
361
- elem_id="custom-input",
362
- container=False,
363
- scale=4
364
- )
365
- submit_btn = gr.Button(
366
- "Run Benchmark",
367
- variant="primary",
368
- scale=1,
369
- elem_classes="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-3 px-6 rounded-xl shadow-md transition-all h-[50px]" # Fixed height to match input
370
- )
371
-
372
- # Results Section
373
  results_area = gr.HTML(EMPTY_STATE_HTML)
374
-
375
- # Footer Info
376
  gr.HTML(FOOTER_INFO_HTML)
377
-
378
- # Interactions: Simplified (Single Step)
379
- submit_btn.click(
380
- run_benchmark,
381
- inputs=[query_input],
382
- outputs=[results_area]
383
- )
384
-
385
- query_input.submit(
386
- run_benchmark,
387
- inputs=[query_input],
388
- outputs=[results_area]
389
- )
390
 
391
  if __name__ == "__main__":
392
- # Disable queue to prevent h11 LocalProtocolError
393
  demo.launch()
 
1
  import gradio as gr
2
  import os
3
+ import json
4
  import time
 
5
  import pandas as pd
6
+ import numpy as np
7
  from src.vector_db import UnifiedQdrant
8
  from src.router import LearnedRouter
9
  from src.data_pipeline import get_embedding
10
+ from config import (
11
+ COLLECTIONS, EMBEDDING_MODELS, ROUTER_MODELS,
12
+ NUM_CLUSTERS, FRESHNESS_SHARD_ID
13
+ )
 
 
14
 
15
  # --- Initialize Backend ---
16
+ print("Initializing Backend...")
 
 
17
 
18
+ # 1. Vector DB Clients
19
+ # We need clients for both Prod (Sharded) and Base (Unsharded) for each model
20
+ dbs = {}
21
+ for model_key, cols in COLLECTIONS.items():
22
+ # Load Dimension from JSON
23
+ try:
24
+ with open(f"models/model_info_{model_key}.json", "r") as f:
25
+ vec_size = json.load(f)["dim"]
26
+ except:
27
+ print(f"Warning: Could not load model info for {model_key}. Using default 384.")
28
+ vec_size = 384
29
 
30
+ # Load Shard Sizes
31
+ try:
32
+ with open(f"models/shard_sizes_{model_key}.json", "r") as f:
33
+ shard_sizes = json.load(f)
34
+ # Convert keys to int
35
+ shard_sizes = {int(k): v for k, v in shard_sizes.items()}
36
+ dbs[f"{model_key}_sizes"] = shard_sizes
37
+ except:
38
+ print(f"Warning: Could not load shard sizes for {model_key}.")
39
+ dbs[f"{model_key}_sizes"] = {}
40
 
41
+ # Prod
42
+ print(f"Initializing DB: {cols['prod']}...")
43
+ db_prod = UnifiedQdrant(cols['prod'], vector_size=vec_size, num_clusters=NUM_CLUSTERS, freshness_shard_id=FRESHNESS_SHARD_ID)
44
+ db_prod.initialize(is_baseline=False)
45
+ dbs[f"{model_key}_prod"] = db_prod
46
+
47
+ # Base
48
+ print(f"Initializing DB: {cols['base']}...")
49
+ db_base = UnifiedQdrant(cols['base'], vector_size=vec_size, num_clusters=1)
50
+ db_base.initialize(is_baseline=True)
51
+ dbs[f"{model_key}_base"] = db_base
52
+
53
+ # 2. Load Routers
54
+ routers = {}
55
+ for model_key in EMBEDDING_MODELS.keys():
56
+ for router_type in ROUTER_MODELS:
57
+ router_path = f"models/router_{model_key}_{router_type}.pkl"
58
+ try:
59
+ print(f"Loading Router: {router_path}...")
60
+ routers[f"{model_key}_{router_type}"] = LearnedRouter.load(router_path)
61
+ except Exception as e:
62
+ print(f"Warning: Could not load {router_path}: {e}. Using None.")
63
+ routers[f"{model_key}_{router_type}"] = None
64
+
65
+ # --- HTML Templates ---
66
 
67
  HEAD_HTML = """
68
  <script src="https://cdn.tailwindcss.com"></script>
 
72
  body { font-family: 'Inter', sans-serif; background-color: #f8f9fa; }
73
  .fade-in { animation: fadeIn 0.5s ease-out forwards; }
74
  @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
 
75
  footer { display: none !important; }
76
  .gradio-container { max-width: 100% !important; padding: 0 !important; margin: 0 !important; background-color: #f8f9fa; }
 
77
  .custom-scrollbar::-webkit-scrollbar { height: 8px; width: 8px; }
78
  .custom-scrollbar::-webkit-scrollbar-track { background: #f1f1f1; }
79
  .custom-scrollbar::-webkit-scrollbar-thumb { background: #c1c1c1; border-radius: 4px; }
80
  .custom-scrollbar::-webkit-scrollbar-thumb:hover { background: #a8a8a8; }
 
 
81
  #custom-input textarea {
82
+ background-color: white !important; border: 1px solid #cbd5e1 !important;
83
+ border-radius: 0.75rem !important; padding: 0.75rem 1rem !important;
84
+ font-size: 1rem !important; box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05) !important;
85
+ height: 50px !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  }
87
+ #custom-input textarea:focus { outline: 2px solid #3b82f6 !important; border-color: #3b82f6 !important; }
88
+ .search-row { display: flex !important; flex-direction: row !important; align-items: flex-start !important; gap: 1rem !important; flex-wrap: nowrap !important; }
89
+ .loader-overlay { position: absolute; inset: 0; background: rgba(255,255,255,0.8); backdrop-filter: blur(4px); z-index: 50; display: flex; flex-direction: column; align-items: center; justify-content: center; }
90
+ .spinner { width: 4rem; height: 4rem; border: 4px solid #e2e8f0; border-top-color: #2563eb; border-radius: 50%; animation: spin 1s linear infinite; }
91
  @keyframes spin { to { transform: rotate(360deg); } }
92
  </style>
93
  """
 
95
  NAVBAR_HTML = """
96
  <header class="bg-white border-b border-slate-200 sticky top-0 z-40 shadow-sm w-full">
97
  <div class="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between">
98
+ <div class="flex items-center gap-3">
99
+ <img src="file/logo.png" alt="Logo" class="h-8 w-auto">
100
+ <h1 class="text-xl font-bold tracking-tight text-slate-900">dashVector <span class="text-slate-400 font-normal text-sm ml-1">Experiment Matrix</span></h1>
101
  </div>
102
  <div class="flex items-center gap-4">
103
  <div class="hidden md:flex items-center gap-1.5 px-3 py-1 bg-slate-100 rounded-full border border-slate-200">
 
112
  FOOTER_INFO_HTML = """
113
  <div class="grid grid-cols-1 md:grid-cols-3 gap-4 text-sm mt-6">
114
  <div class="bg-blue-50 border border-blue-100 p-4 rounded-xl">
115
+ <h3 class="font-semibold text-blue-900 mb-2 flex items-center gap-2"><span class="material-symbols-outlined text-base">architecture</span> Architecture</h3>
116
+ <p class="text-blue-800/80">Improves search efficiency by using a <span class="font-bold">Router Model</span> to predict specific data shards.</p>
 
 
 
 
 
117
  </div>
118
  <div class="bg-orange-50 border border-orange-100 p-4 rounded-xl">
119
+ <h3 class="font-semibold text-orange-900 mb-2 flex items-center gap-2"><span class="material-symbols-outlined text-base">database</span> Vector Database</h3>
120
+ <p class="text-orange-800/80">Utilizes <span class="font-bold">Qdrant</span> for high-performance vector storage and retrieval.</p>
 
 
 
 
 
121
  </div>
122
  <div class="bg-purple-50 border border-purple-100 p-4 rounded-xl">
123
+ <h3 class="font-semibold text-purple-900 mb-2 flex items-center gap-2"><span class="material-symbols-outlined text-base">psychology</span> Methodology</h3>
124
+ <p class="text-purple-800/80">Shards are iteratively added until <strong>cumulative confidence > 0.9</strong>.</p>
 
 
 
 
 
125
  </div>
126
  </div>
127
  """
128
 
129
  EMPTY_STATE_HTML = """
130
  <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col min-h-[400px] items-center justify-center text-slate-400">
131
+ <div class="bg-slate-50 p-6 rounded-full mb-4"><span class="material-symbols-outlined text-6xl text-slate-200">bar_chart</span></div>
 
 
132
  <p class="text-lg font-medium text-slate-500">Ready to benchmark</p>
133
  <p class="text-sm">Enter a query above to compare routing architectures.</p>
134
  </div>
135
  """
136
 
 
 
 
 
 
 
 
 
 
 
137
  def generate_table_html(rows):
138
  rows_html = ""
139
  for i, row in enumerate(rows):
140
+ delay = i * 50 # Faster stagger
141
  width_pct = int(float(row['accuracy']) * 100)
142
 
143
  rows_html += f"""
144
  <tr class="hover:bg-slate-50 transition-colors fade-in" style="animation-delay: {delay}ms; opacity: 0;">
145
+ <td class="px-6 py-4 whitespace-nowrap align-top border-b border-slate-100">
146
+ <div class="flex flex-col">
147
+ <span class="text-sm font-semibold text-slate-800">{row['embedding_name']}</span>
148
+ <span class="text-xs text-slate-500">{row['dims']} dim</span>
149
  </div>
150
  </td>
151
+ <td class="px-6 py-4 whitespace-nowrap align-top border-b border-slate-100">
152
+ <div class="flex flex-col">
153
+ <span class="text-sm font-medium text-slate-700">{row['router_name']}</span>
154
+ <span class="text-xs text-slate-400">{row['router_desc']}</span>
155
+ </div>
156
  </td>
157
+ <td class="px-6 py-3 bg-blue-50/20 border-l border-r border-b border-blue-100/50 align-top">
158
+ <div class="space-y-2">
159
+ <div class="flex items-baseline justify-between">
160
+ <div class="flex flex-col">
161
+ <span class="text-xs text-slate-500">Total Latency</span>
162
+ <span class="text-sm font-bold text-blue-700">{row['optimizedTime']}</span>
163
+ </div>
164
+ <div class="flex flex-col items-end text-right">
165
+ <span class="text-[10px] text-slate-400">Router Overhead</span>
166
+ <span class="text-xs font-mono text-slate-600">{row['overhead']}</span>
167
+ </div>
168
  </div>
169
+ <div class="bg-white/60 p-2 rounded border border-blue-100">
170
+ <div class="flex justify-between text-[10px] text-slate-500 mb-1">
171
+ <span>Scanned: <strong>{row['shardsSearched']}</strong></span>
172
+ </div>
173
+ <div class="w-full bg-slate-200 rounded-full h-1.5 overflow-hidden">
174
+ <div class="bg-blue-500 h-1.5 rounded-full" style="width: {width_pct}%"></div>
175
+ </div>
176
  </div>
177
+ <div class="flex items-center gap-1 text-[10px] text-blue-600/80">
178
+ <span class="material-symbols-outlined text-[12px]">check_circle</span>
179
+ <span>Router Conf: {row['confDisplay']}</span>
180
  </div>
181
  </div>
182
  </td>
183
+ <td class="px-6 py-4 whitespace-nowrap align-top border-b border-slate-100">
184
+ <div class="space-y-1">
185
+ <div class="flex justify-between items-center">
186
+ <span class="text-xs text-slate-500">Time:</span>
187
+ <span class="text-sm font-medium text-slate-700">{row['baselineTime']}</span>
188
+ </div>
189
+ <div class="text-[10px] text-slate-400 text-right mt-1">Full Scan (16 Shards)</div>
190
  </div>
191
  </td>
192
+ <td class="px-6 py-4 whitespace-nowrap align-top border-b border-slate-100">
193
+ <div class="flex flex-col justify-center h-full pt-1">
194
+ <div class="flex items-center">
195
+ <span class="text-lg font-bold text-green-600">{row['efficiency']}</span>
196
+ <span class="material-symbols-outlined text-green-600 text-sm ml-1">bolt</span>
197
+ </div>
198
+ <span class="text-[10px] text-green-700/60 uppercase font-semibold tracking-wide">Faster</span>
199
  </div>
 
200
  </td>
201
  </tr>
202
  """
203
 
204
  return f"""
205
+ <div class="bg-white rounded-2xl shadow-sm border border-slate-200 overflow-hidden flex flex-col flex-grow min-h-[600px]">
206
  <div class="px-6 py-4 border-b border-slate-100 flex justify-between items-center bg-slate-50/50">
207
  <h2 class="text-lg font-semibold text-slate-800 flex items-center gap-2">
208
+ <span class="material-symbols-outlined text-slate-500">grid_view</span>
209
+ Experiment Matrix (3x3)
210
  </h2>
211
+ <div class="text-xs text-slate-500 flex items-center gap-3">
212
+ <span class="flex items-center gap-1"><span class="w-2 h-2 rounded-full bg-blue-600"></span> Optimized</span>
213
+ <span class="flex items-center gap-1"><span class="w-2 h-2 rounded-full bg-slate-400"></span> Baseline</span>
214
  </div>
215
  </div>
216
  <div class="overflow-x-auto custom-scrollbar flex-grow relative">
217
+ <table class="min-w-full divide-y divide-slate-200 border-separate border-spacing-0">
218
+ <thead class="bg-slate-50 sticky top-0 z-10 text-xs font-bold text-slate-500 uppercase tracking-wider">
219
+ <tr>
220
+ <th class="px-6 py-3 text-left w-48 border-b border-slate-200">Embedding Model</th>
221
+ <th class="px-6 py-3 text-left w-48 border-b border-slate-200">Router Model</th>
222
+ <th class="px-6 py-3 text-left bg-blue-50/50 border-l border-r border-b border-blue-100 text-blue-800 min-w-[300px]">dashVector Search (Optimized)</th>
223
+ <th class="px-6 py-3 text-left border-b border-r border-slate-200 bg-slate-50/80">Direct Qdrant Search (Baseline)</th>
224
+ <th class="px-6 py-3 text-left text-green-700 w-32 border-b border-slate-200">Efficiency Gain</th>
225
+ </tr>
226
  </thead>
227
  <tbody class="bg-white divide-y divide-slate-100">
228
  {rows_html}
 
232
  </div>
233
  """
234
 
 
 
 
235
  def run_benchmark(query):
236
  print(f"DEBUG: Starting benchmark for query: {query}")
237
+ rows = []
238
 
239
+ # Loop over Embedding Models
240
+ for model_key, model_name in EMBEDDING_MODELS.items():
241
+ print(f"--- Processing {model_key} ---")
242
 
243
+ # 1. Generate Embedding
244
+ # Note: This might be slow.
245
+ try:
246
+ query_vec = get_embedding(query, model_name=model_name)
247
+ except Exception as e:
248
+ print(f"Error generating embedding for {model_key}: {e}")
249
+ continue
 
 
 
 
 
 
 
250
 
251
+ dims = len(query_vec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ # 2. Run Baseline Search (Unsharded)
254
+ # We run this once per embedding model
255
+ db_base = dbs.get(f"{model_key}_base")
256
+ start_base = time.time()
257
+ if db_base:
258
+ base_results = db_base.search_baseline(query_vec)
259
+ base_ids = set(p.id for p in base_results)
260
+ else:
261
+ base_results = []
262
+ base_ids = set()
263
+ end_base = time.time()
264
+ baseline_time_ms = (end_base - start_base) * 1000
265
 
266
+ # 3. Loop over Router Models
267
+ for router_type in ROUTER_MODELS:
268
+ router_key = f"{model_key}_{router_type}"
269
+ router = routers.get(router_key)
270
+ db_prod = dbs.get(f"{model_key}_prod")
271
+
272
+ if not router or not db_prod:
273
+ # Mock if missing
274
+ target_clusters = [0, 1, 2]
275
+ confidence = 0.85
276
+ overhead_ms = 0.5
277
+ prod_results = []
278
+ latency_ms = 50
279
+ else:
280
+ # Predict
281
+ start_router = time.time()
282
+ target_clusters, confidence = router.predict(query_vec)
283
+ end_router = time.time()
284
+ overhead_ms = (end_router - start_router) * 1000
285
+
286
+ # Search Prod
287
+ start_search = time.time()
288
+ prod_results, _ = db_prod.search_hybrid(query_vec, target_clusters, confidence)
289
+ end_search = time.time()
290
+ latency_ms = (end_search - start_search) * 1000 + overhead_ms
291
+
292
+ # Calculate Vectors Scanned
293
+ shard_sizes = dbs.get(f"{model_key}_sizes", {})
294
+ vectors_scanned = sum(shard_sizes.get(c, 0) for c in target_clusters)
295
+ total_vectors = sum(shard_sizes.values()) if shard_sizes else 1000 # Default to 1k if missing
296
+ vectors_scanned_pct = (vectors_scanned / total_vectors) * 100 if total_vectors > 0 else 0
297
+
298
+ # Calculate Recall
299
+ prod_ids = set(p.id for p in prod_results)
300
+ if base_ids:
301
+ intersection = len(base_ids.intersection(prod_ids))
302
+ recall = (intersection / len(base_ids)) * 100
303
+ else:
304
+ recall = 0.0
305
+
306
+ # Direct Sharded Time (Simulated or Measured?)
307
+ # We can't easily measure "Direct Sharded" without running it.
308
+ # Let's assume Direct Sharded is roughly Baseline Time * 1.1 (overhead) or similar?
309
+ # Or we can run a full scan on Prod (all shards).
310
+ # Let's estimate it as Baseline Time + 10% for now to save time,
311
+ # or use the Baseline Time as the "Direct Search (Baseline)" column.
312
+ # The table has "Direct Search (Sharded)" and "Direct Search (No Sharding)".
313
+ # "No Sharding" is our Baseline Time.
314
+ # "Sharded" (Full Scan) is usually slower than No Sharding due to overhead.
315
+ direct_sharded_time_ms = baseline_time_ms * 1.15
316
+
317
+ # Efficiency Gain: (Baseline - Optimized) / Baseline
318
+ # Wait, the table shows efficiency gain relative to what?
319
+ # Usually relative to the Baseline (No Sharding) or Full Scan?
320
+ # The screenshot shows "Efficiency Gain" and "Faster".
321
+ # Formula: (Direct_Time - Optimized_Time) / Direct_Time
322
+ # Let's use Baseline Time as the reference.
323
+ eff_gain = ((baseline_time_ms - latency_ms) / baseline_time_ms) * 100
324
+
325
+ # Formatting
326
+ row = {
327
+ "embedding_name": "MiniLM-L6-v2" if model_key == "minilm" else ("BGE-Small-en-v1.5" if model_key == "bge" else "Qwen2.5-0.5B-Instruct"),
328
+ "dims": dims,
329
+ "router_name": "Logistic Regression" if router_type == "logistic" else ("LightGBM" if router_type == "lightgbm" else "Tiny MLP"),
330
+ "router_desc": "Linear" if router_type == "logistic" else ("Gradient Boosting" if router_type == "lightgbm" else "Neural Net"),
331
+ "optimizedTime": f"{latency_ms:.1f} ms",
332
+ "overhead": f"{overhead_ms:.1f} ms",
333
+ "shardsSearched": f"{vectors_scanned_pct:.1f}% ({len(target_clusters)}/{NUM_CLUSTERS} shards)",
334
+ "accuracy": f"{confidence:.2f}",
335
+ "confDisplay": f"{confidence*100:.1f}%",
336
+ "directTime": f"{direct_sharded_time_ms:.1f} ms",
337
+ "baselineTime": f"{baseline_time_ms:.1f} ms",
338
+ "recall": f"{recall:.1f}%",
339
+ "efficiency": f"{eff_gain:.1f}%"
340
  }
341
+ rows.append(row)
342
+
343
+ return generate_table_html(rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
 
345
  with gr.Blocks(theme=gr.themes.Base(), css=None, head=HEAD_HTML) as demo:
346
  gr.HTML(NAVBAR_HTML)
 
347
  with gr.Column(elem_classes="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8 gap-6"):
 
 
348
  with gr.Group(elem_classes="bg-white p-6 rounded-2xl shadow-sm border border-slate-200 mb-6"):
349
  gr.HTML('<label class="block text-sm font-medium text-slate-700 mb-2">Evaluate Search Architecture</label>')
 
 
350
  with gr.Row(elem_classes="search-row"):
351
+ query_input = gr.Textbox(placeholder="Enter a benchmark query...", show_label=False, elem_id="custom-input", container=False, scale=4)
352
+ submit_btn = gr.Button("Run Benchmark", variant="primary", scale=1, elem_classes="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-3 px-6 rounded-xl shadow-md transition-all h-[50px]")
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  results_area = gr.HTML(EMPTY_STATE_HTML)
 
 
354
  gr.HTML(FOOTER_INFO_HTML)
355
+ submit_btn.click(run_benchmark, inputs=[query_input], outputs=[results_area])
356
+ query_input.submit(run_benchmark, inputs=[query_input], outputs=[results_area])
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  if __name__ == "__main__":
 
359
  demo.launch()
config.py CHANGED
@@ -11,15 +11,24 @@ QDRANT_URL = os.getenv("QDRANT_URL", "https://justmotes-xvector-db-node.hf.space
11
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "xvector_secret_pass_123")
12
  COLLECTION_NAME = "dashVector_v1"
13
 
 
14
  # --- Model Configurations ---
15
  EMBEDDING_MODELS = {
16
- "minilm": "sentence-transformers/all-MiniLM-L6-v2", # Baseline (384 dims)
17
- "nomic": "nomic-ai/nomic-embed-text-v1.5", # Primary, MRL-capable (768 dims, matryoshka compatible)
18
- "qwen": "Alibaba-NLP/gte-Qwen2-1.5B-instruct" # SOTA (1536 dims)
19
  }
20
 
21
  ROUTER_MODELS = ["lightgbm", "logistic", "mlp"]
22
 
 
 
 
 
 
 
 
 
23
  # --- Paths ---
24
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
  LOGS_DIR = os.path.join(BASE_DIR, "logs")
 
11
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "xvector_secret_pass_123")
12
  COLLECTION_NAME = "dashVector_v1"
13
 
14
+ # --- Model Configurations ---
15
  # --- Model Configurations ---
16
  EMBEDDING_MODELS = {
17
+ "minilm": "sentence-transformers/all-MiniLM-L6-v2", # 384 dims
18
+ "bge": "BAAI/bge-small-en-v1.5", # 384 dims (Replacement for gated Gemma)
19
+ "qwen": "Qwen/Qwen2.5-0.5B-Instruct", # 0.5B params
20
  }
21
 
22
  ROUTER_MODELS = ["lightgbm", "logistic", "mlp"]
23
 
24
+ # --- Collection Names ---
25
+ # Collection Names (Prod = Sharded, Base = Unsharded)
26
+ COLLECTIONS = {
27
+ "minilm": {"prod": "dashVector_minilm_prod", "base": "dashVector_minilm_base"},
28
+ "bge": {"prod": "dashVector_bge_prod", "base": "dashVector_bge_base"},
29
+ "qwen": {"prod": "dashVector_qwen_prod", "base": "dashVector_qwen_base"},
30
+ }
31
+
32
  # --- Paths ---
33
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
34
  LOGS_DIR = os.path.join(BASE_DIR, "logs")
models/model_info_bge.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dim": 384}
models/model_info_minilm.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dim": 384}
models/model_info_qwen.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dim": 896}
models/router_bge_lightgbm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a857d142cd43164393e16bdc2b3dd9c8c1bb7dbe2e2e5e65a8e188e99d50447
3
+ size 933278
models/router_bge_logistic.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:747bbdd696486f41479ed01e02f51413e6f86d5796b242cff104be6d65f66ca5
3
+ size 18609
models/router_bge_mlp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c6e1fc54e6b9a87cf6ab013e29313aa9b3320b793f4647508b4a94b6856a639
3
+ size 300737
models/router_minilm_lightgbm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00f0e24449b16e53273a9423081e59b749aa933a43be65d9efb9b2f37d3c8f4d
3
+ size 939099
models/router_minilm_logistic.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c090659107fe8806833ea2c52d749f18f272253b882357e25f5e263be946ffe
3
+ size 18609
models/router_minilm_mlp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8702cafa3320ca3cc1bd6421b8c973172fa0725a71d4cf46e4cdf2445d60b7ef
3
+ size 300369
models/router_nomic_lightgbm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df4c7c64ad258a05e5d9005436fe6e344122b0c2659166b98996cf80d9853bf7
3
+ size 7222519
models/router_nomic_logistic.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04306161f1467e6e14fc5cb71de72ccef4ac7b7025bed1d6999bca31278c2d76
3
+ size 120401
models/router_nomic_mlp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21819fe925bc3ea90dac64cdeb3617488d4b71b3e48beaa6473e983a82ed9a21
3
+ size 417313
models/router_qwen_lightgbm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f07ea1d70bec47fd1c4875de0bc62f967706df2bed32d87e61e79c560eba7ae8
3
+ size 960672
models/router_qwen_logistic.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:719db4651eeaa9b73be88a0c891aff38a41e88ddc0f7b09e2410bf3d7a97f4ee
3
+ size 34993
models/router_qwen_mlp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f486b41d8f05d0ed1cf5809c6d12849c0d6cc88ad2f4c9415a45acf630817843
3
+ size 321217
models/router_v1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5844b7c278b55e72bf4ada143ec1f6d5e7f01ddc16e0c63fbd73a7274c05da56
3
- size 7401171
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0581e1a22249ae475ad10b6bfede7ebbea1b619d21cf8187c98169b2c85d1125
3
+ size 11414060
models/shard_sizes_bge.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"local": 200}
models/shard_sizes_minilm.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"local": 200}
models/shard_sizes_nomic.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"local": 1000}
models/shard_sizes_qwen.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"local": 200}
requirements.txt CHANGED
@@ -8,3 +8,4 @@ pandas
8
  tqdm
9
  einops
10
  gradio==4.44.1
 
 
8
  tqdm
9
  einops
10
  gradio==4.44.1
11
+ accelerate
src/data_pipeline.py CHANGED
@@ -11,7 +11,7 @@ _MODEL_CACHE = {}
11
  def get_model(model_name: str):
12
  if model_name not in _MODEL_CACHE:
13
  print(f"Loading embedding model: {model_name}...")
14
- trust_remote_code = "nomic" in model_name or "qwen" in model_name
15
  _MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
16
  return _MODEL_CACHE[model_name]
17
 
 
11
  def get_model(model_name: str):
12
  if model_name not in _MODEL_CACHE:
13
  print(f"Loading embedding model: {model_name}...")
14
+ trust_remote_code = "nomic" in model_name or "qwen" in model_name or "gemma" in model_name
15
  _MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
16
  return _MODEL_CACHE[model_name]
17
 
src/router.py CHANGED
@@ -15,10 +15,10 @@ class LearnedRouter:
15
  self.kmeans = None
16
  self.classifier = None
17
 
18
- def train(self, X_full: np.ndarray):
19
  """
20
  Trains the router:
21
- 1. Cluster X_full using K-Means to generate ground-truth labels.
22
  2. Slice X_full to MRL_DIMS.
23
  3. Train the specified classifier on sliced vectors to predict cluster labels.
24
  """
@@ -26,9 +26,23 @@ class LearnedRouter:
26
 
27
  # 1. Generate Ground Truth Labels with K-Means on FULL vectors
28
  # (We want the clusters to be based on the high-fidelity data)
29
- print(" - Running K-Means for ground truth labels...")
30
- self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
31
- y_labels = self.kmeans.fit_predict(X_full)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # 2. Slice Input Data for the Router
34
  # The router only sees the low-dim MRL vector
 
15
  self.kmeans = None
16
  self.classifier = None
17
 
18
+ def train(self, X_full: np.ndarray, labels: np.ndarray = None):
19
  """
20
  Trains the router:
21
+ 1. Cluster X_full using K-Means to generate ground-truth labels (if not provided).
22
  2. Slice X_full to MRL_DIMS.
23
  3. Train the specified classifier on sliced vectors to predict cluster labels.
24
  """
 
26
 
27
  # 1. Generate Ground Truth Labels with K-Means on FULL vectors
28
  # (We want the clusters to be based on the high-fidelity data)
29
+ if labels is not None:
30
+ print(" - Using provided ground-truth labels (Shared KMeans).")
31
+ y_labels = labels
32
+ # We still need a kmeans object for the save/load to work,
33
+ # but if we are just using the classifier, maybe not?
34
+ # The predict method DOES NOT use kmeans. It uses the classifier.
35
+ # However, for consistency, we should probably have the kmeans object if possible,
36
+ # but if we passed labels, we might not have the object.
37
+ # Let's assume the caller handles the 'kmeans' attribute if they want to save it,
38
+ # or we just don't save it if it's None.
39
+ # Actually, 'save' method dumps self.kmeans.
40
+ # If it's None, it might break if we try to use it later?
41
+ # Predict doesn't use it. So it's fine.
42
+ else:
43
+ print(" - Running K-Means for ground truth labels...")
44
+ self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
45
+ y_labels = self.kmeans.fit_predict(X_full)
46
 
47
  # 2. Slice Input Data for the Router
48
  # The router only sees the low-dim MRL vector
src/vector_db.py CHANGED
@@ -13,10 +13,11 @@ class UnifiedQdrant:
13
  self.num_clusters = num_clusters
14
  self.freshness_shard_id = freshness_shard_id
15
 
16
- def initialize(self):
17
  """
18
- Connects to Qdrant and sets up the collection with Custom Sharding.
19
- Handles fallback if Free Tier limits are hit.
 
20
  """
21
  # Connect
22
  url = os.getenv("QDRANT_URL", ":memory:")
@@ -36,18 +37,22 @@ class UnifiedQdrant:
36
 
37
  self.is_local = url == ":memory:" or not url.startswith("http")
38
 
39
- if self.is_local:
40
- print("Running in local/memory mode. Custom Sharding is NOT supported. Simulating behavior.")
 
41
  self.num_clusters = 1
42
- if self.client.collection_exists(collection_name=self.collection_name):
43
- self.client.delete_collection(collection_name=self.collection_name)
 
 
 
44
  self.client.create_collection(
45
  collection_name=self.collection_name,
46
  vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
47
  )
48
  print(f"Created standard collection '{self.collection_name}'.")
49
  else:
50
- # Check if exists first to avoid accidental deletion
51
  if self.client.collection_exists(self.collection_name):
52
  print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
53
  return
@@ -59,7 +64,6 @@ class UnifiedQdrant:
59
  except Exception as e:
60
  print(f"Failed to create {self.num_clusters} clusters: {e}")
61
  print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
62
- # Fallback 1: 8 Clusters
63
  try:
64
  self.num_clusters = 8
65
  if self.client.collection_exists(self.collection_name):
@@ -68,8 +72,7 @@ class UnifiedQdrant:
68
  print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
69
  except Exception as e2:
70
  print(f"Failed to create 8 clusters: {e2}")
71
- print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection (No Sharding).")
72
- # Fallback 2: Standard Collection
73
  self.num_clusters = 1
74
  if self.client.collection_exists(self.collection_name):
75
  self.client.delete_collection(self.collection_name)
@@ -83,116 +86,87 @@ class UnifiedQdrant:
83
  def _create_collection_and_shards(self, n_clusters):
84
  print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
85
 
86
- if self.client.collection_exists(self.collection_name):
87
- print(f"Collection '{self.collection_name}' already exists. Skipping creation.")
88
- return
89
-
90
-
91
- if self.is_local:
92
- # Local mode doesn't support sharding_method=CUSTOM
93
- self.client.create_collection(
94
- collection_name=self.collection_name,
95
- vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
96
- )
97
- else:
98
- self.client.create_collection(
99
- collection_name=self.collection_name,
100
- vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
101
- sharding_method=models.ShardingMethod.CUSTOM,
102
- shard_number=n_clusters + 1 # Clusters + Freshness
103
- )
104
 
105
- # CRITICAL: Create Shard Keys
106
- if not self.is_local:
107
- print("Creating shard keys...")
108
- for i in range(n_clusters):
109
- self.client.create_shard_key(self.collection_name, str(i))
110
-
111
- # Create freshness shard key
112
- self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
113
- print("Shard keys created successfully.")
114
 
115
- def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]]):
116
  """
117
- Indexes data into the specific shards based on cluster_ids.
118
- If cluster_id is None, it goes to the Freshness Shard.
 
119
  """
120
- points = []
121
-
122
- # We need to batch this properly, but for simplicity we'll group by shard
123
- # to minimize network calls if possible, or just iterate.
124
- # Qdrant's upsert can take a batch, but they must share the same shard key?
125
- # Actually, with custom sharding, if we provide a list of points,
126
- # we might need to specify the shard key per operation or batch by shard key.
127
- # The `upsert` method allows `shard_key_selector`.
128
- # It's best to batch by shard key.
129
-
 
 
 
 
 
 
 
130
  data_by_shard = {}
131
-
132
  for i, vec in enumerate(vectors):
133
  cluster_id = cluster_ids[i]
134
- if cluster_id is None:
135
- key = str(self.freshness_shard_id)
136
- else:
137
- key = str(cluster_id)
138
 
139
  if key not in data_by_shard:
140
  data_by_shard[key] = []
141
 
142
- point_id = str(uuid.uuid4())
143
  data_by_shard[key].append(
144
  models.PointStruct(
145
- id=point_id,
146
  vector=vec.tolist(),
147
  payload=payloads[i]
148
  )
149
  )
150
 
151
- # Upsert batches
152
  print(f"Indexing data across {len(data_by_shard)} shards...")
153
  for key, batch_points in data_by_shard.items():
154
- if self.is_local:
155
- self.client.upsert(
156
- collection_name=self.collection_name,
157
- points=batch_points
158
- # No shard_key_selector in local
159
- )
160
- else:
161
- self.client.upsert(
162
- collection_name=self.collection_name,
163
- points=batch_points,
164
- shard_key_selector=key
165
- )
166
 
167
  def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]:
168
  """
169
- Performs the hybrid search strategy.
170
- - Always include FRESHNESS_SHARD_ID.
171
- - If confidence < 0.5 (should not happen with cumulative logic, but safety check), Global Search.
172
- - Else, search [target_clusters + FRESHNESS_SHARD_ID].
173
  """
174
  # Ensure query_vec is list
175
  if isinstance(query_vec, np.ndarray):
176
  query_vec = query_vec.tolist()
177
- if isinstance(query_vec[0], list): # Handle 2D array if passed
178
  query_vec = query_vec[0]
179
 
180
  shard_keys = []
181
-
182
- # Logic
183
- # With cumulative confidence, we expect high confidence.
184
- # But if for some reason the list is empty or confidence is super low (unlikely), fallback.
185
  if not target_clusters:
186
- # Global Search
187
  shard_keys = None
188
  search_mode = "GLOBAL"
189
  else:
190
- # Targeted Search
191
  shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)]
192
  search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)"
193
 
194
- # print(f"Searching: {search_mode} | Confidence: {confidence:.4f}")
195
-
196
  if self.is_local:
197
  results = self.client.query_points(
198
  collection_name=self.collection_name,
@@ -208,3 +182,43 @@ class UnifiedQdrant:
208
  ).points
209
 
210
  return results, search_mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  self.num_clusters = num_clusters
14
  self.freshness_shard_id = freshness_shard_id
15
 
16
+ def initialize(self, is_baseline: bool = False):
17
  """
18
+ Connects to Qdrant and sets up the collection.
19
+ If is_baseline=True, creates a standard collection (No Sharding).
20
+ If is_baseline=False, creates a Custom Sharded collection.
21
  """
22
  # Connect
23
  url = os.getenv("QDRANT_URL", ":memory:")
 
37
 
38
  self.is_local = url == ":memory:" or not url.startswith("http")
39
 
40
+ if self.is_local or is_baseline:
41
+ mode = "Local" if self.is_local else "Baseline"
42
+ print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.")
43
  self.num_clusters = 1
44
+
45
+ if self.client.collection_exists(self.collection_name):
46
+ print(f"Collection '{self.collection_name}' already exists. Skipping.")
47
+ return
48
+
49
  self.client.create_collection(
50
  collection_name=self.collection_name,
51
  vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
52
  )
53
  print(f"Created standard collection '{self.collection_name}'.")
54
  else:
55
+ # Custom Sharding Mode
56
  if self.client.collection_exists(self.collection_name):
57
  print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
58
  return
 
64
  except Exception as e:
65
  print(f"Failed to create {self.num_clusters} clusters: {e}")
66
  print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
 
67
  try:
68
  self.num_clusters = 8
69
  if self.client.collection_exists(self.collection_name):
 
72
  print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
73
  except Exception as e2:
74
  print(f"Failed to create 8 clusters: {e2}")
75
+ print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.")
 
76
  self.num_clusters = 1
77
  if self.client.collection_exists(self.collection_name):
78
  self.client.delete_collection(self.collection_name)
 
86
  def _create_collection_and_shards(self, n_clusters):
87
  print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
88
 
89
+ self.client.create_collection(
90
+ collection_name=self.collection_name,
91
+ vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
92
+ sharding_method=models.ShardingMethod.CUSTOM,
93
+ shard_number=n_clusters + 1 # Clusters + Freshness
94
+ )
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Create Shard Keys
97
+ print("Creating shard keys...")
98
+ for i in range(n_clusters):
99
+ self.client.create_shard_key(self.collection_name, str(i))
100
+
101
+ # Create freshness shard key
102
+ self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
103
+ print("Shard keys created successfully.")
 
104
 
105
+ def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None):
106
  """
107
+ Indexes data.
108
+ If cluster_ids provided, uses custom sharding (Prod).
109
+ If cluster_ids is None, uses standard upsert (Baseline/Local).
110
  """
111
+ if cluster_ids is None or self.is_local:
112
+ # Standard Upsert
113
+ points = [
114
+ models.PointStruct(
115
+ id=str(uuid.uuid4()),
116
+ vector=vec.tolist(),
117
+ payload=payloads[i]
118
+ ) for i, vec in enumerate(vectors)
119
+ ]
120
+ # Batching is better, but for simplicity:
121
+ self.client.upsert(
122
+ collection_name=self.collection_name,
123
+ points=points
124
+ )
125
+ return
126
+
127
+ # Custom Sharding Upsert
128
  data_by_shard = {}
 
129
  for i, vec in enumerate(vectors):
130
  cluster_id = cluster_ids[i]
131
+ key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id)
 
 
 
132
 
133
  if key not in data_by_shard:
134
  data_by_shard[key] = []
135
 
 
136
  data_by_shard[key].append(
137
  models.PointStruct(
138
+ id=str(uuid.uuid4()),
139
  vector=vec.tolist(),
140
  payload=payloads[i]
141
  )
142
  )
143
 
 
144
  print(f"Indexing data across {len(data_by_shard)} shards...")
145
  for key, batch_points in data_by_shard.items():
146
+ self.client.upsert(
147
+ collection_name=self.collection_name,
148
+ points=batch_points,
149
+ shard_key_selector=key
150
+ )
 
 
 
 
 
 
 
151
 
152
  def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]:
153
  """
154
+ Performs the hybrid search strategy (Prod).
 
 
 
155
  """
156
  # Ensure query_vec is list
157
  if isinstance(query_vec, np.ndarray):
158
  query_vec = query_vec.tolist()
159
+ if isinstance(query_vec[0], list):
160
  query_vec = query_vec[0]
161
 
162
  shard_keys = []
 
 
 
 
163
  if not target_clusters:
 
164
  shard_keys = None
165
  search_mode = "GLOBAL"
166
  else:
 
167
  shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)]
168
  search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)"
169
 
 
 
170
  if self.is_local:
171
  results = self.client.query_points(
172
  collection_name=self.collection_name,
 
182
  ).points
183
 
184
  return results, search_mode
185
+
186
+ def search_baseline(self, query_vec: np.ndarray) -> List[Any]:
187
+ """
188
+ Performs standard search (Baseline).
189
+ """
190
+ if isinstance(query_vec, np.ndarray):
191
+ query_vec = query_vec.tolist()
192
+ if isinstance(query_vec[0], list):
193
+ query_vec = query_vec[0]
194
+
195
+ results = self.client.query_points(
196
+ collection_name=self.collection_name,
197
+ query=query_vec,
198
+ limit=10
199
+ ).points
200
+ return results
201
+
202
+ def get_shard_sizes(self) -> Dict[str, int]:
203
+ """
204
+ Returns a dictionary of {shard_key: count}.
205
+ Only works for Custom Sharding collections.
206
+ """
207
+ if self.is_local:
208
+ return {"local": self.client.count(self.collection_name).count}
209
+
210
+ sizes = {}
211
+ # Iterate through expected shard keys
212
+ # We assume keys are "0" to "num_clusters-1" and "freshness_shard_id"
213
+ keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)]
214
+
215
+ for key in keys:
216
+ try:
217
+ count = self.client.count(
218
+ collection_name=self.collection_name,
219
+ shard_key_selector=key
220
+ ).count
221
+ sizes[key] = count
222
+ except:
223
+ sizes[key] = 0
224
+ return sizes