Spaces:
Running
Running
File size: 3,440 Bytes
b92d96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import time
import numpy as np
from typing import List, Dict, Any, Tuple
from src.vector_db import UnifiedQdrant
from src.router import LearnedRouter
from src.data_pipeline import get_embeddings
from config import NUM_CLUSTERS, FRESHNESS_SHARD_ID, EMBEDDING_MODELS
class ComparisonEngine:
def __init__(self, db: UnifiedQdrant, router: LearnedRouter, embedding_model_name: str = "minilm"):
self.db = db
self.router = router
self.embedding_model_name = EMBEDDING_MODELS.get(embedding_model_name, embedding_model_name)
def get_query_embedding(self, query: str) -> np.ndarray:
# Returns 1D array
emb = get_embeddings(self.embedding_model_name, [query])
return emb[0]
def direct_search(self, query: str) -> Dict[str, Any]:
"""
Brute Force Search (Baseline).
Searches ALL shards.
"""
query_vec = self.get_query_embedding(query)
start_time = time.time()
# In Qdrant, searching without shard_key_selector searches all shards.
# However, our UnifiedQdrant.search_hybrid is designed for hybrid.
# We need a raw search method or just use the client directly.
# Let's use the client directly to be pure "Brute Force".
# Note: In local mode, everything is one collection anyway.
# In Cloud with custom sharding, omitting shard_key searches all.
if self.db.is_local:
results = self.db.client.query_points(
collection_name=self.db.collection_name,
query=query_vec,
limit=10
).points
else:
results = self.db.client.query_points(
collection_name=self.db.collection_name,
query=query_vec,
limit=10
# No shard_key_selector -> Global Search
).points
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Compute Units: All Clusters + Freshness
shards_searched = self.db.num_clusters + 1
return {
"results": results,
"latency_ms": latency_ms,
"shards_searched": shards_searched,
"mode": "Brute Force"
}
def xvector_search(self, query: str) -> Dict[str, Any]:
"""
xVector Search (Optimized).
Uses Router -> Targeted Shard Search.
"""
query_vec = self.get_query_embedding(query)
start_time = time.time()
# 1. Router Prediction
target_cluster, confidence = self.router.predict(query_vec.reshape(1, -1))
# 2. Hybrid Search (Target + Freshness OR Global Fallback)
results, search_mode = self.db.search_hybrid(query_vec, target_cluster, confidence)
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Calculate Shards Searched
if "GLOBAL" in search_mode:
shards_searched = self.db.num_clusters + 1
else:
shards_searched = 2 # Target + Freshness
return {
"results": results,
"latency_ms": latency_ms,
"shards_searched": shards_searched,
"mode": f"xVector ({search_mode})",
"confidence": confidence,
"target_cluster": target_cluster
}
|