sqlbot / ai /pipeline.py
sqlbot
Initial Hugging Face sqlbot setup
28035e9
"""Main DSPy reasoning pipeline β€” optimized for speed.
Reduced from 9 stages to 4 LLM calls in the happy path:
1. AnalyzeAndPlan (question understanding + schema analysis + query planning)
2. SQLGeneration
3. SQLCritiqueAndFix (one pass; only retries on failure)
4. InterpretAndInsight (interpretation + insights in one call)
"""
import json
import logging
import re
from typing import Any
import dspy
from ai.groq_setup import get_lm
from ai.signatures import (
AnalyzeAndPlan,
SQLGeneration,
SQLRepair,
InterpretAndInsight,
)
from ai.validator import validate_sql, check_sql_against_schema
from db.schema import format_schema
from db.relationships import format_relationships
from db.profiler import get_data_profile
from db.executor import execute_sql
logger = logging.getLogger(__name__)
MAX_REPAIR_RETRIES = 2
class SQLAnalystPipeline:
"""End-to-end reasoning pipeline: question β†’ SQL β†’ results β†’ insights."""
def __init__(self, provider: str = "groq"):
self.provider = provider
self._lm = get_lm(provider)
# DSPy predict modules
self.analyze = dspy.Predict(AnalyzeAndPlan)
self.generate_sql = dspy.Predict(SQLGeneration)
self.interpret = dspy.Predict(InterpretAndInsight)
self.repair = dspy.Predict(SQLRepair)
# ── public API ──────────────────────────────────────────────────────
def run(self, question: str) -> dict[str, Any]:
"""Run the full pipeline and return {sql, data, answer, insights}."""
schema_str = format_schema()
rels_str = format_relationships()
profile_str = get_data_profile()
# 1. Analyze & Plan (single LLM call replaces 3 former stages)
logger.info("Stage 1 β€” Analyze & Plan")
plan = self.analyze(
question=question,
schema_info=schema_str,
relationships=rels_str,
data_profile=profile_str,
)
plan_text = (
f"Intent: {plan.intent}\n"
f"Tables: {plan.relevant_tables}\n"
f"Columns: {plan.relevant_columns}\n"
f"Joins: {plan.join_conditions}\n"
f"Where: {plan.where_conditions}\n"
f"Aggregations: {plan.aggregations}\n"
f"Group By: {plan.group_by}\n"
f"Order By: {plan.order_by}\n"
f"Limit: {plan.limit_val}"
)
# 2. SQL Generation
logger.info("Stage 2 β€” SQL Generation")
sql_result = self.generate_sql(
question=question,
schema_info=schema_str,
query_plan=plan_text,
)
sql = self._clean_sql(sql_result.sql_query)
# 3. Code-based schema validation (instant β€” no LLM call)
logger.info("Stage 3 β€” Schema Validation")
from db.schema import get_schema
schema_valid, schema_issues = check_sql_against_schema(sql, get_schema())
if not schema_valid:
logger.warning(f"Schema issues detected: {schema_issues}")
# Try regenerating SQL once with the issues as feedback
sql_result = self.generate_sql(
question=question,
schema_info=schema_str,
query_plan=plan_text + f"\n\nPREVIOUS SQL HAD ISSUES: {schema_issues}. Fix them.",
)
sql = self._clean_sql(sql_result.sql_query)
# 4. Safety validation (no LLM call)
is_safe, reason = validate_sql(sql)
if not is_safe:
return {
"sql": sql,
"data": [],
"answer": f"Query rejected: {reason}",
"insights": "",
}
# 5. SQL Execution + repair loop
logger.info("Stage 4 β€” Executing SQL")
exec_result = execute_sql(sql)
for attempt in range(MAX_REPAIR_RETRIES):
if exec_result["success"]:
break
logger.warning(f"SQL error (attempt {attempt + 1}): {exec_result['error']}")
repair_result = self.repair(
sql_query=sql,
error_message=exec_result["error"],
schema_info=schema_str,
question=question,
)
sql = self._clean_sql(repair_result.corrected_sql)
is_safe, reason = validate_sql(sql)
if not is_safe:
return {
"sql": sql,
"data": [],
"answer": f"Repaired query rejected: {reason}",
"insights": "",
}
exec_result = execute_sql(sql)
if not exec_result["success"]:
return {
"sql": sql,
"data": [],
"answer": f"Failed after {MAX_REPAIR_RETRIES} repairs. Error: {exec_result['error']}",
"insights": "",
}
data = exec_result["data"]
data_for_llm = data[:50]
results_json = json.dumps(data_for_llm, default=str)
# 6. Interpret & Insight (single LLM call replaces 2 former stages)
logger.info("Stage 5 β€” Interpret & Insight")
result = self.interpret(
question=question,
sql_query=sql,
query_results=results_json,
)
return {
"sql": sql,
"data": data,
"answer": result.answer,
"insights": result.insights,
}
def generate_sql_only(self, question: str) -> str:
"""Run the pipeline up to SQL generation and return just the SQL."""
schema_str = format_schema()
rels_str = format_relationships()
profile_str = get_data_profile()
plan = self.analyze(
question=question,
schema_info=schema_str,
relationships=rels_str,
data_profile=profile_str,
)
plan_text = (
f"Intent: {plan.intent}\n"
f"Tables: {plan.relevant_tables}\n"
f"Columns: {plan.relevant_columns}\n"
f"Joins: {plan.join_conditions}\n"
f"Where: {plan.where_conditions}\n"
f"Aggregations: {plan.aggregations}\n"
f"Group By: {plan.group_by}\n"
f"Order By: {plan.order_by}\n"
f"Limit: {plan.limit_val}"
)
sql_result = self.generate_sql(
question=question,
schema_info=schema_str,
query_plan=plan_text,
)
sql = self._clean_sql(sql_result.sql_query)
# Code-based schema check
from db.schema import get_schema
schema_valid, schema_issues = check_sql_against_schema(sql, get_schema())
if not schema_valid:
sql_result = self.generate_sql(
question=question,
schema_info=schema_str,
query_plan=plan_text + f"\n\nPREVIOUS SQL HAD ISSUES: {schema_issues}. Fix them.",
)
sql = self._clean_sql(sql_result.sql_query)
return sql
# ── helpers ─────────────────────────────────────────────────────────
@staticmethod
def _clean_sql(raw: str) -> str:
"""Strip markdown fences, trailing prose, and whitespace from LLM SQL."""
sql = raw.strip()
# 1. Remove ```sql ... ``` wrappers
if sql.startswith("```"):
lines = sql.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
sql = "\n".join(lines).strip()
# 2. Extract only the first valid SQL statement
match = re.search(
r"((?:SELECT|WITH)\b[\s\S]*?)(;|\n\n(?=[A-Z][a-z])|$)",
sql,
re.IGNORECASE,
)
if match:
sql = match.group(1).strip()
# 3. Remove trailing lines that look like natural language
cleaned_lines: list[str] = []
for line in sql.split("\n"):
stripped = line.strip()
if not stripped:
cleaned_lines.append(line)
continue
if re.match(
r"^(However|Note|This|The|Please|But|Also|In |It |I |Here|Since|Because|Although|Unfortunately)",
stripped,
):
break
cleaned_lines.append(line)
sql = "\n".join(cleaned_lines).strip()
# 4. Remove trailing semicolons
sql = sql.rstrip(";")
return sql