Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import ( | |
| NUM_CLUSTERS, FRESHNESS_SHARD_ID, MRL_DIMS, | |
| EMBEDDING_MODELS, ROUTER_MODELS, COLLECTION_NAME, | |
| QDRANT_URL, QDRANT_API_KEY | |
| ) | |
| from src.data_pipeline import get_embeddings, load_ms_marco | |
| from src.router import LearnedRouter | |
| from src.vector_db import UnifiedQdrant | |
| ROUTER_PATH = "models/router_v1.pkl" | |
| def ingest_data(): | |
| print(">>> Starting Ingestion Pipeline for Qdrant Cloud...") | |
| if QDRANT_URL == ":memory:": | |
| print("WARNING: QDRANT_URL is still :memory:. Please set QDRANT_URL env var for production.") | |
| # We continue anyway for testing logic, but warn user. | |
| # 1. Load Data (101k samples for production proof) | |
| # For demo speed, we might start with 10k, but let's aim for 20k to be significant. | |
| N_SAMPLES = 25000 | |
| print(f"Loading {N_SAMPLES} samples from MS MARCO...") | |
| raw_texts = load_ms_marco(N_SAMPLES) | |
| # 2. Generate Embeddings | |
| # Use 'nomic' or 'minilm'. Let's stick to 'minilm' for speed/reliability in this demo unless specified. | |
| # Config says 'nomic' is primary, but 'minilm' is baseline. | |
| # Let's use 'minilm' for the first pass to ensure it works, or 'nomic' if we want MRL power. | |
| # The prompt mentioned MRL optimization, so 'nomic' is better if we want real MRL. | |
| # However, 'minilm' is 384 dims. 'nomic' is 768. | |
| # Our config MRL_DIMS is 64. | |
| # Let's use 'minilm' as it's faster to download/run on CPU if needed. | |
| MODEL_NAME = EMBEDDING_MODELS["minilm"] | |
| print(f"Generating embeddings using {MODEL_NAME}...") | |
| embeddings = get_embeddings(MODEL_NAME, raw_texts) | |
| vector_dim = embeddings.shape[1] | |
| # 3. Train Router | |
| print("Training Router...") | |
| router = LearnedRouter(model_type="lightgbm", n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| # Train router (runs KMeans internally if labels not provided) | |
| router.train(embeddings) | |
| # Get cluster labels for indexing | |
| cluster_labels = router.kmeans.predict(embeddings) | |
| print("Router training complete.") | |
| # Save Router | |
| abs_router_path = os.path.abspath(ROUTER_PATH) | |
| print(f"Saving router to {abs_router_path}...") | |
| os.makedirs(os.path.dirname(abs_router_path), exist_ok=True) | |
| router.save(abs_router_path) | |
| print("Router saved.") | |
| # 4. Index Data | |
| print("Assigning clusters...") | |
| # For indexing, we need to know which cluster each point belongs to | |
| # We already have cluster_labels from KMeans | |
| print("Initializing Qdrant...") | |
| db = UnifiedQdrant( | |
| collection_name=COLLECTION_NAME, | |
| vector_size=vector_dim, | |
| num_clusters=NUM_CLUSTERS, | |
| freshness_shard_id=FRESHNESS_SHARD_ID | |
| ) | |
| db.initialize() | |
| print("Indexing data...") | |
| # Convert embeddings to list of lists if needed, but numpy is fine for our method | |
| # Create dummy payloads | |
| payloads = [{"text": text, "source": "ms_marco"} for text in raw_texts] | |
| # Use the cluster labels as the target shards | |
| # Note: cluster_labels are int32, convert to int | |
| target_clusters = [int(c) for c in cluster_labels] | |
| db.index_data(embeddings, payloads, target_clusters) | |
| # Verify Count | |
| print("Verifying index count...") | |
| info = db.client.get_collection(COLLECTION_NAME) | |
| print(f"Collection '{COLLECTION_NAME}' has {info.points_count} points.") | |
| print(">>> Ingestion Complete!") | |
| if __name__ == "__main__": | |
| ingest_data() | |