sql-agent / src /orchestrator /pipeline.py
DanielRegaladoCardoso's picture
Result narration (Qwen-powered) + UI polish for upload area + dark-mode fixes
7af78a0 verified
"""
SQL Agent orchestrator. Models are constructed (loaded onto cuda) at
import time per ZeroGPU best practices. The pipeline runs inference only.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import pandas as pd
from src.models.chart_reasoner import ChartReasoner
from src.models.sql_generator import SQLGenerator
from src.models.svg_renderer import SVGRenderer
from src.rag.engine import RAGEngine
from src.utils.sql_executor import SQLExecutor
logger = logging.getLogger(__name__)
class SQLAgentOrchestrator:
"""End-to-end NL -> SQL -> chart pipeline backed by DuckDB."""
def __init__(
self,
sql_generator: SQLGenerator,
chart_reasoner: ChartReasoner,
svg_renderer: SVGRenderer,
) -> None:
self.executor = SQLExecutor()
self.rag = RAGEngine(self.executor.con)
self.sql_generator = sql_generator
self.chart_reasoner = chart_reasoner
self.svg_renderer = svg_renderer
def load_data(
self,
source: Union[str, Path, pd.DataFrame],
table_name: Optional[str] = None,
) -> str:
if isinstance(source, pd.DataFrame):
name = table_name or "data"
self.executor.register_dataframe(name, source)
return name
return self.executor.register_file(source, table_name)
def schema_text(self) -> str:
return self.rag.retrieve("", top_k=5)
def list_tables(self) -> List[str]:
return self.executor.get_table_names()
def sample(self, table: str, n: int = 5) -> pd.DataFrame:
return self.executor.get_sample(table, n)
def process(self, question: str) -> Dict[str, Any]:
"""Inference-only pipeline; models already loaded at module level."""
result: Dict[str, Any] = {
"question": question,
"sql": None,
"results": [],
"columns": [],
"chart_spec": None,
"svg": None,
"narration": None,
"error": None,
}
try:
schema = self.schema_text()
if not schema:
result["error"] = "No data loaded. Upload a CSV/JSON first."
return result
# SQL with self-correction loop: up to 3 attempts. Each retry
# feeds the previous SQL + error back to the model so it learns
# from its own mistake.
sql = None
last_error = None
for attempt in range(3):
sql = self.sql_generator.generate(
question=question,
schema=schema,
previous_sql=sql if attempt > 0 else None,
previous_error=last_error if attempt > 0 else None,
)
# Try to actually execute (validate+run in one shot)
try:
rows, cols = self.executor.execute(sql)
last_error = None
break
except Exception as e:
last_error = str(e)
logger.info(f"SQL attempt {attempt+1} failed: {last_error}")
rows, cols = [], []
result["sql"] = sql
if last_error:
result["error"] = (
f"Could not produce a valid SQL query after 3 attempts.\n"
f"Last error: {last_error}"
)
return result
result["results"] = rows
result["columns"] = cols
spec = self.chart_reasoner.generate(
question=question, sql=sql, results=rows, columns=cols,
)
result["chart_spec"] = spec
svg = self.svg_renderer.generate(spec, rows)
result["svg"] = svg
# Result narration — reuse Qwen for a 1-2 sentence finding
try:
result["narration"] = self.sql_generator.narrate(
question=question, sql=sql, results=rows, columns=cols,
)
except Exception as e:
logger.warning(f"narration step failed: {e}")
result["narration"] = None
return result
except Exception as e:
logger.exception("Pipeline failed")
result["error"] = str(e)
return result
def reset(self) -> None:
self.executor.close()
self.executor = SQLExecutor()
self.rag.bind(self.executor.con)