File size: 5,293 Bytes
942050b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Build the Chroma schema index for one (or all) registered databases.
Live tool — calls Mistral `mistral-embed` for vectors. Idempotent: re-runs
upsert chunks under the same `chunk_id` (db::table), so vectors get refreshed
in place and stale chunks for renamed tables are NOT auto-pruned (run with
``--reset`` to clear the collection first if you have schema deletions).
The default ``--sample-size`` is imported from ``PipelineConfig.primary_sample_size``
so the index is built with the same density runtime expects. Pass an explicit
value only if you want to rebuild for a non-default runtime configuration.
Usage:
uv run python scripts/build_index.py --db chinook
uv run python scripts/build_index.py --db all --persist chroma_data
uv run python scripts/build_index.py --db chinook --reset
"""
from __future__ import annotations
import argparse
import sys
from dataclasses import fields
from pathlib import Path
import chromadb
from nl_sql.agent.graph import PipelineConfig
from nl_sql.config import get_settings
from nl_sql.db.registry import get_default_registry
from nl_sql.llm.cache import CachingEmbeddingProvider
from nl_sql.llm.providers.base import EmbeddingProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.chunker import to_chunks
from nl_sql.schema_index.indexer import SCHEMA_COLLECTION, SchemaIndex
from nl_sql.schema_index.introspector import introspect
def _runtime_sample_size_default() -> int:
"""Read `PipelineConfig.primary_sample_size` default without constructing
the dataclass (it requires live providers/registry we don't have here)."""
for field_ in fields(PipelineConfig):
if field_.name == "primary_sample_size":
default = field_.default
if isinstance(default, int):
return default
raise RuntimeError("PipelineConfig.primary_sample_size default missing")
DEFAULT_SAMPLE_SIZE: int = _runtime_sample_size_default()
"""Source of truth for the sample density baked into Chroma chunks.
Runtime expects this to equal `PipelineConfig.primary_sample_size`; the
mixture appendix breaks if the index is built with more samples than
runtime advertises."""
def build_for_db(idx: SchemaIndex, db_id: str, *, sample_size: int = DEFAULT_SAMPLE_SIZE) -> int:
registry = get_default_registry()
spec = registry.get(db_id)
print(f"[introspect] {db_id} ({spec.url})")
tables = introspect(spec.make_engine(), sample_size=sample_size)
print(f"[chunk] {len(tables)} tables → chunks")
chunks = to_chunks(tables, db_id=db_id)
print(f"[index] embedding + upserting {len(chunks)} chunks")
n = idx.index_schema(chunks)
print(f"[done] {db_id}: {n} chunks indexed")
return n
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--db",
required=True,
help="Database id (e.g. 'chinook', 'bird_california_schools') or 'all'.",
)
parser.add_argument(
"--persist",
default="chroma_data",
help="Chroma persist directory (default: chroma_data/)",
)
parser.add_argument(
"--sample-size",
type=int,
default=DEFAULT_SAMPLE_SIZE,
help=(
"Top-K sample values per column to bake into each chunk "
f"(default: {DEFAULT_SAMPLE_SIZE} = PipelineConfig.primary_sample_size). "
"Keep aligned with runtime or the sample-mixture appendix breaks."
),
)
parser.add_argument(
"--reset",
action="store_true",
help="Drop the schema_chunks collection before indexing.",
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable diskcache wrapper around the embedding provider.",
)
return parser
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
settings = get_settings()
persist = Path(args.persist)
persist.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(persist))
if args.reset:
try:
client.delete_collection(SCHEMA_COLLECTION)
print(f"[reset] dropped {SCHEMA_COLLECTION}")
except Exception as exc:
print(f"[reset] no existing {SCHEMA_COLLECTION} to drop ({exc})")
raw_embedder = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
embedder: EmbeddingProvider = (
raw_embedder
if args.no_cache
else CachingEmbeddingProvider(
raw_embedder,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
)
idx = SchemaIndex(persist_dir=persist, embedder=embedder, client=client)
registry = get_default_registry()
targets = registry.ids() if args.db == "all" else [args.db]
total = 0
for db_id in targets:
total += build_for_db(idx, db_id, sample_size=args.sample_size)
print(f"[summary] indexed {total} chunks across {len(targets)} db(s)")
return 0
if __name__ == "__main__":
sys.exit(main())
|