import logging import os from typing import Any import numpy as np import pandas as pd from dotenv import load_dotenv from duckdb import DuckDBPyConnection from src.models import ( Charts, Continuous, Data, DateTime, Nominal, PlotConfig, Route, SmallCardNum, SQLQueryModel, TableData, ) load_dotenv() logger = logging.getLogger(__name__) MAX_BARS_COUNT = 20 SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5")) SQL_PROMPT = os.getenv("SQL_PROMPT") USER_PROMPT = os.getenv("USER_PROMPT") ROUTER_SYSTEM_PROMPT = os.getenv("ROUTER_SYSTEM_PROMPT") CHART_CONFIG_SYSTEM_PROMPT = os.getenv("CHART_CONFIG_SYSTEM_PROMPT") CHART_CONFIG_USER_PROMPT = os.getenv("CHART_CONFIG_USER_PROMPT") class SQLPipeline: def __init__( self, duckdb: DuckDBPyConnection, chain, ) -> None: self._duckdb = duckdb self.chain = chain def generate_sql( self, user_question: str, context: str, errors: str | None = None ) -> str | dict[str, str | int | float | None] | list[str] | None: """Generate SQL + description.""" user_prompt_formatted = USER_PROMPT.format( question=user_question, context=context ) if errors: user_prompt_formatted += f"Carefully review the previous error or\ exception and rewrite the SQL so that the error does not occur again.\ Try a different approach or rewrite SQL if needed. Last error: {errors}" sql = self.chain.run( system_prompt=SQL_PROMPT, user_prompt=user_prompt_formatted, format_name="sql_query", response_format=SQLQueryModel, ) logger.info(f"SQL Generated Successfully: {sql}") return sql def run_query(self, sql_query: str) -> pd.DataFrame | None: """Execute SQL and return dataframe.""" logger.info("Query Execution Started.") return self._duckdb.query(sql_query).df() def try_sql_with_retries( self, user_question: str, context: str, max_retries: int = SQL_GENERATION_RETRIES, ) -> tuple[ str | dict[str, str | int | float | None] | list[str] | None, pd.DataFrame | None, ]: """Try SQL generation + execution with retries.""" last_error = None all_errors = "" for attempt in range( 1, max_retries + 2 ): # @ Since the first is normal and not consider in retries try: if attempt > 1 and last_error: logger.info(f"Retrying: {attempt - 1}") # Generate SQL sql = self.generate_sql(user_question, context, errors=all_errors) if not sql: return None, None else: # Generate SQL sql = self.generate_sql(user_question, context) if not sql: return None, None # Try executing query sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql if not isinstance(sql_query_str, str): raise ValueError( f"Expected SQL query to be a string, got {type(sql_query_str).__name__}" ) query_df = self.run_query(sql_query_str) # If execution succeeds, stop retrying or if df is not empty if query_df is not None and not query_df.empty: return sql, query_df except Exception as e: last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}" logger.error(f"Error during SQL generation or execution: {last_error}") all_errors += last_error logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}") return None, None class QueryRouter: def __init__(self, chain) -> None: self.chain = chain def route_request(self, user_question: str, context: str) -> int: """Route the user question to 0, 1, or 2.""" user_prompt_formatted = USER_PROMPT.format( question=user_question, context=context ) route = self.chain.run( system_prompt=ROUTER_SYSTEM_PROMPT, user_prompt=user_prompt_formatted, format_name="route_queries", response_format=Route, ) logger.info( f"Query routed to: {route} Where if query is routed to 0 its irrelevant, if 1 its visualizable, if 2 its only sql, and 3 if its datetime." ) return route class ChartFormatter: def _build_xy_data(self, label_data, value_data, limit_unique_x=False): df = pd.DataFrame({"x": label_data, "y": value_data}) if limit_unique_x and df["x"].nunique() > MAX_BARS_COUNT: df = df.head(MAX_BARS_COUNT) return df.to_dict(orient="records") def is_continuous(self, dtype) -> bool: if pd.api.types.is_bool_dtype(dtype): return False return ( pd.api.types.is_integer_dtype(dtype) or pd.api.types.is_float_dtype(dtype) or pd.api.types.is_numeric_dtype(dtype) ) def is_datetime(self, dtype) -> bool: return pd.api.types.is_datetime64_any_dtype( dtype ) or pd.api.types.is_timedelta64_dtype(dtype) def detect_dtype(self, data): """Detects dtypes of columns.""" type_ = {} for col_name in data.columns: col_data = data[col_name] if self.is_continuous(col_data.dtype): # detect as categorical if distinct value is small if isinstance(col_data, pd.Series): nuniques = col_data.nunique() else: raise TypeError(f"unprocessed column type:{type(col_name)}") small_cardinality_threshold = 10 if nuniques < small_cardinality_threshold: type_[col_name] = SmallCardNum() else: type_[col_name] = Continuous() elif self.is_datetime(col_data.dtype): type_[col_name] = DateTime() else: type_[col_name] = Nominal() return type_ def build_bar_chart(self, label_data, value_data): return self._build_xy_data(label_data, value_data, limit_unique_x=True) def build_line_chart(self, label_data, value_data): return self._build_xy_data(label_data, value_data) def build_pie_chart(self, label_data, value_data): return self._build_xy_data(label_data, value_data) def build_histogram(self, data): range_ = (data.min(), data.max()) counts, bins = np.histogram(data, bins=50, range=range_) return [ { "bin_start": bins[i], "bin_end": bins[i + 1], "frequency": counts[i], } for i in range(len(counts)) ] def format_and_select_chart(self, df: pd.DataFrame): cols = df.columns.tolist() dtypes = self.detect_dtype(df) if len(cols) == 1: col = cols[0] dtype = dtypes[col] if isinstance(dtype, Continuous): return "hist", self.build_histogram(df[col].dropna()), dtypes if isinstance(dtype, (SmallCardNum, Nominal)): counts = df[col].value_counts() chart = "pie" if counts.size <= 6 else "bar" builder = ( self.build_pie_chart if chart == "pie" else self.build_bar_chart ) return chart, builder(counts.index, counts.values), dtypes if len(cols) == 2: x, y = cols dtype_x = dtypes[x] dtype_y = dtypes[y] data_x = df[x] data_y = df[y] if {type(dtype_x), type(dtype_y)} == {Nominal, Continuous}: label, value = ( (data_x, data_y) if isinstance(dtype_x, Nominal) else (data_y, data_x) ) formatted_data = self.build_bar_chart(label, value) return "bar", formatted_data, dtypes elif {type(dtype_x), type(dtype_y)} == {Continuous, Continuous}: label, value = ( (data_x, data_y) if isinstance(dtype_x, Continuous) else (data_y, data_x) ) formatted_data = self.build_bar_chart(label, value) return "bar", formatted_data, dtypes elif {type(dtype_x), type(dtype_y)} == {SmallCardNum, Continuous}: label, value = ( (data_x, data_y) if isinstance(dtype_x, SmallCardNum) else (data_y, data_x) ) formatted_data = self.build_bar_chart(label, value) return "bar", formatted_data, dtypes elif isinstance(dtype_x, SmallCardNum) and isinstance( dtype_y, SmallCardNum ): formatted_data = self.build_bar_chart(data_x, data_y) return "bar", formatted_data, dtypes elif {type(dtype_x), type(dtype_y)} == {DateTime, Continuous}: label, value = ( (data_x, data_y) if isinstance(dtype_x, DateTime) else (data_y, data_x) ) formatted_data = self.build_line_chart(label, value) return "line", formatted_data, dtypes elif ( isinstance(dtype_x, DateTime) and isinstance(dtype_y, SmallCardNum) ) or (isinstance(dtype_y, DateTime) and isinstance(dtype_x, SmallCardNum)): label, value = ( (data_x, data_y) if isinstance(dtype_x, DateTime) else (data_y, data_x) ) formatted_data = self.build_line_chart(label, value) return "line", formatted_data, dtypes elif {type(dtype_x), type(dtype_y)} == {Nominal, SmallCardNum}: label, value = ( (data_x, data_y) if isinstance(dtype_x, Nominal) else (data_y, data_x) ) formatted_data = self.build_bar_chart(label, value) return "bar", formatted_data, dtypes return None, None, None class SQLVizChain: def __init__(self, duckdb: DuckDBPyConnection, chain): self._duckdb = duckdb self.chain = chain self.router = QueryRouter(chain=self.chain) self.sql_generator = SQLPipeline(duckdb, chain=self.chain) self.charting = ChartFormatter() def create_chart_config( self, query_df: pd.DataFrame, user_question: str, sql: str ) -> tuple[list[dict[Any, Any]] | None, dict[str, Any] | None, str | None]: """Format data for visualization and return chart config.""" ( chart_type, formatted_data, dtypes, ) = self.charting.format_and_select_chart(df=query_df) if not all([formatted_data, dtypes, chart_type]): return None, None, None chart_config = self.chain.run( system_prompt=CHART_CONFIG_SYSTEM_PROMPT, user_prompt=CHART_CONFIG_USER_PROMPT.format( question=user_question, sql_query=sql, dtypes=dtypes, chart_type=chart_type, ), format_name="chart_config", response_format=PlotConfig, ) logger.info(f"Chart Config Generated: {chart_config}") return formatted_data, chart_config, chart_type def create_viz_with_text_response( self, query_df: pd.DataFrame, user_question: str, sql_config: dict[Any, Any] ) -> dict[str, Any]: formatted_data, chart_config, chart_type = self.create_chart_config( query_df, user_question, sql_config["sql_query"] ) table_data = TableData(data=query_df) if not all([formatted_data, chart_config, chart_type]): logger.info("Failed to format data or generate chart config.") logger.info(f"Total Token Counts: {self.chain.total_tokens}") return { "chart_data": table_data, "chart_config": None, "chart_type": None, "sql_config": sql_config, } chart_data = Data.validate_data(data=formatted_data) if chart_config and chart_config["type"] in {"bar", "line", "pie", "hist"}: data = Charts(**{chart_config["type"]: chart_data}) else: raise ValueError( "Invalid Plot Type. Must be one of 'bar', 'line', 'pie', 'hist'" ) logger.info("Visualization Chain Completed Successfully.") logger.info(f"Total Token Counts: {self.chain.total_tokens}") return { "chart_data": data, "chart_config": chart_config, "chart_type": chart_type, "sql_config": sql_config, } def run(self, user_question: str, context: str) -> dict[str, Any]: """Main pipeline: question → SQL → data → chart config.""" route = self.router.route_request(user_question=user_question, context=context) if route == 0: return { "chart_data": None, "chart_config": None, "chart_type": None, "sql_config": None, } sql_config, query_df = self.sql_generator.try_sql_with_retries( user_question=user_question, context=context ) if sql_config is None or query_df is None: logger.info("Failed to generate or execute SQL after retries.") logger.info(f"Total Token Counts: {self.chain.total_tokens}") return { "chart_data": None, "chart_config": None, "chart_type": None, "sql_config": None, } return self.create_viz_with_text_response( query_df=query_df, user_question=user_question, sql_config=sql_config )