Spaces:
Running on Zero
Running on Zero
| """ | |
| SQL Agent orchestrator. Models are constructed (loaded onto cuda) at | |
| import time per ZeroGPU best practices. The pipeline runs inference only. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| import pandas as pd | |
| from src.models.chart_reasoner import ChartReasoner | |
| from src.models.sql_generator import SQLGenerator | |
| from src.models.svg_renderer import SVGRenderer | |
| from src.rag.engine import RAGEngine | |
| from src.utils.sql_executor import SQLExecutor | |
| logger = logging.getLogger(__name__) | |
| class SQLAgentOrchestrator: | |
| """End-to-end NL -> SQL -> chart pipeline backed by DuckDB.""" | |
| def __init__( | |
| self, | |
| sql_generator: SQLGenerator, | |
| chart_reasoner: ChartReasoner, | |
| svg_renderer: SVGRenderer, | |
| ) -> None: | |
| self.executor = SQLExecutor() | |
| self.rag = RAGEngine(self.executor.con) | |
| self.sql_generator = sql_generator | |
| self.chart_reasoner = chart_reasoner | |
| self.svg_renderer = svg_renderer | |
| def load_data( | |
| self, | |
| source: Union[str, Path, pd.DataFrame], | |
| table_name: Optional[str] = None, | |
| ) -> str: | |
| if isinstance(source, pd.DataFrame): | |
| name = table_name or "data" | |
| self.executor.register_dataframe(name, source) | |
| return name | |
| return self.executor.register_file(source, table_name) | |
| def schema_text(self) -> str: | |
| return self.rag.retrieve("", top_k=5) | |
| def list_tables(self) -> List[str]: | |
| return self.executor.get_table_names() | |
| def sample(self, table: str, n: int = 5) -> pd.DataFrame: | |
| return self.executor.get_sample(table, n) | |
| def process(self, question: str) -> Dict[str, Any]: | |
| """Inference-only pipeline; models already loaded at module level.""" | |
| result: Dict[str, Any] = { | |
| "question": question, | |
| "sql": None, | |
| "results": [], | |
| "columns": [], | |
| "chart_spec": None, | |
| "svg": None, | |
| "narration": None, | |
| "error": None, | |
| } | |
| try: | |
| schema = self.schema_text() | |
| if not schema: | |
| result["error"] = "No data loaded. Upload a CSV/JSON first." | |
| return result | |
| # SQL with self-correction loop: up to 3 attempts. Each retry | |
| # feeds the previous SQL + error back to the model so it learns | |
| # from its own mistake. | |
| sql = None | |
| last_error = None | |
| for attempt in range(3): | |
| sql = self.sql_generator.generate( | |
| question=question, | |
| schema=schema, | |
| previous_sql=sql if attempt > 0 else None, | |
| previous_error=last_error if attempt > 0 else None, | |
| ) | |
| # Try to actually execute (validate+run in one shot) | |
| try: | |
| rows, cols = self.executor.execute(sql) | |
| last_error = None | |
| break | |
| except Exception as e: | |
| last_error = str(e) | |
| logger.info(f"SQL attempt {attempt+1} failed: {last_error}") | |
| rows, cols = [], [] | |
| result["sql"] = sql | |
| if last_error: | |
| result["error"] = ( | |
| f"Could not produce a valid SQL query after 3 attempts.\n" | |
| f"Last error: {last_error}" | |
| ) | |
| return result | |
| result["results"] = rows | |
| result["columns"] = cols | |
| spec = self.chart_reasoner.generate( | |
| question=question, sql=sql, results=rows, columns=cols, | |
| ) | |
| result["chart_spec"] = spec | |
| svg = self.svg_renderer.generate(spec, rows) | |
| result["svg"] = svg | |
| # Result narration — reuse Qwen for a 1-2 sentence finding | |
| try: | |
| result["narration"] = self.sql_generator.narrate( | |
| question=question, sql=sql, results=rows, columns=cols, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"narration step failed: {e}") | |
| result["narration"] = None | |
| return result | |
| except Exception as e: | |
| logger.exception("Pipeline failed") | |
| result["error"] = str(e) | |
| return result | |
| def reset(self) -> None: | |
| self.executor.close() | |
| self.executor = SQLExecutor() | |
| self.rag.bind(self.executor.con) | |