Spaces:
Sleeping
Sleeping
| """Main DSPy reasoning pipeline β optimized for speed. | |
| Reduced from 9 stages to 4 LLM calls in the happy path: | |
| 1. AnalyzeAndPlan (question understanding + schema analysis + query planning) | |
| 2. SQLGeneration | |
| 3. SQLCritiqueAndFix (one pass; only retries on failure) | |
| 4. InterpretAndInsight (interpretation + insights in one call) | |
| """ | |
| import json | |
| import logging | |
| import re | |
| from typing import Any | |
| import dspy | |
| from ai.groq_setup import get_lm | |
| from ai.signatures import ( | |
| AnalyzeAndPlan, | |
| SQLGeneration, | |
| SQLRepair, | |
| InterpretAndInsight, | |
| ) | |
| from ai.validator import validate_sql, check_sql_against_schema | |
| from db.schema import format_schema | |
| from db.relationships import format_relationships | |
| from db.profiler import get_data_profile | |
| from db.executor import execute_sql | |
| logger = logging.getLogger(__name__) | |
| MAX_REPAIR_RETRIES = 2 | |
| class SQLAnalystPipeline: | |
| """End-to-end reasoning pipeline: question β SQL β results β insights.""" | |
| def __init__(self, provider: str = "groq"): | |
| self.provider = provider | |
| self._lm = get_lm(provider) | |
| # DSPy predict modules | |
| self.analyze = dspy.Predict(AnalyzeAndPlan) | |
| self.generate_sql = dspy.Predict(SQLGeneration) | |
| self.interpret = dspy.Predict(InterpretAndInsight) | |
| self.repair = dspy.Predict(SQLRepair) | |
| # ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(self, question: str) -> dict[str, Any]: | |
| """Run the full pipeline and return {sql, data, answer, insights}.""" | |
| schema_str = format_schema() | |
| rels_str = format_relationships() | |
| profile_str = get_data_profile() | |
| # 1. Analyze & Plan (single LLM call replaces 3 former stages) | |
| logger.info("Stage 1 β Analyze & Plan") | |
| plan = self.analyze( | |
| question=question, | |
| schema_info=schema_str, | |
| relationships=rels_str, | |
| data_profile=profile_str, | |
| ) | |
| plan_text = ( | |
| f"Intent: {plan.intent}\n" | |
| f"Tables: {plan.relevant_tables}\n" | |
| f"Columns: {plan.relevant_columns}\n" | |
| f"Joins: {plan.join_conditions}\n" | |
| f"Where: {plan.where_conditions}\n" | |
| f"Aggregations: {plan.aggregations}\n" | |
| f"Group By: {plan.group_by}\n" | |
| f"Order By: {plan.order_by}\n" | |
| f"Limit: {plan.limit_val}" | |
| ) | |
| # 2. SQL Generation | |
| logger.info("Stage 2 β SQL Generation") | |
| sql_result = self.generate_sql( | |
| question=question, | |
| schema_info=schema_str, | |
| query_plan=plan_text, | |
| ) | |
| sql = self._clean_sql(sql_result.sql_query) | |
| # 3. Code-based schema validation (instant β no LLM call) | |
| logger.info("Stage 3 β Schema Validation") | |
| from db.schema import get_schema | |
| schema_valid, schema_issues = check_sql_against_schema(sql, get_schema()) | |
| if not schema_valid: | |
| logger.warning(f"Schema issues detected: {schema_issues}") | |
| # Try regenerating SQL once with the issues as feedback | |
| sql_result = self.generate_sql( | |
| question=question, | |
| schema_info=schema_str, | |
| query_plan=plan_text + f"\n\nPREVIOUS SQL HAD ISSUES: {schema_issues}. Fix them.", | |
| ) | |
| sql = self._clean_sql(sql_result.sql_query) | |
| # 4. Safety validation (no LLM call) | |
| is_safe, reason = validate_sql(sql) | |
| if not is_safe: | |
| return { | |
| "sql": sql, | |
| "data": [], | |
| "answer": f"Query rejected: {reason}", | |
| "insights": "", | |
| } | |
| # 5. SQL Execution + repair loop | |
| logger.info("Stage 4 β Executing SQL") | |
| exec_result = execute_sql(sql) | |
| for attempt in range(MAX_REPAIR_RETRIES): | |
| if exec_result["success"]: | |
| break | |
| logger.warning(f"SQL error (attempt {attempt + 1}): {exec_result['error']}") | |
| repair_result = self.repair( | |
| sql_query=sql, | |
| error_message=exec_result["error"], | |
| schema_info=schema_str, | |
| question=question, | |
| ) | |
| sql = self._clean_sql(repair_result.corrected_sql) | |
| is_safe, reason = validate_sql(sql) | |
| if not is_safe: | |
| return { | |
| "sql": sql, | |
| "data": [], | |
| "answer": f"Repaired query rejected: {reason}", | |
| "insights": "", | |
| } | |
| exec_result = execute_sql(sql) | |
| if not exec_result["success"]: | |
| return { | |
| "sql": sql, | |
| "data": [], | |
| "answer": f"Failed after {MAX_REPAIR_RETRIES} repairs. Error: {exec_result['error']}", | |
| "insights": "", | |
| } | |
| data = exec_result["data"] | |
| data_for_llm = data[:50] | |
| results_json = json.dumps(data_for_llm, default=str) | |
| # 6. Interpret & Insight (single LLM call replaces 2 former stages) | |
| logger.info("Stage 5 β Interpret & Insight") | |
| result = self.interpret( | |
| question=question, | |
| sql_query=sql, | |
| query_results=results_json, | |
| ) | |
| return { | |
| "sql": sql, | |
| "data": data, | |
| "answer": result.answer, | |
| "insights": result.insights, | |
| } | |
| def generate_sql_only(self, question: str) -> str: | |
| """Run the pipeline up to SQL generation and return just the SQL.""" | |
| schema_str = format_schema() | |
| rels_str = format_relationships() | |
| profile_str = get_data_profile() | |
| plan = self.analyze( | |
| question=question, | |
| schema_info=schema_str, | |
| relationships=rels_str, | |
| data_profile=profile_str, | |
| ) | |
| plan_text = ( | |
| f"Intent: {plan.intent}\n" | |
| f"Tables: {plan.relevant_tables}\n" | |
| f"Columns: {plan.relevant_columns}\n" | |
| f"Joins: {plan.join_conditions}\n" | |
| f"Where: {plan.where_conditions}\n" | |
| f"Aggregations: {plan.aggregations}\n" | |
| f"Group By: {plan.group_by}\n" | |
| f"Order By: {plan.order_by}\n" | |
| f"Limit: {plan.limit_val}" | |
| ) | |
| sql_result = self.generate_sql( | |
| question=question, | |
| schema_info=schema_str, | |
| query_plan=plan_text, | |
| ) | |
| sql = self._clean_sql(sql_result.sql_query) | |
| # Code-based schema check | |
| from db.schema import get_schema | |
| schema_valid, schema_issues = check_sql_against_schema(sql, get_schema()) | |
| if not schema_valid: | |
| sql_result = self.generate_sql( | |
| question=question, | |
| schema_info=schema_str, | |
| query_plan=plan_text + f"\n\nPREVIOUS SQL HAD ISSUES: {schema_issues}. Fix them.", | |
| ) | |
| sql = self._clean_sql(sql_result.sql_query) | |
| return sql | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _clean_sql(raw: str) -> str: | |
| """Strip markdown fences, trailing prose, and whitespace from LLM SQL.""" | |
| sql = raw.strip() | |
| # 1. Remove ```sql ... ``` wrappers | |
| if sql.startswith("```"): | |
| lines = sql.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| sql = "\n".join(lines).strip() | |
| # 2. Extract only the first valid SQL statement | |
| match = re.search( | |
| r"((?:SELECT|WITH)\b[\s\S]*?)(;|\n\n(?=[A-Z][a-z])|$)", | |
| sql, | |
| re.IGNORECASE, | |
| ) | |
| if match: | |
| sql = match.group(1).strip() | |
| # 3. Remove trailing lines that look like natural language | |
| cleaned_lines: list[str] = [] | |
| for line in sql.split("\n"): | |
| stripped = line.strip() | |
| if not stripped: | |
| cleaned_lines.append(line) | |
| continue | |
| if re.match( | |
| r"^(However|Note|This|The|Please|But|Also|In |It |I |Here|Since|Because|Although|Unfortunately)", | |
| stripped, | |
| ): | |
| break | |
| cleaned_lines.append(line) | |
| sql = "\n".join(cleaned_lines).strip() | |
| # 4. Remove trailing semicolons | |
| sql = sql.rstrip(";") | |
| return sql | |