"""Embed all chunks at every Matryoshka dim + SPLADE + BM25, upsert to Qdrant. For each (module, strategy) pair: 1. Load chunks from data/processed/{module}/chunks_{strategy}.jsonl 2. Compute mxbai full-1024-dim dense embeddings (one forward pass), truncate to all 5 dims 3. Compute SPLADE sparse vectors 4. Compute BM25 sparse vectors 5. Build PointStructs (chunk_id as point ID, payload = chunk dict, vectors as named map) 6. Upsert in batches of 64 to the matching Qdrant collection Idempotency: skips a collection if `points_count` already equals the chunk count. Use `--force` to re-embed everything (slow). `--module` / `--strategy` to filter. Total compute on CPU: ~30-90 minutes depending on machine, dominated by mxbai dense embedding pass. """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path from qdrant_client.http import models as rest from tqdm import tqdm ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from pipelines.shared.embedder import DIMENSIONS, MatryoshkaEmbedder from pipelines.shared.qdrant_client import ( SPARSE_VECTOR_NAMES, _dense_name, all_collection_specs, get_client, ) from pipelines.shared.sparse_encoder import BM25Encoder, SpladeEncoder PROCESSED_DIR = ROOT / "data" / "processed" UPSERT_BATCH = 64 # points per Qdrant upsert call EMBED_BATCH = 64 # docs per dense embedding forward pass (MPS handles 64 well) # Fields excluded from payload — large or redundant. PAYLOAD_DROP = {"chunk_id"} # chunk_id is the Qdrant point ID; no need to store twice def load_chunks(module: str, strategy: str) -> list[dict]: path = PROCESSED_DIR / module / f"chunks_{strategy}.jsonl" if not path.exists(): raise FileNotFoundError(f"missing chunk file: {path}") return [json.loads(line) for line in path.open()] def chunk_to_payload(chunk: dict) -> dict: return {k: v for k, v in chunk.items() if k not in PAYLOAD_DROP} def build_points( chunks: list[dict], dense_by_dim: dict[int, "np.ndarray"], splade_vecs: list, bm25_vecs: list, ) -> list[rest.PointStruct]: points: list[rest.PointStruct] = [] for i, chunk in enumerate(chunks): named_dense = {_dense_name(d): dense_by_dim[d][i].tolist() for d in DIMENSIONS} named_sparse = { "splade": rest.SparseVector( indices=splade_vecs[i].indices, values=splade_vecs[i].values), "bm25": rest.SparseVector( indices=bm25_vecs[i].indices, values=bm25_vecs[i].values), } points.append(rest.PointStruct( id=chunk["chunk_id"], # UUIDv5 from chunking_base.make_chunk_id vector={**named_dense, **named_sparse}, payload=chunk_to_payload(chunk), )) return points def process_collection( client, module: str, strategy: str, coll_name: str, embedder: MatryoshkaEmbedder, splade: SpladeEncoder, bm25: BM25Encoder, *, force: bool, ) -> dict: chunks = load_chunks(module, strategy) n_target = len(chunks) info = client.get_collection(coll_name) n_existing = info.points_count or 0 if n_existing >= n_target and not force: print(f" ⏭ {coll_name}: {n_existing} points already ≥ {n_target} chunks. Skipping (use --force to re-embed).") return {"collection": coll_name, "skipped": True, "n_chunks": n_target, "n_existing": n_existing} print(f" ▶ {coll_name}: embedding {n_target} chunks " f"(currently {n_existing} points; force={force})") texts = [c["content"] for c in chunks] # 1. Dense — single mxbai forward pass for all chunks, then truncate to 5 dims. t0 = time.perf_counter() dense_by_dim = embedder.embed_documents_all_dims(texts, show_progress=True) dense_t = time.perf_counter() - t0 print(f" dense embeddings ({n_target} × {DIMENSIONS}): {dense_t:.1f}s " f"({n_target / max(dense_t, 1e-6):.1f} chunks/s)") # 2. SPLADE sparse t0 = time.perf_counter() splade_vecs = splade.encode(texts, batch_size=16) splade_t = time.perf_counter() - t0 print(f" SPLADE sparse: {splade_t:.1f}s ({n_target / max(splade_t, 1e-6):.1f} chunks/s)") # 3. BM25 sparse t0 = time.perf_counter() bm25_vecs = bm25.encode_documents(texts, batch_size=64) bm25_t = time.perf_counter() - t0 print(f" BM25 sparse: {bm25_t:.1f}s ({n_target / max(bm25_t, 1e-6):.1f} chunks/s)") # 4. Build + upsert in batches points = build_points(chunks, dense_by_dim, splade_vecs, bm25_vecs) t0 = time.perf_counter() n_upserted = 0 for i in tqdm(range(0, len(points), UPSERT_BATCH), desc=f" upsert({coll_name.split('_')[-1]})", leave=False): batch = points[i:i + UPSERT_BATCH] client.upsert(collection_name=coll_name, points=batch, wait=False) n_upserted += len(batch) upsert_t = time.perf_counter() - t0 print(f" upserted {n_upserted} points: {upsert_t:.1f}s") # Verify after a brief pause (wait=False above) info = client.get_collection(coll_name) n_final = info.points_count or 0 return { "collection": coll_name, "skipped": False, "n_chunks": n_target, "n_existing_before": n_existing, "n_existing_after": n_final, "n_upserted": n_upserted, "dense_seconds": round(dense_t, 1), "splade_seconds": round(splade_t, 1), "bm25_seconds": round(bm25_t, 1), "upsert_seconds": round(upsert_t, 1), } def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--module", choices=["compliance", "credit"], default=None, help="Filter to one module") ap.add_argument("--strategy", default=None, help="Filter to one strategy name") ap.add_argument("--force", action="store_true", help="Re-embed even if collection already has points") args = ap.parse_args() client = get_client() embedder = MatryoshkaEmbedder(batch_size=EMBED_BATCH) splade = SpladeEncoder() bm25 = BM25Encoder() specs = all_collection_specs() if args.module: specs = [s for s in specs if s[0] == args.module] if args.strategy: specs = [s for s in specs if s[1] == args.strategy] print(f"\nProcessing {len(specs)} collection(s)\n") results = [] grand_t0 = time.perf_counter() for module, strategy, coll_name in specs: try: r = process_collection(client, module, strategy, coll_name, embedder, splade, bm25, force=args.force) results.append(r) except Exception as e: print(f" ✗ {coll_name}: {type(e).__name__}: {e}") import traceback; traceback.print_exc() results.append({"collection": coll_name, "error": str(e)}) print(f"\nTotal elapsed: {time.perf_counter() - grand_t0:.1f}s") # Summary summary_path = PROCESSED_DIR / "_qdrant_load_summary.json" summary_path.write_text(json.dumps({"results": results}, indent=2)) print(f"\nSummary → {summary_path.relative_to(ROOT)}") # Final state print("\n=== Final collection state ===") for module, strategy, coll_name in all_collection_specs(): info = client.get_collection(coll_name) print(f" {coll_name}: {info.points_count or 0} points") return 0 if __name__ == "__main__": sys.exit(main())