Spaces:
Sleeping
Sleeping
File size: 5,237 Bytes
5c32ed1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
Load precomputed embeddings + chunk metadata and upsert into Qdrant.
Supports both local Qdrant (docker) and Qdrant Cloud via env vars.
Creates the collection with proper HNSW config if it doesn't exist.
Usage:
python store_qdrant.py # full upsert
python store_qdrant.py --recreate # drop + recreate collection first
"""
import argparse
import json
import numpy as np
from tqdm import tqdm
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
HnswConfigDiff,
OptimizersConfigDiff,
PayloadSchemaType,
)
from config import (
RAG_CHUNKS_PATH,
EMBEDDINGS_FILE,
EMBEDDING_DIM,
QDRANT_HOST,
QDRANT_PORT,
QDRANT_COLLECTION,
QDRANT_URL,
QDRANT_API_KEY,
PROVIDER_NAME,
PROVIDER_SLUG,
)
UPSERT_BATCH_SIZE = 100
def get_client() -> QdrantClient:
if QDRANT_URL:
return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=60)
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=60)
def ensure_collection(client: QdrantClient, recreate: bool = False):
exists = client.collection_exists(QDRANT_COLLECTION)
if exists and recreate:
print(f"Dropping existing collection '{QDRANT_COLLECTION}'...")
client.delete_collection(QDRANT_COLLECTION)
exists = False
if not exists:
print(f"Creating collection '{QDRANT_COLLECTION}' (dim={EMBEDDING_DIM})...")
client.create_collection(
collection_name=QDRANT_COLLECTION,
vectors_config=VectorParams(
size=EMBEDDING_DIM,
distance=Distance.COSINE,
on_disk=False,
),
hnsw_config=HnswConfigDiff(
m=16,
ef_construct=100,
),
optimizers_config=OptimizersConfigDiff(
indexing_threshold=20000,
),
)
print("Collection created.")
print("Ensuring payload indexes for filtered search...")
for field in ("section", "policy_name", "plan_type", "doc_type", "provider"):
client.create_payload_index(
collection_name=QDRANT_COLLECTION,
field_name=field,
field_schema=PayloadSchemaType.KEYWORD,
)
print(" Indexes created: section, policy_name, plan_type, doc_type, provider")
def load_data():
print("Loading embeddings...")
data = np.load(EMBEDDINGS_FILE, allow_pickle=True)
ids = data["ids"]
embeddings = data["embeddings"]
print(f" Loaded {len(ids)} embeddings of dim {embeddings.shape[1]}")
print("Loading chunk metadata...")
with open(RAG_CHUNKS_PATH, "r", encoding="utf-8") as f:
chunks = json.load(f)
chunk_map = {c["id"]: c for c in chunks}
print(f" Loaded {len(chunks)} chunks")
return ids, embeddings, chunk_map
def build_payload(chunk: dict) -> dict:
return {
"policy_name": chunk.get("policy_name", ""),
"policy_number": chunk.get("policy_number", ""),
"effective_date": chunk.get("effective_date", ""),
"plan_type": chunk.get("plan_type", ""),
"doc_type": chunk.get("doc_type", ""),
"section": chunk.get("section", ""),
"page_start": chunk.get("page_start", 0),
"page_end": chunk.get("page_end", 0),
"chunk_index": chunk.get("chunk_index", 0),
"total_chunks_in_section": chunk.get("total_chunks_in_section", 0),
"text": chunk.get("text", ""),
"provider": PROVIDER_SLUG,
}
def upsert_points(client, ids, embeddings, chunk_map):
points = []
skipped = 0
for i, (chunk_id, vector) in enumerate(zip(ids, embeddings)):
chunk_id_str = str(chunk_id)
if chunk_id_str not in chunk_map:
skipped += 1
continue
payload = build_payload(chunk_map[chunk_id_str])
points.append(
PointStruct(
id=i,
vector=vector.tolist(),
payload=payload,
)
)
if skipped:
print(f" Skipped {skipped} embeddings (no matching chunk metadata)")
print(f" Upserting {len(points)} points in batches of {UPSERT_BATCH_SIZE}...")
for batch_start in tqdm(range(0, len(points), UPSERT_BATCH_SIZE), desc="Upserting"):
batch = points[batch_start : batch_start + UPSERT_BATCH_SIZE]
client.upsert(collection_name=QDRANT_COLLECTION, points=batch, wait=True)
return len(points)
def main():
parser = argparse.ArgumentParser(description="Store embeddings in Qdrant")
parser.add_argument("--recreate", action="store_true", help="Drop and recreate collection")
args = parser.parse_args()
client = get_client()
ensure_collection(client, recreate=args.recreate)
ids, embeddings, chunk_map = load_data()
total = upsert_points(client, ids, embeddings, chunk_map)
info = client.get_collection(QDRANT_COLLECTION)
print(f"\nDone. Collection '{QDRANT_COLLECTION}' now has {info.points_count} points.")
print(f" Vectors dim: {EMBEDDING_DIM}")
print(f" Distance: COSINE")
print(f" Provider: {PROVIDER_NAME}")
if __name__ == "__main__":
main()
|