| """Resource bootstrap + pipeline factory for the Streamlit UI.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import chromadb |
| import streamlit as st |
|
|
| from nl_sql.agent.graph import PipelineConfig, build_pipeline |
| from nl_sql.config import get_settings |
| from nl_sql.db.registry import DatabaseRegistry, get_default_registry |
| from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider |
| from nl_sql.llm.providers import build_provider |
| from nl_sql.llm.providers.base import EmbeddingProvider, LLMProvider |
| from nl_sql.llm.providers.mistral import MistralProvider |
| from nl_sql.schema_index.indexer import SchemaIndex |
|
|
|
|
| @st.cache_resource(show_spinner="Initialising providers + Chroma index…") |
| def bootstrap() -> tuple[DatabaseRegistry, SchemaIndex, LLMProvider, LLMProvider]: |
| settings = get_settings() |
| if not settings.mistral_api_key: |
| raise RuntimeError( |
| "MISTRAL_API_KEY is not set in .env — required for codestral + mistral-embed." |
| ) |
|
|
| registry = get_default_registry() |
|
|
| persist_dir = Path("chroma_data") |
| if not persist_dir.is_dir(): |
| raise RuntimeError( |
| f"Chroma persist dir {persist_dir!r} not found. " |
| "Run `uv run python scripts/build_index.py --db all` first." |
| ) |
| chroma_client = chromadb.PersistentClient(path=str(persist_dir)) |
|
|
| raw_embedder = MistralProvider( |
| api_key=settings.mistral_api_key, |
| gen_model=settings.mistral_gen_model, |
| embed_model=settings.mistral_embed_model, |
| base_url=settings.mistral_base_url, |
| ) |
| embedder: EmbeddingProvider = CachingEmbeddingProvider( |
| raw_embedder, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| schema_index = SchemaIndex(persist_dir=persist_dir, embedder=embedder, client=chroma_client) |
|
|
| raw_sql = build_provider("mistral", settings=settings) |
| sql_provider: LLMProvider = CachingLLMProvider( |
| raw_sql, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| explain_provider = sql_provider |
|
|
| return registry, schema_index, sql_provider, explain_provider |
|
|
|
|
| def make_pipeline( |
| registry: DatabaseRegistry, |
| schema_index: SchemaIndex, |
| sql_provider: LLMProvider, |
| explain_provider: LLMProvider, |
| *, |
| schema_top_k: int, |
| fk_hops: int, |
| table_budget: int, |
| sort_schema_block: bool, |
| extended_sample_size: int, |
| fewshot_top_k: int = 3, |
| cross_db_fewshot: bool = True, |
| verify_retry_on_empty: bool = True, |
| ) -> Any: |
| config = PipelineConfig( |
| sql_provider=sql_provider, |
| explain_provider=explain_provider, |
| schema_index=schema_index, |
| registry=registry, |
| schema_top_k=schema_top_k, |
| fewshot_top_k=fewshot_top_k, |
| fk_hops=fk_hops, |
| table_budget=table_budget, |
| sort_schema_block=sort_schema_block, |
| primary_sample_size=3, |
| extended_sample_size=extended_sample_size, |
| cross_db_fewshot=cross_db_fewshot, |
| verify_retry_on_empty=verify_retry_on_empty, |
| ) |
| return build_pipeline(config) |
|
|