File size: 3,440 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import time
import numpy as np
from typing import List, Dict, Any, Tuple
from src.vector_db import UnifiedQdrant
from src.router import LearnedRouter
from src.data_pipeline import get_embeddings
from config import NUM_CLUSTERS, FRESHNESS_SHARD_ID, EMBEDDING_MODELS

class ComparisonEngine:
    def __init__(self, db: UnifiedQdrant, router: LearnedRouter, embedding_model_name: str = "minilm"):
        self.db = db
        self.router = router
        self.embedding_model_name = EMBEDDING_MODELS.get(embedding_model_name, embedding_model_name)
        
    def get_query_embedding(self, query: str) -> np.ndarray:
        # Returns 1D array
        emb = get_embeddings(self.embedding_model_name, [query])
        return emb[0]

    def direct_search(self, query: str) -> Dict[str, Any]:
        """
        Brute Force Search (Baseline).
        Searches ALL shards.
        """
        query_vec = self.get_query_embedding(query)
        
        start_time = time.time()
        
        # In Qdrant, searching without shard_key_selector searches all shards.
        # However, our UnifiedQdrant.search_hybrid is designed for hybrid.
        # We need a raw search method or just use the client directly.
        # Let's use the client directly to be pure "Brute Force".
        
        # Note: In local mode, everything is one collection anyway.
        # In Cloud with custom sharding, omitting shard_key searches all.
        
        if self.db.is_local:
             results = self.db.client.query_points(
                collection_name=self.db.collection_name,
                query=query_vec,
                limit=10
            ).points
        else:
            results = self.db.client.query_points(
                collection_name=self.db.collection_name,
                query=query_vec,
                limit=10
                # No shard_key_selector -> Global Search
            ).points
            
        end_time = time.time()
        latency_ms = (end_time - start_time) * 1000
        
        # Compute Units: All Clusters + Freshness
        shards_searched = self.db.num_clusters + 1
        
        return {
            "results": results,
            "latency_ms": latency_ms,
            "shards_searched": shards_searched,
            "mode": "Brute Force"
        }

    def xvector_search(self, query: str) -> Dict[str, Any]:
        """
        xVector Search (Optimized).
        Uses Router -> Targeted Shard Search.
        """
        query_vec = self.get_query_embedding(query)
        
        start_time = time.time()
        
        # 1. Router Prediction
        target_cluster, confidence = self.router.predict(query_vec.reshape(1, -1))
        
        # 2. Hybrid Search (Target + Freshness OR Global Fallback)
        results, search_mode = self.db.search_hybrid(query_vec, target_cluster, confidence)
        
        end_time = time.time()
        latency_ms = (end_time - start_time) * 1000
        
        # Calculate Shards Searched
        if "GLOBAL" in search_mode:
            shards_searched = self.db.num_clusters + 1
        else:
            shards_searched = 2 # Target + Freshness
            
        return {
            "results": results,
            "latency_ms": latency_ms,
            "shards_searched": shards_searched,
            "mode": f"xVector ({search_mode})",
            "confidence": confidence,
            "target_cluster": target_cluster
        }