|
|
""" |
|
|
Build a Chroma + SQLite index for this RAG system (offline / advanced users). |
|
|
|
|
|
The index output folder is compatible with the Space runtime bootstrap: |
|
|
<output_dir>/ |
|
|
chroma_db/ |
|
|
doc_store.db |
|
|
manifest.json |
|
|
|
|
|
Examples: |
|
|
|
|
|
1) Build from HF dataset directly (streaming is not supported for save_to_disk-based build): |
|
|
python scripts/build_vector_db.py \ |
|
|
--config config/default_config.yaml \ |
|
|
--source huggingface \ |
|
|
--dataset ZhangNy/radiology-dataset \ |
|
|
--output-dir ./index_out |
|
|
|
|
|
2) Build from local saved dataset: |
|
|
python scripts/build_vector_db.py \ |
|
|
--config config/default_config.yaml \ |
|
|
--source local \ |
|
|
--local-path ./hf_dataset_prepared \ |
|
|
--output-dir ./index_out |
|
|
|
|
|
Notes: |
|
|
- Embedding model used at build time must match query-time embeddings used in the Space, |
|
|
otherwise retrieval quality will degrade. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import shutil |
|
|
import time |
|
|
from collections import Counter |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parents[1])) |
|
|
|
|
|
|
|
|
def _clean_text(text: str) -> str: |
|
|
|
|
|
import re |
|
|
|
|
|
t = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text or "") |
|
|
return t.replace("\xa0", " ") |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
parser = argparse.ArgumentParser(description="Build vector index (Chroma + SQLite doc store)") |
|
|
parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Config YAML path") |
|
|
parser.add_argument("--source", choices=["local", "huggingface"], default="huggingface") |
|
|
parser.add_argument("--local-path", type=str, default=None, help="Path to dataset saved via save_to_disk()") |
|
|
parser.add_argument("--dataset", type=str, default="ZhangNy/radiology-dataset", help="HF dataset repo id") |
|
|
parser.add_argument("--split", type=str, default="train") |
|
|
parser.add_argument("--limit", type=int, default=None, help="Limit number of documents (debug)") |
|
|
parser.add_argument("--output-dir", type=str, default="./index_out", help="Output directory for index artifacts") |
|
|
parser.add_argument("--overwrite", action="store_true", help="Overwrite output dir if exists") |
|
|
args = parser.parse_args() |
|
|
|
|
|
from datasets import load_dataset, load_from_disk |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_core.documents import Document |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
|
|
from radiology_rag.config import Config |
|
|
from radiology_rag.doc_store import PersistentDocStore |
|
|
from radiology_rag.embedding import EmbeddingClient, EmbeddingConfig |
|
|
|
|
|
cfg = Config(args.config) |
|
|
|
|
|
out_dir = Path(args.output_dir) |
|
|
chroma_dir = out_dir / "chroma_db" |
|
|
doc_db = out_dir / "doc_store.db" |
|
|
manifest_path = out_dir / "manifest.json" |
|
|
|
|
|
if out_dir.exists() and args.overwrite: |
|
|
shutil.rmtree(out_dir) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if chroma_dir.exists() or doc_db.exists(): |
|
|
if not args.overwrite: |
|
|
raise SystemExit(f"Output dir already has index artifacts. Use --overwrite. ({out_dir})") |
|
|
|
|
|
|
|
|
if args.source == "local": |
|
|
if not args.local_path: |
|
|
raise SystemExit("--local-path is required when --source local") |
|
|
dataset = load_from_disk(args.local_path) |
|
|
else: |
|
|
dataset = load_dataset(args.dataset, split=args.split) |
|
|
|
|
|
if args.limit: |
|
|
dataset = dataset.select(range(min(int(args.limit), len(dataset)))) |
|
|
|
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=cfg.get_int("processing.chunk_size", 1024), |
|
|
chunk_overlap=cfg.get_int("processing.chunk_overlap", 200), |
|
|
separators=cfg.get("processing.separators", ["\n\n", "\n", " "]), |
|
|
keep_separator=cfg.get_bool("processing.keep_separator", True), |
|
|
) |
|
|
|
|
|
|
|
|
emb = EmbeddingClient( |
|
|
EmbeddingConfig( |
|
|
base_url=cfg.get_str("embedding.api_base_url"), |
|
|
api_key=cfg.get_str("embedding.api_key"), |
|
|
model_name=cfg.get_str("embedding.model_name"), |
|
|
batch_size=cfg.get_int("embedding.batch_size", 32), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
doc_store = PersistentDocStore(str(doc_db), read_only=False) |
|
|
vectorstore = Chroma( |
|
|
collection_name="radiology_docs", |
|
|
embedding_function=emb.langchain_embeddings, |
|
|
persist_directory=str(chroma_dir), |
|
|
) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
parent_pairs: List[Tuple[str, Dict[str, Any]]] = [] |
|
|
child_docs: List[Document] = [] |
|
|
counts = Counter() |
|
|
|
|
|
for item in dataset: |
|
|
doc_id = (item.get("doc_id") or "").strip() |
|
|
if not doc_id: |
|
|
continue |
|
|
source_type = (item.get("source_type") or "").strip() |
|
|
title = (item.get("title") or "").strip() |
|
|
content = _clean_text(item.get("content") or "") |
|
|
url = (item.get("url") or "").strip() |
|
|
metadata = item.get("metadata") or {} |
|
|
|
|
|
counts[source_type or "unknown"] += 1 |
|
|
|
|
|
|
|
|
parent_pairs.append( |
|
|
( |
|
|
doc_id, |
|
|
{ |
|
|
"complete_document": { |
|
|
"doc_id": doc_id, |
|
|
"title": title, |
|
|
"content": content, |
|
|
"url": url, |
|
|
"metadata": metadata, |
|
|
}, |
|
|
"main_content": content, |
|
|
"images": [], |
|
|
"source_type": source_type, |
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
chunks = splitter.split_text(content) |
|
|
total = len(chunks) |
|
|
for i, chunk in enumerate(chunks): |
|
|
child_docs.append( |
|
|
Document( |
|
|
page_content=chunk, |
|
|
metadata={ |
|
|
"doc_id": f"{doc_id}_chunk_{i}", |
|
|
"parent_id": doc_id, |
|
|
"source_type": source_type, |
|
|
"title": title, |
|
|
"chunk_index": i, |
|
|
"total_chunks": total, |
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
doc_store.mset(parent_pairs) |
|
|
|
|
|
|
|
|
batch_size = int(cfg.get_int("processing.batch_size", 32)) |
|
|
for i in range(0, len(child_docs), batch_size): |
|
|
vectorstore.add_documents(child_docs[i : i + batch_size]) |
|
|
|
|
|
elapsed = time.time() - start |
|
|
|
|
|
|
|
|
manifest = { |
|
|
"built_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), |
|
|
"seconds": elapsed, |
|
|
"dataset": {"source": args.source, "dataset": args.dataset, "split": args.split, "limit": args.limit}, |
|
|
"embedding": {"type": "api", "model_name": cfg.get_str("embedding.model_name"), "base_url": cfg.get_str("embedding.api_base_url")}, |
|
|
"processing": { |
|
|
"chunk_size": cfg.get_int("processing.chunk_size", 1024), |
|
|
"chunk_overlap": cfg.get_int("processing.chunk_overlap", 200), |
|
|
}, |
|
|
"counts_by_source_type": dict(counts), |
|
|
"artifacts": {"chroma_dir": "chroma_db", "doc_store": "doc_store.db"}, |
|
|
} |
|
|
with open(manifest_path, "w", encoding="utf-8") as f: |
|
|
json.dump(manifest, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"✓ Index built at: {out_dir}") |
|
|
print(f" - documents: {sum(counts.values())} (by type: {dict(counts)})") |
|
|
print(f" - chunks: {len(child_docs)}") |
|
|
print(f" - elapsed: {elapsed:.1f}s") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
raise SystemExit(main()) |
|
|
|
|
|
|
|
|
|