| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from app.agent.kb_embedding import KBEmbeddingService | |
| from app.db.chroma_client import get_collection, get_vector_backend | |
| DATA_PATH = PROJECT_ROOT / "data" / "medical_kb.jsonl" | |
| BATCH_SIZE = 256 | |
| def _iter_records(): | |
| with DATA_PATH.open("r", encoding="utf-8") as file_obj: | |
| for line in file_obj: | |
| yield json.loads(line) | |
| def _is_duplicate_error(exc: Exception) -> bool: | |
| message = str(exc).lower() | |
| return "duplicate" in message or "already exists" in message or "unique" in message | |
| def _is_disk_full_error(exc: Exception) -> bool: | |
| message = str(exc).lower() | |
| return "disk is full" in message or "(code: 13)" in message | |
| def _ingest_batch(batch: list[dict]) -> tuple[int, int]: | |
| collection = get_collection() | |
| embedder = KBEmbeddingService() | |
| unique_records: dict[str, dict] = {} | |
| duplicate_ids_in_batch = 0 | |
| for record in batch: | |
| record_id = record["id"] | |
| if record_id in unique_records: | |
| duplicate_ids_in_batch += 1 | |
| continue | |
| unique_records[record_id] = record | |
| batch_ids = list(unique_records) | |
| existing = set(collection.get(ids=batch_ids).get("ids", [])) | |
| pending_records = [record for record_id, record in unique_records.items() if record_id not in existing] | |
| if not pending_records: | |
| return 0, len(batch_ids) + duplicate_ids_in_batch | |
| ids = [record["id"] for record in pending_records] | |
| documents = [record["content"] for record in pending_records] | |
| metadatas = [record["metadata"] for record in pending_records] | |
| embeddings = embedder.embed_batch(documents) | |
| try: | |
| collection.add( | |
| ids=ids, | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ) | |
| return len(pending_records), duplicate_ids_in_batch | |
| except Exception as exc: | |
| if _is_disk_full_error(exc): | |
| raise RuntimeError(f"Disk is full while writing batch: {exc}") from exc | |
| if not _is_duplicate_error(exc): | |
| print(f"Batch add fallback triggered: {exc}", flush=True) | |
| inserted = 0 | |
| skipped = 0 | |
| for record_id, embedding, document, metadata in zip(ids, embeddings, documents, metadatas): | |
| try: | |
| collection.add( | |
| ids=[record_id], | |
| embeddings=[embedding], | |
| documents=[document], | |
| metadatas=[metadata], | |
| ) | |
| inserted += 1 | |
| except Exception as exc: | |
| if _is_disk_full_error(exc): | |
| raise RuntimeError(f"Disk is full while ingesting {record_id}: {exc}") from exc | |
| if _is_duplicate_error(exc): | |
| skipped += 1 | |
| continue | |
| print(f"Failed to ingest {record_id}: {exc}", flush=True) | |
| return inserted, skipped + duplicate_ids_in_batch | |
| def main(force: bool = False) -> None: | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Input file not found: {DATA_PATH}. Run scripts/prepare_dataset.py first." | |
| ) | |
| collection = get_collection() | |
| backend = get_vector_backend() | |
| print(f"Vector backend: {backend}", flush=True) | |
| existing_count = collection.count() | |
| if existing_count > 0 and not force: | |
| print( | |
| f"Vector DB already has {existing_count} chunks. " | |
| "Skipping ingestion to avoid re-embedding. " | |
| "Use --force if you want to resume/check for new records.", | |
| flush=True, | |
| ) | |
| return | |
| batch: list[dict] = [] | |
| total_processed = 0 | |
| total_inserted = 0 | |
| total_skipped = 0 | |
| batch_num = 0 | |
| try: | |
| for record in _iter_records(): | |
| batch.append(record) | |
| if len(batch) < BATCH_SIZE: | |
| continue | |
| batch_num += 1 | |
| total_processed += len(batch) | |
| print(f"Ingesting batch {batch_num}... ({total_processed} records processed)", flush=True) | |
| inserted, skipped = _ingest_batch(batch) | |
| total_inserted += inserted | |
| total_skipped += skipped | |
| batch = [] | |
| if batch: | |
| batch_num += 1 | |
| total_processed += len(batch) | |
| print(f"Ingesting batch {batch_num}... ({total_processed} records processed)", flush=True) | |
| inserted, skipped = _ingest_batch(batch) | |
| total_inserted += inserted | |
| total_skipped += skipped | |
| except RuntimeError as exc: | |
| print(f"Ingestion stopped early: {exc}", flush=True) | |
| print(f"Ingestion complete. Total chunks in vector DB: {collection.count()}", flush=True) | |
| if total_skipped: | |
| print(f"Skipped duplicate records: {total_skipped}", flush=True) | |
| if total_inserted: | |
| print(f"Inserted records in this run: {total_inserted}", flush=True) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Ingest medical KB records into vector DB.") | |
| parser.add_argument( | |
| "--force", | |
| action="store_true", | |
| help="Run ingestion even if vector DB already has records.", | |
| ) | |
| args = parser.parse_args() | |
| main(force=args.force) | |
Xet Storage Details
- Size:
- 5.35 kB
- Xet hash:
- 99237cc3b81d125e6711ff40b5012502789b843887be167a306546339cbdb360
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.