Spaces:
Running
Running
| import sys | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import ( | |
| NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS, | |
| EMBEDDING_MODELS, ROUTER_MODELS, COLLECTION_NAME, | |
| QDRANT_URL, QDRANT_API_KEY | |
| ) | |
| from src.data_pipeline import get_embeddings, load_ms_marco | |
| from src.router import LearnedRouter | |
| from src.vector_db import UnifiedQdrant | |
| def ingest_data(): | |
| print(">>> Starting Ingestion Pipeline for Qdrant Cloud...") | |
| if QDRANT_URL == ":memory:": | |
| print("WARNING: QDRANT_URL is still :memory:. Please set QDRANT_URL env var for production.") | |
| # We continue anyway for testing logic, but warn user. | |
| # 1. Load Data (101k samples for production proof) | |
| # For demo speed, we might start with 10k, but let's aim for 20k to be significant. | |
| N_SAMPLES = 1000 | |
| print(f"Loading {N_SAMPLES} samples from MS MARCO...") | |
| raw_texts = load_ms_marco(N_SAMPLES) | |
| # 2. Generate Embeddings | |
| # Use 'nomic' or 'minilm'. Let's stick to 'minilm' for speed/reliability in this demo unless specified. | |
| # Config says 'nomic' is primary, but 'minilm' is baseline. | |
| # Let's use 'minilm' for the first pass to ensure it works, or 'nomic' if we want MRL power. | |
| # The prompt mentioned MRL optimization, so 'nomic' is better if we want real MRL. | |
| # However, 'minilm' is 384 dims. 'nomic' is 768. | |
| # Our config MRL_DIMS is 64. | |
| # Let's use 'minilm' as it's faster to download/run on CPU if needed. | |
| MODEL_NAME = EMBEDDING_MODELS["minilm"] | |
| print(f"Generating embeddings using {MODEL_NAME}...") | |
| embeddings = get_embeddings(MODEL_NAME, raw_texts) | |
| vector_dim = embeddings.shape[1] | |
| # 3. Train Router | |
| # We need to train the router on this data to cluster it. | |
| print("Training Router...") | |
| router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| router.train(embeddings) | |
| # Save Router | |
| os.makedirs("models", exist_ok=True) | |
| router.save("models/router_v1.pkl") | |
| # 4. Assign Clusters (Ground Truth for Indexing) | |
| print("Assigning clusters...") | |
| # We use the router's internal KMeans to get the "Ground Truth" cluster for each point. | |
| # This ensures that the data actually lives where the router *should* predict it to be (mostly). | |
| cluster_ids = router.kmeans.predict(embeddings) | |
| # 5. Index to Qdrant | |
| print("Initializing Qdrant...") | |
| db = UnifiedQdrant( | |
| collection_name=COLLECTION_NAME, | |
| vector_size=vector_dim, | |
| num_clusters=NUM_CLUSTERS, | |
| freshness_shard_id=FRESHNESS_SHARD_ID | |
| ) | |
| db.initialize() | |
| print("Indexing data...") | |
| # Batching is handled inside index_data somewhat, but let's pass it all | |
| # The index_data method groups by shard, which is efficient for custom sharding. | |
| payloads = [{"text": t, "origin": "ms_marco"} for t in raw_texts] | |
| # We can process in chunks to avoid OOM if 20k is too big for memory (it's fine for 20k). | |
| db.index_data(embeddings, payloads, cluster_ids) | |
| print(">>> Ingestion Complete!") | |
| if __name__ == "__main__": | |
| ingest_data() | |