File size: 7,542 Bytes
657d287 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """Embed all chunks at every Matryoshka dim + SPLADE + BM25, upsert to Qdrant.
For each (module, strategy) pair:
1. Load chunks from data/processed/{module}/chunks_{strategy}.jsonl
2. Compute mxbai full-1024-dim dense embeddings (one forward pass), truncate to all 5 dims
3. Compute SPLADE sparse vectors
4. Compute BM25 sparse vectors
5. Build PointStructs (chunk_id as point ID, payload = chunk dict, vectors as named map)
6. Upsert in batches of 64 to the matching Qdrant collection
Idempotency: skips a collection if `points_count` already equals the chunk count.
Use `--force` to re-embed everything (slow). `--module` / `--strategy` to filter.
Total compute on CPU: ~30-90 minutes depending on machine, dominated by mxbai dense
embedding pass.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
from qdrant_client.http import models as rest
from tqdm import tqdm
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from pipelines.shared.embedder import DIMENSIONS, MatryoshkaEmbedder
from pipelines.shared.qdrant_client import (
SPARSE_VECTOR_NAMES,
_dense_name,
all_collection_specs,
get_client,
)
from pipelines.shared.sparse_encoder import BM25Encoder, SpladeEncoder
PROCESSED_DIR = ROOT / "data" / "processed"
UPSERT_BATCH = 64 # points per Qdrant upsert call
EMBED_BATCH = 64 # docs per dense embedding forward pass (MPS handles 64 well)
# Fields excluded from payload — large or redundant.
PAYLOAD_DROP = {"chunk_id"} # chunk_id is the Qdrant point ID; no need to store twice
def load_chunks(module: str, strategy: str) -> list[dict]:
path = PROCESSED_DIR / module / f"chunks_{strategy}.jsonl"
if not path.exists():
raise FileNotFoundError(f"missing chunk file: {path}")
return [json.loads(line) for line in path.open()]
def chunk_to_payload(chunk: dict) -> dict:
return {k: v for k, v in chunk.items() if k not in PAYLOAD_DROP}
def build_points(
chunks: list[dict],
dense_by_dim: dict[int, "np.ndarray"],
splade_vecs: list,
bm25_vecs: list,
) -> list[rest.PointStruct]:
points: list[rest.PointStruct] = []
for i, chunk in enumerate(chunks):
named_dense = {_dense_name(d): dense_by_dim[d][i].tolist() for d in DIMENSIONS}
named_sparse = {
"splade": rest.SparseVector(
indices=splade_vecs[i].indices, values=splade_vecs[i].values),
"bm25": rest.SparseVector(
indices=bm25_vecs[i].indices, values=bm25_vecs[i].values),
}
points.append(rest.PointStruct(
id=chunk["chunk_id"], # UUIDv5 from chunking_base.make_chunk_id
vector={**named_dense, **named_sparse},
payload=chunk_to_payload(chunk),
))
return points
def process_collection(
client,
module: str,
strategy: str,
coll_name: str,
embedder: MatryoshkaEmbedder,
splade: SpladeEncoder,
bm25: BM25Encoder,
*,
force: bool,
) -> dict:
chunks = load_chunks(module, strategy)
n_target = len(chunks)
info = client.get_collection(coll_name)
n_existing = info.points_count or 0
if n_existing >= n_target and not force:
print(f" ⏭ {coll_name}: {n_existing} points already ≥ {n_target} chunks. Skipping (use --force to re-embed).")
return {"collection": coll_name, "skipped": True, "n_chunks": n_target,
"n_existing": n_existing}
print(f" ▶ {coll_name}: embedding {n_target} chunks "
f"(currently {n_existing} points; force={force})")
texts = [c["content"] for c in chunks]
# 1. Dense — single mxbai forward pass for all chunks, then truncate to 5 dims.
t0 = time.perf_counter()
dense_by_dim = embedder.embed_documents_all_dims(texts, show_progress=True)
dense_t = time.perf_counter() - t0
print(f" dense embeddings ({n_target} × {DIMENSIONS}): {dense_t:.1f}s "
f"({n_target / max(dense_t, 1e-6):.1f} chunks/s)")
# 2. SPLADE sparse
t0 = time.perf_counter()
splade_vecs = splade.encode(texts, batch_size=16)
splade_t = time.perf_counter() - t0
print(f" SPLADE sparse: {splade_t:.1f}s ({n_target / max(splade_t, 1e-6):.1f} chunks/s)")
# 3. BM25 sparse
t0 = time.perf_counter()
bm25_vecs = bm25.encode_documents(texts, batch_size=64)
bm25_t = time.perf_counter() - t0
print(f" BM25 sparse: {bm25_t:.1f}s ({n_target / max(bm25_t, 1e-6):.1f} chunks/s)")
# 4. Build + upsert in batches
points = build_points(chunks, dense_by_dim, splade_vecs, bm25_vecs)
t0 = time.perf_counter()
n_upserted = 0
for i in tqdm(range(0, len(points), UPSERT_BATCH), desc=f" upsert({coll_name.split('_')[-1]})", leave=False):
batch = points[i:i + UPSERT_BATCH]
client.upsert(collection_name=coll_name, points=batch, wait=False)
n_upserted += len(batch)
upsert_t = time.perf_counter() - t0
print(f" upserted {n_upserted} points: {upsert_t:.1f}s")
# Verify after a brief pause (wait=False above)
info = client.get_collection(coll_name)
n_final = info.points_count or 0
return {
"collection": coll_name,
"skipped": False,
"n_chunks": n_target,
"n_existing_before": n_existing,
"n_existing_after": n_final,
"n_upserted": n_upserted,
"dense_seconds": round(dense_t, 1),
"splade_seconds": round(splade_t, 1),
"bm25_seconds": round(bm25_t, 1),
"upsert_seconds": round(upsert_t, 1),
}
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--module", choices=["compliance", "credit"], default=None,
help="Filter to one module")
ap.add_argument("--strategy", default=None,
help="Filter to one strategy name")
ap.add_argument("--force", action="store_true",
help="Re-embed even if collection already has points")
args = ap.parse_args()
client = get_client()
embedder = MatryoshkaEmbedder(batch_size=EMBED_BATCH)
splade = SpladeEncoder()
bm25 = BM25Encoder()
specs = all_collection_specs()
if args.module:
specs = [s for s in specs if s[0] == args.module]
if args.strategy:
specs = [s for s in specs if s[1] == args.strategy]
print(f"\nProcessing {len(specs)} collection(s)\n")
results = []
grand_t0 = time.perf_counter()
for module, strategy, coll_name in specs:
try:
r = process_collection(client, module, strategy, coll_name,
embedder, splade, bm25, force=args.force)
results.append(r)
except Exception as e:
print(f" ✗ {coll_name}: {type(e).__name__}: {e}")
import traceback; traceback.print_exc()
results.append({"collection": coll_name, "error": str(e)})
print(f"\nTotal elapsed: {time.perf_counter() - grand_t0:.1f}s")
# Summary
summary_path = PROCESSED_DIR / "_qdrant_load_summary.json"
summary_path.write_text(json.dumps({"results": results}, indent=2))
print(f"\nSummary → {summary_path.relative_to(ROOT)}")
# Final state
print("\n=== Final collection state ===")
for module, strategy, coll_name in all_collection_specs():
info = client.get_collection(coll_name)
print(f" {coll_name}: {info.points_count or 0} points")
return 0
if __name__ == "__main__":
sys.exit(main())
|