download
raw
5.35 kB
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.agent.kb_embedding import KBEmbeddingService
from app.db.chroma_client import get_collection, get_vector_backend
DATA_PATH = PROJECT_ROOT / "data" / "medical_kb.jsonl"
BATCH_SIZE = 256
def _iter_records():
with DATA_PATH.open("r", encoding="utf-8") as file_obj:
for line in file_obj:
yield json.loads(line)
def _is_duplicate_error(exc: Exception) -> bool:
message = str(exc).lower()
return "duplicate" in message or "already exists" in message or "unique" in message
def _is_disk_full_error(exc: Exception) -> bool:
message = str(exc).lower()
return "disk is full" in message or "(code: 13)" in message
def _ingest_batch(batch: list[dict]) -> tuple[int, int]:
collection = get_collection()
embedder = KBEmbeddingService()
unique_records: dict[str, dict] = {}
duplicate_ids_in_batch = 0
for record in batch:
record_id = record["id"]
if record_id in unique_records:
duplicate_ids_in_batch += 1
continue
unique_records[record_id] = record
batch_ids = list(unique_records)
existing = set(collection.get(ids=batch_ids).get("ids", []))
pending_records = [record for record_id, record in unique_records.items() if record_id not in existing]
if not pending_records:
return 0, len(batch_ids) + duplicate_ids_in_batch
ids = [record["id"] for record in pending_records]
documents = [record["content"] for record in pending_records]
metadatas = [record["metadata"] for record in pending_records]
embeddings = embedder.embed_batch(documents)
try:
collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
return len(pending_records), duplicate_ids_in_batch
except Exception as exc:
if _is_disk_full_error(exc):
raise RuntimeError(f"Disk is full while writing batch: {exc}") from exc
if not _is_duplicate_error(exc):
print(f"Batch add fallback triggered: {exc}", flush=True)
inserted = 0
skipped = 0
for record_id, embedding, document, metadata in zip(ids, embeddings, documents, metadatas):
try:
collection.add(
ids=[record_id],
embeddings=[embedding],
documents=[document],
metadatas=[metadata],
)
inserted += 1
except Exception as exc:
if _is_disk_full_error(exc):
raise RuntimeError(f"Disk is full while ingesting {record_id}: {exc}") from exc
if _is_duplicate_error(exc):
skipped += 1
continue
print(f"Failed to ingest {record_id}: {exc}", flush=True)
return inserted, skipped + duplicate_ids_in_batch
def main(force: bool = False) -> None:
if not DATA_PATH.exists():
raise FileNotFoundError(
f"Input file not found: {DATA_PATH}. Run scripts/prepare_dataset.py first."
)
collection = get_collection()
backend = get_vector_backend()
print(f"Vector backend: {backend}", flush=True)
existing_count = collection.count()
if existing_count > 0 and not force:
print(
f"Vector DB already has {existing_count} chunks. "
"Skipping ingestion to avoid re-embedding. "
"Use --force if you want to resume/check for new records.",
flush=True,
)
return
batch: list[dict] = []
total_processed = 0
total_inserted = 0
total_skipped = 0
batch_num = 0
try:
for record in _iter_records():
batch.append(record)
if len(batch) < BATCH_SIZE:
continue
batch_num += 1
total_processed += len(batch)
print(f"Ingesting batch {batch_num}... ({total_processed} records processed)", flush=True)
inserted, skipped = _ingest_batch(batch)
total_inserted += inserted
total_skipped += skipped
batch = []
if batch:
batch_num += 1
total_processed += len(batch)
print(f"Ingesting batch {batch_num}... ({total_processed} records processed)", flush=True)
inserted, skipped = _ingest_batch(batch)
total_inserted += inserted
total_skipped += skipped
except RuntimeError as exc:
print(f"Ingestion stopped early: {exc}", flush=True)
print(f"Ingestion complete. Total chunks in vector DB: {collection.count()}", flush=True)
if total_skipped:
print(f"Skipped duplicate records: {total_skipped}", flush=True)
if total_inserted:
print(f"Inserted records in this run: {total_inserted}", flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Ingest medical KB records into vector DB.")
parser.add_argument(
"--force",
action="store_true",
help="Run ingestion even if vector DB already has records.",
)
args = parser.parse_args()
main(force=args.force)

Xet Storage Details

Size:
5.35 kB
·
Xet hash:
99237cc3b81d125e6711ff40b5012502789b843887be167a306546339cbdb360

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.