dashVectorSpace / src /comparison.py
justmotes's picture
Deploy dashVectorspace v1 (Full)
b92d96d
raw
history blame
3.44 kB
import time
import numpy as np
from typing import List, Dict, Any, Tuple
from src.vector_db import UnifiedQdrant
from src.router import LearnedRouter
from src.data_pipeline import get_embeddings
from config import NUM_CLUSTERS, FRESHNESS_SHARD_ID, EMBEDDING_MODELS
class ComparisonEngine:
def __init__(self, db: UnifiedQdrant, router: LearnedRouter, embedding_model_name: str = "minilm"):
self.db = db
self.router = router
self.embedding_model_name = EMBEDDING_MODELS.get(embedding_model_name, embedding_model_name)
def get_query_embedding(self, query: str) -> np.ndarray:
# Returns 1D array
emb = get_embeddings(self.embedding_model_name, [query])
return emb[0]
def direct_search(self, query: str) -> Dict[str, Any]:
"""
Brute Force Search (Baseline).
Searches ALL shards.
"""
query_vec = self.get_query_embedding(query)
start_time = time.time()
# In Qdrant, searching without shard_key_selector searches all shards.
# However, our UnifiedQdrant.search_hybrid is designed for hybrid.
# We need a raw search method or just use the client directly.
# Let's use the client directly to be pure "Brute Force".
# Note: In local mode, everything is one collection anyway.
# In Cloud with custom sharding, omitting shard_key searches all.
if self.db.is_local:
results = self.db.client.query_points(
collection_name=self.db.collection_name,
query=query_vec,
limit=10
).points
else:
results = self.db.client.query_points(
collection_name=self.db.collection_name,
query=query_vec,
limit=10
# No shard_key_selector -> Global Search
).points
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Compute Units: All Clusters + Freshness
shards_searched = self.db.num_clusters + 1
return {
"results": results,
"latency_ms": latency_ms,
"shards_searched": shards_searched,
"mode": "Brute Force"
}
def xvector_search(self, query: str) -> Dict[str, Any]:
"""
xVector Search (Optimized).
Uses Router -> Targeted Shard Search.
"""
query_vec = self.get_query_embedding(query)
start_time = time.time()
# 1. Router Prediction
target_cluster, confidence = self.router.predict(query_vec.reshape(1, -1))
# 2. Hybrid Search (Target + Freshness OR Global Fallback)
results, search_mode = self.db.search_hybrid(query_vec, target_cluster, confidence)
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Calculate Shards Searched
if "GLOBAL" in search_mode:
shards_searched = self.db.num_clusters + 1
else:
shards_searched = 2 # Target + Freshness
return {
"results": results,
"latency_ms": latency_ms,
"shards_searched": shards_searched,
"mode": f"xVector ({search_mode})",
"confidence": confidence,
"target_cluster": target_cluster
}