| """Build the Chroma schema index for one (or all) registered databases. |
| |
| Live tool — calls Mistral `mistral-embed` for vectors. Idempotent: re-runs |
| upsert chunks under the same `chunk_id` (db::table), so vectors get refreshed |
| in place and stale chunks for renamed tables are NOT auto-pruned (run with |
| ``--reset`` to clear the collection first if you have schema deletions). |
| |
| The default ``--sample-size`` is imported from ``PipelineConfig.primary_sample_size`` |
| so the index is built with the same density runtime expects. Pass an explicit |
| value only if you want to rebuild for a non-default runtime configuration. |
| |
| Usage: |
| uv run python scripts/build_index.py --db chinook |
| uv run python scripts/build_index.py --db all --persist chroma_data |
| uv run python scripts/build_index.py --db chinook --reset |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from dataclasses import fields |
| from pathlib import Path |
|
|
| import chromadb |
|
|
| from nl_sql.agent.graph import PipelineConfig |
| from nl_sql.config import get_settings |
| from nl_sql.db.registry import get_default_registry |
| from nl_sql.llm.cache import CachingEmbeddingProvider |
| from nl_sql.llm.providers.base import EmbeddingProvider |
| from nl_sql.llm.providers.mistral import MistralProvider |
| from nl_sql.schema_index.chunker import to_chunks |
| from nl_sql.schema_index.indexer import SCHEMA_COLLECTION, SchemaIndex |
| from nl_sql.schema_index.introspector import introspect |
|
|
|
|
| def _runtime_sample_size_default() -> int: |
| """Read `PipelineConfig.primary_sample_size` default without constructing |
| the dataclass (it requires live providers/registry we don't have here).""" |
| for field_ in fields(PipelineConfig): |
| if field_.name == "primary_sample_size": |
| default = field_.default |
| if isinstance(default, int): |
| return default |
| raise RuntimeError("PipelineConfig.primary_sample_size default missing") |
|
|
|
|
| DEFAULT_SAMPLE_SIZE: int = _runtime_sample_size_default() |
| """Source of truth for the sample density baked into Chroma chunks. |
| Runtime expects this to equal `PipelineConfig.primary_sample_size`; the |
| mixture appendix breaks if the index is built with more samples than |
| runtime advertises.""" |
|
|
|
|
| def build_for_db(idx: SchemaIndex, db_id: str, *, sample_size: int = DEFAULT_SAMPLE_SIZE) -> int: |
| registry = get_default_registry() |
| spec = registry.get(db_id) |
| print(f"[introspect] {db_id} ({spec.url})") |
| tables = introspect(spec.make_engine(), sample_size=sample_size) |
| print(f"[chunk] {len(tables)} tables → chunks") |
| chunks = to_chunks(tables, db_id=db_id) |
| print(f"[index] embedding + upserting {len(chunks)} chunks") |
| n = idx.index_schema(chunks) |
| print(f"[done] {db_id}: {n} chunks indexed") |
| return n |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--db", |
| required=True, |
| help="Database id (e.g. 'chinook', 'bird_california_schools') or 'all'.", |
| ) |
| parser.add_argument( |
| "--persist", |
| default="chroma_data", |
| help="Chroma persist directory (default: chroma_data/)", |
| ) |
| parser.add_argument( |
| "--sample-size", |
| type=int, |
| default=DEFAULT_SAMPLE_SIZE, |
| help=( |
| "Top-K sample values per column to bake into each chunk " |
| f"(default: {DEFAULT_SAMPLE_SIZE} = PipelineConfig.primary_sample_size). " |
| "Keep aligned with runtime or the sample-mixture appendix breaks." |
| ), |
| ) |
| parser.add_argument( |
| "--reset", |
| action="store_true", |
| help="Drop the schema_chunks collection before indexing.", |
| ) |
| parser.add_argument( |
| "--no-cache", |
| action="store_true", |
| help="Disable diskcache wrapper around the embedding provider.", |
| ) |
| return parser |
|
|
|
|
| def main(argv: list[str] | None = None) -> int: |
| args = build_parser().parse_args(argv) |
|
|
| settings = get_settings() |
| persist = Path(args.persist) |
| persist.mkdir(parents=True, exist_ok=True) |
| client = chromadb.PersistentClient(path=str(persist)) |
| if args.reset: |
| try: |
| client.delete_collection(SCHEMA_COLLECTION) |
| print(f"[reset] dropped {SCHEMA_COLLECTION}") |
| except Exception as exc: |
| print(f"[reset] no existing {SCHEMA_COLLECTION} to drop ({exc})") |
|
|
| raw_embedder = MistralProvider( |
| api_key=settings.mistral_api_key, |
| gen_model=settings.mistral_gen_model, |
| embed_model=settings.mistral_embed_model, |
| base_url=settings.mistral_base_url, |
| ) |
| embedder: EmbeddingProvider = ( |
| raw_embedder |
| if args.no_cache |
| else CachingEmbeddingProvider( |
| raw_embedder, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| ) |
| idx = SchemaIndex(persist_dir=persist, embedder=embedder, client=client) |
|
|
| registry = get_default_registry() |
| targets = registry.ids() if args.db == "all" else [args.db] |
|
|
| total = 0 |
| for db_id in targets: |
| total += build_for_db(idx, db_id, sample_size=args.sample_size) |
| print(f"[summary] indexed {total} chunks across {len(targets)} db(s)") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|