Spaces:
Running
Running
| import sys | |
| import os | |
| import numpy as np | |
| import json | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Configuration | |
| N_SAMPLES = 25000 # Full Benchmark | |
| NUM_CLUSTERS = 32 # Production Cluster Count | |
| FRESHNESS_SHARD_ID = 0 | |
| from config import ( | |
| MRL_DIMS, | |
| EMBEDDING_MODELS, ROUTER_MODELS, COLLECTIONS, | |
| QDRANT_URL, QDRANT_API_KEY | |
| ) | |
| from src.data_pipeline import get_embeddings, load_ms_marco | |
| from src.router import LearnedRouter | |
| from src.vector_db import UnifiedQdrant | |
| def ingest_full_benchmark(): | |
| print(">>> Starting Full Benchmark Ingestion Pipeline...") | |
| # 1. Load Data | |
| # 1. Load Data | |
| print(f"Loading {N_SAMPLES} samples from MS MARCO...") | |
| raw_texts = load_ms_marco(N_SAMPLES) | |
| # Loop through each embedding model | |
| for model_key, model_name in EMBEDDING_MODELS.items(): | |
| print(f"\n==================================================") | |
| print(f"Processing Embedding Model: {model_key.upper()} ({model_name})") | |
| print(f"==================================================") | |
| # 2. Generate Embeddings | |
| print(f"Generating embeddings...") | |
| embeddings = get_embeddings(model_name, raw_texts) | |
| vector_dim = embeddings.shape[1] | |
| print(f"Embeddings generated. Shape: {embeddings.shape}") | |
| # Save Model Info (Dimension) for App | |
| model_info_path = f"models/model_info_{model_key}.json" | |
| with open(model_info_path, "w") as f: | |
| json.dump({"dim": vector_dim}, f) | |
| print(f"Saved model info to {model_info_path}") | |
| # 3. Baseline Collection (Unsharded) | |
| base_col_name = COLLECTIONS[model_key]["base"] | |
| print(f"\n--- Setting up Baseline Collection: {base_col_name} ---") | |
| db_base = UnifiedQdrant( | |
| collection_name=base_col_name, | |
| vector_size=vector_dim, | |
| num_clusters=1 # Unsharded | |
| ) | |
| db_base.initialize(is_baseline=True) | |
| print(f"Indexing data into Baseline...") | |
| payloads = [{"text": text, "source": "ms_marco"} for text in raw_texts] | |
| db_base.index_data(embeddings, payloads, cluster_ids=None) # None = Standard Upsert | |
| print("Baseline Indexing Complete.") | |
| # 4. Train Routers & Prod Collection (Sharded) | |
| prod_col_name = COLLECTIONS[model_key]["prod"] | |
| print(f"\n--- Setting up Prod Collection: {prod_col_name} ---") | |
| # We need "Ground Truth" labels for indexing. | |
| # Ideally, we use the router's training labels (KMeans labels). | |
| # We train the routers first. | |
| # We will use the labels from the FIRST router training (e.g., Logistic) | |
| # as the ground truth for physical sharding. | |
| # Or better, we explicitly run KMeans once to define the physical shards, | |
| # and then train all routers to predict those labels. | |
| # LearnedRouter.train does KMeans internally. | |
| # Let's instantiate a "Master" router just for KMeans/Sharding. | |
| print("Running K-Means to define Physical Shards...") | |
| # We can use the 'logistic' router class to do this, or just use KMeans directly. | |
| # Let's use the router class to keep it consistent. | |
| master_router = LearnedRouter(model_type="logistic", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| # We access the internal logic or just train it and use its labels. | |
| master_router.train(embeddings) | |
| cluster_labels = master_router.kmeans.labels_ # Get the labels | |
| # Now we have the physical shard assignment (cluster_labels) | |
| # Initialize Prod DB | |
| db_prod = UnifiedQdrant( | |
| collection_name=prod_col_name, | |
| vector_size=vector_dim, | |
| num_clusters=NUM_CLUSTERS, | |
| freshness_shard_id=FRESHNESS_SHARD_ID | |
| ) | |
| db_prod.initialize(is_baseline=False) | |
| print(f"Indexing data into Prod (Sharded)...") | |
| target_clusters = [int(c) for c in cluster_labels] | |
| db_prod.index_data(embeddings, payloads, cluster_ids=target_clusters) | |
| # Save Shard Sizes | |
| print("Saving Shard Sizes...") | |
| shard_sizes = db_prod.get_shard_sizes() | |
| size_path = f"models/shard_sizes_{model_key}.json" | |
| with open(size_path, "w") as f: | |
| json.dump(shard_sizes, f) | |
| print(f"Shard sizes saved to {size_path}") | |
| # 5. Train & Save All Routers | |
| # We already trained 'logistic' (master_router), but let's re-save/train loop for clarity | |
| # and to ensure they all predict the SAME KMeans clusters. | |
| # Wait, if we re-train KMeans inside each router, they might converge to DIFFERENT clusters! | |
| # CRITICAL: They must share the same KMeans model (Physical Layout). | |
| print("\n--- Training Routers ---") | |
| kmeans_model = master_router.kmeans # Reuse this! | |
| for router_type in ROUTER_MODELS: | |
| print(f"Training {router_type.upper()}...") | |
| # We need a way to inject the pre-trained KMeans into the router | |
| # so it learns to predict THESE specific clusters. | |
| # LearnedRouter currently runs KMeans in .train(). | |
| # We should modify LearnedRouter or hack it. | |
| # Hack: Initialize router, set .kmeans = kmeans_model, then train ONLY the classifier. | |
| router = LearnedRouter(model_type=router_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| router.kmeans = kmeans_model # Inject shared KMeans | |
| # We need a method to train ONLY classifier. | |
| # Let's add a 'train_classifier' method to LearnedRouter or modify 'train'. | |
| # For now, I will assume I need to modify router.py to support this. | |
| # But to avoid breaking changes mid-script, I'll do it in the script if possible. | |
| # Actually, I'll modify router.py in the next step to allow passing 'labels'. | |
| # Assuming I update router.py to accept 'labels' in train(): | |
| router.train(embeddings, labels=cluster_labels) | |
| save_name = f"router_{model_key}_{router_type}.pkl" | |
| save_path = os.path.abspath(f"models/{save_name}") | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| router.save(save_path) | |
| print(f"Saved {save_name}") | |
| print("\n>>> Full Benchmark Ingestion Complete!") | |
| if __name__ == "__main__": | |
| ingest_full_benchmark() | |