import time import numpy as np from typing import List, Dict, Any, Tuple from src.vector_db import UnifiedQdrant from src.router import LearnedRouter from src.data_pipeline import get_embeddings from config import NUM_CLUSTERS, FRESHNESS_SHARD_ID, EMBEDDING_MODELS class ComparisonEngine: def __init__(self, db: UnifiedQdrant, router: LearnedRouter, embedding_model_name: str = "minilm"): self.db = db self.router = router self.embedding_model_name = EMBEDDING_MODELS.get(embedding_model_name, embedding_model_name) def get_query_embedding(self, query: str) -> np.ndarray: # Returns 1D array emb = get_embeddings(self.embedding_model_name, [query]) return emb[0] def direct_search(self, query: str) -> Dict[str, Any]: """ Brute Force Search (Baseline). Searches ALL shards. """ query_vec = self.get_query_embedding(query) start_time = time.time() # In Qdrant, searching without shard_key_selector searches all shards. # However, our UnifiedQdrant.search_hybrid is designed for hybrid. # We need a raw search method or just use the client directly. # Let's use the client directly to be pure "Brute Force". # Note: In local mode, everything is one collection anyway. # In Cloud with custom sharding, omitting shard_key searches all. if self.db.is_local: results = self.db.client.query_points( collection_name=self.db.collection_name, query=query_vec, limit=10 ).points else: results = self.db.client.query_points( collection_name=self.db.collection_name, query=query_vec, limit=10 # No shard_key_selector -> Global Search ).points end_time = time.time() latency_ms = (end_time - start_time) * 1000 # Compute Units: All Clusters + Freshness shards_searched = self.db.num_clusters + 1 return { "results": results, "latency_ms": latency_ms, "shards_searched": shards_searched, "mode": "Brute Force" } def xvector_search(self, query: str) -> Dict[str, Any]: """ xVector Search (Optimized). Uses Router -> Targeted Shard Search. """ query_vec = self.get_query_embedding(query) start_time = time.time() # 1. Router Prediction target_cluster, confidence = self.router.predict(query_vec.reshape(1, -1)) # 2. Hybrid Search (Target + Freshness OR Global Fallback) results, search_mode = self.db.search_hybrid(query_vec, target_cluster, confidence) end_time = time.time() latency_ms = (end_time - start_time) * 1000 # Calculate Shards Searched if "GLOBAL" in search_mode: shards_searched = self.db.num_clusters + 1 else: shards_searched = 2 # Target + Freshness return { "results": results, "latency_ms": latency_ms, "shards_searched": shards_searched, "mode": f"xVector ({search_mode})", "confidence": confidence, "target_cluster": target_cluster }