""" Load precomputed embeddings + chunk metadata and upsert into Qdrant. Supports both local Qdrant (docker) and Qdrant Cloud via env vars. Creates the collection with proper HNSW config if it doesn't exist. Usage: python store_qdrant.py # full upsert python store_qdrant.py --recreate # drop + recreate collection first """ import argparse import json import numpy as np from tqdm import tqdm from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, PointStruct, HnswConfigDiff, OptimizersConfigDiff, PayloadSchemaType, ) from config import ( RAG_CHUNKS_PATH, EMBEDDINGS_FILE, EMBEDDING_DIM, QDRANT_HOST, QDRANT_PORT, QDRANT_COLLECTION, QDRANT_URL, QDRANT_API_KEY, PROVIDER_NAME, PROVIDER_SLUG, ) UPSERT_BATCH_SIZE = 100 def get_client() -> QdrantClient: if QDRANT_URL: return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=60) return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=60) def ensure_collection(client: QdrantClient, recreate: bool = False): exists = client.collection_exists(QDRANT_COLLECTION) if exists and recreate: print(f"Dropping existing collection '{QDRANT_COLLECTION}'...") client.delete_collection(QDRANT_COLLECTION) exists = False if not exists: print(f"Creating collection '{QDRANT_COLLECTION}' (dim={EMBEDDING_DIM})...") client.create_collection( collection_name=QDRANT_COLLECTION, vectors_config=VectorParams( size=EMBEDDING_DIM, distance=Distance.COSINE, on_disk=False, ), hnsw_config=HnswConfigDiff( m=16, ef_construct=100, ), optimizers_config=OptimizersConfigDiff( indexing_threshold=20000, ), ) print("Collection created.") print("Ensuring payload indexes for filtered search...") for field in ("section", "policy_name", "plan_type", "doc_type", "provider"): client.create_payload_index( collection_name=QDRANT_COLLECTION, field_name=field, field_schema=PayloadSchemaType.KEYWORD, ) print(" Indexes created: section, policy_name, plan_type, doc_type, provider") def load_data(): print("Loading embeddings...") data = np.load(EMBEDDINGS_FILE, allow_pickle=True) ids = data["ids"] embeddings = data["embeddings"] print(f" Loaded {len(ids)} embeddings of dim {embeddings.shape[1]}") print("Loading chunk metadata...") with open(RAG_CHUNKS_PATH, "r", encoding="utf-8") as f: chunks = json.load(f) chunk_map = {c["id"]: c for c in chunks} print(f" Loaded {len(chunks)} chunks") return ids, embeddings, chunk_map def build_payload(chunk: dict) -> dict: return { "policy_name": chunk.get("policy_name", ""), "policy_number": chunk.get("policy_number", ""), "effective_date": chunk.get("effective_date", ""), "plan_type": chunk.get("plan_type", ""), "doc_type": chunk.get("doc_type", ""), "section": chunk.get("section", ""), "page_start": chunk.get("page_start", 0), "page_end": chunk.get("page_end", 0), "chunk_index": chunk.get("chunk_index", 0), "total_chunks_in_section": chunk.get("total_chunks_in_section", 0), "text": chunk.get("text", ""), "provider": PROVIDER_SLUG, } def upsert_points(client, ids, embeddings, chunk_map): points = [] skipped = 0 for i, (chunk_id, vector) in enumerate(zip(ids, embeddings)): chunk_id_str = str(chunk_id) if chunk_id_str not in chunk_map: skipped += 1 continue payload = build_payload(chunk_map[chunk_id_str]) points.append( PointStruct( id=i, vector=vector.tolist(), payload=payload, ) ) if skipped: print(f" Skipped {skipped} embeddings (no matching chunk metadata)") print(f" Upserting {len(points)} points in batches of {UPSERT_BATCH_SIZE}...") for batch_start in tqdm(range(0, len(points), UPSERT_BATCH_SIZE), desc="Upserting"): batch = points[batch_start : batch_start + UPSERT_BATCH_SIZE] client.upsert(collection_name=QDRANT_COLLECTION, points=batch, wait=True) return len(points) def main(): parser = argparse.ArgumentParser(description="Store embeddings in Qdrant") parser.add_argument("--recreate", action="store_true", help="Drop and recreate collection") args = parser.parse_args() client = get_client() ensure_collection(client, recreate=args.recreate) ids, embeddings, chunk_map = load_data() total = upsert_points(client, ids, embeddings, chunk_map) info = client.get_collection(QDRANT_COLLECTION) print(f"\nDone. Collection '{QDRANT_COLLECTION}' now has {info.points_count} points.") print(f" Vectors dim: {EMBEDDING_DIM}") print(f" Distance: COSINE") print(f" Provider: {PROVIDER_NAME}") if __name__ == "__main__": main()