File size: 6,325 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import time
import numpy as np
import pandas as pd
from tabulate import tabulate
import itertools

from config import (
    NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS, 
    EMBEDDING_MODELS, ROUTER_MODELS, COLLECTION_NAME
)
from src.data_pipeline import get_embeddings, mrl_slice, load_ms_marco, generate_synthetic_data
from src.router import LearnedRouter
from src.vector_db import UnifiedQdrant
from src.active_learning import log_for_retraining

def run_benchmark():
    print("============================================================")
    print("   xVector / dashVector: Learned Hybrid Retrieval Engine    ")
    print("============================================================")
    
    results_table = []
    
    # P&C Matrix: Iterate through all Embedding Models x Router Models
    combinations = list(itertools.product(EMBEDDING_MODELS.keys(), ROUTER_MODELS))
    
    for embed_name, router_name in combinations:
        print(f"\n>>> Running Experiment: Embedding='{embed_name}' | Router='{router_name}'")
        
        model_id = EMBEDDING_MODELS[embed_name]
        
        # 1. Generate/Load Data
        # We need enough data to cluster meaningfully.
        N_SAMPLES = 2000 
        raw_texts = load_ms_marco(N_SAMPLES)
        
        # Generate Embeddings
        embeddings = get_embeddings(model_id, raw_texts)
        vector_dim = embeddings.shape[1]
        
        # Split into Train (for Router) and Index (for DB)
        # In a real scenario, we might train on a subset and index everything.
        # Here, let's use 50% for training router, and index the other 50% + some "fresh" data.
        split_idx = int(N_SAMPLES * 0.5)
        X_train = embeddings[:split_idx]
        X_index = embeddings[split_idx:]
        texts_index = raw_texts[split_idx:]
        
        # 2. Train Router
        router = LearnedRouter(model_type=router_name, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
        router.train(X_train)
        
        # 3. Index Data
        # We need to assign clusters to X_index using the router (or ground truth?)
        # For the "Index Data" phase, we usually index based on the Router's prediction 
        # OR we can index based on Ground Truth K-Means if we want the DB to be perfect, 
        # and then test if the Router can find it.
        # The prompt says: "The Brain (Router)... predicts which data partition... contains the answer".
        # Usually, we partition data using K-Means (Ground Truth) during ingestion.
        # Then at query time, the Router predicts where to look.
        
        # So:
        # A. Run K-Means on X_index to determine where they SHOULD go.
        # (Ideally, we use the SAME K-Means model from training if possible, but K-Means is transductive.
        #  We should probably use the router's kmeans to predict labels for X_index)
        
        # Let's use the router's internal kmeans to assign ground truth labels for indexing.
        # This ensures consistency.
        ground_truth_labels = router.kmeans.predict(X_index)
        
        # Initialize DB
        db = UnifiedQdrant(
            collection_name=COLLECTION_NAME, 
            vector_size=vector_dim, 
            num_clusters=NUM_CLUSTERS, 
            freshness_shard_id=FRESHNESS_SHARD_ID
        )
        db.initialize()
        
        # Prepare payloads
        payloads = [{"text": t, "origin": "historical"} for t in texts_index]
        
        # Index Historical Data (Assigned to specific clusters)
        db.index_data(X_index, payloads, ground_truth_labels)
        
        # Index some "Fresh" Data (No cluster assigned -> Freshness Shard)
        # Let's simulate 100 fresh items
        fresh_texts = generate_synthetic_data(100)
        fresh_embeddings = get_embeddings(model_id, fresh_texts)
        fresh_payloads = [{"text": t, "origin": "fresh"} for t in fresh_texts]
        db.index_data(fresh_embeddings, fresh_payloads, [None] * len(fresh_texts))
        
        # 4. Run Test Queries
        # We'll use a subset of X_index as queries to see if we can find them back (Self-Recall)
        # And maybe some completely new queries.
        test_indices = np.random.choice(len(X_index), size=20, replace=False)
        test_queries = X_index[test_indices]
        test_query_texts = [texts_index[i] for i in test_indices]
        
        latencies = []
        hits = 0
        shards_searched_count = 0
        
        print("  - Running Test Queries...")
        for i, query_vec in enumerate(test_queries):
            start_time = time.time()
            
            # Router Prediction
            target_cluster, confidence = router.predict(query_vec)
            
            # Search
            results, search_mode = db.search_hybrid(query_vec, target_cluster, confidence)
            
            end_time = time.time()
            latencies.append((end_time - start_time) * 1000) # ms
            
            # Check if we found the correct document (Self-Recall)
            # We look for the text in the results
            target_text = test_query_texts[i]
            found = any(res.payload['text'] == target_text for res in results)
            if found:
                hits += 1
                
            # Log for Active Learning
            log_for_retraining(target_text, confidence, results)
            
            # Track efficiency
            if "GLOBAL" in search_mode:
                shards_searched_count += (NUM_CLUSTERS + 1)
            else:
                shards_searched_count += 2 # Target + Freshness
                
        # 5. Metrics
        avg_latency = np.mean(latencies)
        accuracy = hits / len(test_queries)
        avg_shards = shards_searched_count / len(test_queries)
        total_shards = NUM_CLUSTERS + 1
        savings = (1 - (avg_shards / total_shards)) * 100
        
        results_table.append([
            embed_name, router_name, 
            f"{accuracy:.2%}", f"{avg_latency:.2f} ms", 
            f"{savings:.1f}%"
        ])
        
    # Print Summary
    print("\n\n================ RESULTS SUMMARY ================")
    headers = ["Embedding", "Router", "Accuracy", "Latency", "Compute Savings"]
    print(tabulate(results_table, headers=headers, tablefmt="grid"))
    
if __name__ == "__main__":
    run_benchmark()