nl-sql / app /bootstrap.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
424ea19 verified
Raw
History Blame Contribute Delete
3.18 kB
"""Resource bootstrap + pipeline factory for the Streamlit UI."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import chromadb
import streamlit as st
from nl_sql.agent.graph import PipelineConfig, build_pipeline
from nl_sql.config import get_settings
from nl_sql.db.registry import DatabaseRegistry, get_default_registry
from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider
from nl_sql.llm.providers import build_provider
from nl_sql.llm.providers.base import EmbeddingProvider, LLMProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.indexer import SchemaIndex
@st.cache_resource(show_spinner="Initialising providers + Chroma index…")
def bootstrap() -> tuple[DatabaseRegistry, SchemaIndex, LLMProvider, LLMProvider]:
settings = get_settings()
if not settings.mistral_api_key:
raise RuntimeError(
"MISTRAL_API_KEY is not set in .env — required for codestral + mistral-embed."
)
registry = get_default_registry()
persist_dir = Path("chroma_data")
if not persist_dir.is_dir():
raise RuntimeError(
f"Chroma persist dir {persist_dir!r} not found. "
"Run `uv run python scripts/build_index.py --db all` first."
)
chroma_client = chromadb.PersistentClient(path=str(persist_dir))
raw_embedder = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
embedder: EmbeddingProvider = CachingEmbeddingProvider(
raw_embedder,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
schema_index = SchemaIndex(persist_dir=persist_dir, embedder=embedder, client=chroma_client)
raw_sql = build_provider("mistral", settings=settings)
sql_provider: LLMProvider = CachingLLMProvider(
raw_sql,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
explain_provider = sql_provider
return registry, schema_index, sql_provider, explain_provider
def make_pipeline(
registry: DatabaseRegistry,
schema_index: SchemaIndex,
sql_provider: LLMProvider,
explain_provider: LLMProvider,
*,
schema_top_k: int,
fk_hops: int,
table_budget: int,
sort_schema_block: bool,
extended_sample_size: int,
fewshot_top_k: int = 3,
cross_db_fewshot: bool = True,
verify_retry_on_empty: bool = True,
) -> Any:
config = PipelineConfig(
sql_provider=sql_provider,
explain_provider=explain_provider,
schema_index=schema_index,
registry=registry,
schema_top_k=schema_top_k,
fewshot_top_k=fewshot_top_k,
fk_hops=fk_hops,
table_budget=table_budget,
sort_schema_block=sort_schema_block,
primary_sample_size=3,
extended_sample_size=extended_sample_size,
cross_db_fewshot=cross_db_fewshot,
verify_retry_on_empty=verify_retry_on_empty,
)
return build_pipeline(config)