Spaces:
Running
Running
| import time | |
| import numpy as np | |
| from typing import List, Dict, Any, Tuple | |
| from src.vector_db import UnifiedQdrant | |
| from src.router import LearnedRouter | |
| from src.data_pipeline import get_embeddings | |
| from config import NUM_CLUSTERS, FRESHNESS_SHARD_ID, EMBEDDING_MODELS | |
| class ComparisonEngine: | |
| def __init__(self, db: UnifiedQdrant, router: LearnedRouter, embedding_model_name: str = "minilm"): | |
| self.db = db | |
| self.router = router | |
| self.embedding_model_name = EMBEDDING_MODELS.get(embedding_model_name, embedding_model_name) | |
| def get_query_embedding(self, query: str) -> np.ndarray: | |
| # Returns 1D array | |
| emb = get_embeddings(self.embedding_model_name, [query]) | |
| return emb[0] | |
| def direct_search(self, query: str) -> Dict[str, Any]: | |
| """ | |
| Brute Force Search (Baseline). | |
| Searches ALL shards. | |
| """ | |
| query_vec = self.get_query_embedding(query) | |
| start_time = time.time() | |
| # In Qdrant, searching without shard_key_selector searches all shards. | |
| # However, our UnifiedQdrant.search_hybrid is designed for hybrid. | |
| # We need a raw search method or just use the client directly. | |
| # Let's use the client directly to be pure "Brute Force". | |
| # Note: In local mode, everything is one collection anyway. | |
| # In Cloud with custom sharding, omitting shard_key searches all. | |
| if self.db.is_local: | |
| results = self.db.client.query_points( | |
| collection_name=self.db.collection_name, | |
| query=query_vec, | |
| limit=10 | |
| ).points | |
| else: | |
| results = self.db.client.query_points( | |
| collection_name=self.db.collection_name, | |
| query=query_vec, | |
| limit=10 | |
| # No shard_key_selector -> Global Search | |
| ).points | |
| end_time = time.time() | |
| latency_ms = (end_time - start_time) * 1000 | |
| # Compute Units: All Clusters + Freshness | |
| shards_searched = self.db.num_clusters + 1 | |
| return { | |
| "results": results, | |
| "latency_ms": latency_ms, | |
| "shards_searched": shards_searched, | |
| "mode": "Brute Force" | |
| } | |
| def xvector_search(self, query: str) -> Dict[str, Any]: | |
| """ | |
| xVector Search (Optimized). | |
| Uses Router -> Targeted Shard Search. | |
| """ | |
| query_vec = self.get_query_embedding(query) | |
| start_time = time.time() | |
| # 1. Router Prediction | |
| target_cluster, confidence = self.router.predict(query_vec.reshape(1, -1)) | |
| # 2. Hybrid Search (Target + Freshness OR Global Fallback) | |
| results, search_mode = self.db.search_hybrid(query_vec, target_cluster, confidence) | |
| end_time = time.time() | |
| latency_ms = (end_time - start_time) * 1000 | |
| # Calculate Shards Searched | |
| if "GLOBAL" in search_mode: | |
| shards_searched = self.db.num_clusters + 1 | |
| else: | |
| shards_searched = 2 # Target + Freshness | |
| return { | |
| "results": results, | |
| "latency_ms": latency_ms, | |
| "shards_searched": shards_searched, | |
| "mode": f"xVector ({search_mode})", | |
| "confidence": confidence, | |
| "target_cluster": target_cluster | |
| } | |