Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from duckdb import DuckDBPyConnection | |
| from src.models import ( | |
| Charts, | |
| Continuous, | |
| Data, | |
| DateTime, | |
| Nominal, | |
| PlotConfig, | |
| Route, | |
| SmallCardNum, | |
| SQLQueryModel, | |
| TableData, | |
| ) | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| MAX_BARS_COUNT = 20 | |
| SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5")) | |
| SQL_PROMPT = os.getenv("SQL_PROMPT") | |
| USER_PROMPT = os.getenv("USER_PROMPT") | |
| ROUTER_SYSTEM_PROMPT = os.getenv("ROUTER_SYSTEM_PROMPT") | |
| CHART_CONFIG_SYSTEM_PROMPT = os.getenv("CHART_CONFIG_SYSTEM_PROMPT") | |
| CHART_CONFIG_USER_PROMPT = os.getenv("CHART_CONFIG_USER_PROMPT") | |
| class SQLPipeline: | |
| def __init__( | |
| self, | |
| duckdb: DuckDBPyConnection, | |
| chain, | |
| ) -> None: | |
| self._duckdb = duckdb | |
| self.chain = chain | |
| def generate_sql( | |
| self, user_question: str, context: str, errors: str | None = None | |
| ) -> str | dict[str, str | int | float | None] | list[str] | None: | |
| """Generate SQL + description.""" | |
| user_prompt_formatted = USER_PROMPT.format( | |
| question=user_question, context=context | |
| ) | |
| if errors: | |
| user_prompt_formatted += f"Carefully review the previous error or\ | |
| exception and rewrite the SQL so that the error does not occur again.\ | |
| Try a different approach or rewrite SQL if needed. Last error: {errors}" | |
| sql = self.chain.run( | |
| system_prompt=SQL_PROMPT, | |
| user_prompt=user_prompt_formatted, | |
| format_name="sql_query", | |
| response_format=SQLQueryModel, | |
| ) | |
| logger.info(f"SQL Generated Successfully: {sql}") | |
| return sql | |
| def run_query(self, sql_query: str) -> pd.DataFrame | None: | |
| """Execute SQL and return dataframe.""" | |
| logger.info("Query Execution Started.") | |
| return self._duckdb.query(sql_query).df() | |
| def try_sql_with_retries( | |
| self, | |
| user_question: str, | |
| context: str, | |
| max_retries: int = SQL_GENERATION_RETRIES, | |
| ) -> tuple[ | |
| str | dict[str, str | int | float | None] | list[str] | None, | |
| pd.DataFrame | None, | |
| ]: | |
| """Try SQL generation + execution with retries.""" | |
| last_error = None | |
| all_errors = "" | |
| for attempt in range( | |
| 1, max_retries + 2 | |
| ): # @ Since the first is normal and not consider in retries | |
| try: | |
| if attempt > 1 and last_error: | |
| logger.info(f"Retrying: {attempt - 1}") | |
| # Generate SQL | |
| sql = self.generate_sql(user_question, context, errors=all_errors) | |
| if not sql: | |
| return None, None | |
| else: | |
| # Generate SQL | |
| sql = self.generate_sql(user_question, context) | |
| if not sql: | |
| return None, None | |
| # Try executing query | |
| sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql | |
| if not isinstance(sql_query_str, str): | |
| raise ValueError( | |
| f"Expected SQL query to be a string, got {type(sql_query_str).__name__}" | |
| ) | |
| query_df = self.run_query(sql_query_str) | |
| # If execution succeeds, stop retrying or if df is not empty | |
| if query_df is not None and not query_df.empty: | |
| return sql, query_df | |
| except Exception as e: | |
| last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}" | |
| logger.error(f"Error during SQL generation or execution: {last_error}") | |
| all_errors += last_error | |
| logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}") | |
| return None, None | |
| class QueryRouter: | |
| def __init__(self, chain) -> None: | |
| self.chain = chain | |
| def route_request(self, user_question: str, context: str) -> int: | |
| """Route the user question to 0, 1, or 2.""" | |
| user_prompt_formatted = USER_PROMPT.format( | |
| question=user_question, context=context | |
| ) | |
| route = self.chain.run( | |
| system_prompt=ROUTER_SYSTEM_PROMPT, | |
| user_prompt=user_prompt_formatted, | |
| format_name="route_queries", | |
| response_format=Route, | |
| ) | |
| logger.info( | |
| f"Query routed to: {route} Where if query is routed to 0 its irrelevant, if 1 its visualizable, if 2 its only sql, and 3 if its datetime." | |
| ) | |
| return route | |
| class ChartFormatter: | |
| def _build_xy_data(self, label_data, value_data, limit_unique_x=False): | |
| df = pd.DataFrame({"x": label_data, "y": value_data}) | |
| if limit_unique_x and df["x"].nunique() > MAX_BARS_COUNT: | |
| df = df.head(MAX_BARS_COUNT) | |
| return df.to_dict(orient="records") | |
| def is_continuous(self, dtype) -> bool: | |
| if pd.api.types.is_bool_dtype(dtype): | |
| return False | |
| return ( | |
| pd.api.types.is_integer_dtype(dtype) | |
| or pd.api.types.is_float_dtype(dtype) | |
| or pd.api.types.is_numeric_dtype(dtype) | |
| ) | |
| def is_datetime(self, dtype) -> bool: | |
| return pd.api.types.is_datetime64_any_dtype( | |
| dtype | |
| ) or pd.api.types.is_timedelta64_dtype(dtype) | |
| def detect_dtype(self, data): | |
| """Detects dtypes of columns.""" | |
| type_ = {} | |
| for col_name in data.columns: | |
| col_data = data[col_name] | |
| if self.is_continuous(col_data.dtype): | |
| # detect as categorical if distinct value is small | |
| if isinstance(col_data, pd.Series): | |
| nuniques = col_data.nunique() | |
| else: | |
| raise TypeError(f"unprocessed column type:{type(col_name)}") | |
| small_cardinality_threshold = 10 | |
| if nuniques < small_cardinality_threshold: | |
| type_[col_name] = SmallCardNum() | |
| else: | |
| type_[col_name] = Continuous() | |
| elif self.is_datetime(col_data.dtype): | |
| type_[col_name] = DateTime() | |
| else: | |
| type_[col_name] = Nominal() | |
| return type_ | |
| def build_bar_chart(self, label_data, value_data): | |
| return self._build_xy_data(label_data, value_data, limit_unique_x=True) | |
| def build_line_chart(self, label_data, value_data): | |
| return self._build_xy_data(label_data, value_data) | |
| def build_pie_chart(self, label_data, value_data): | |
| return self._build_xy_data(label_data, value_data) | |
| def build_histogram(self, data): | |
| range_ = (data.min(), data.max()) | |
| counts, bins = np.histogram(data, bins=50, range=range_) | |
| return [ | |
| { | |
| "bin_start": bins[i], | |
| "bin_end": bins[i + 1], | |
| "frequency": counts[i], | |
| } | |
| for i in range(len(counts)) | |
| ] | |
| def format_and_select_chart(self, df: pd.DataFrame): | |
| cols = df.columns.tolist() | |
| dtypes = self.detect_dtype(df) | |
| if len(cols) == 1: | |
| col = cols[0] | |
| dtype = dtypes[col] | |
| if isinstance(dtype, Continuous): | |
| return "hist", self.build_histogram(df[col].dropna()), dtypes | |
| if isinstance(dtype, (SmallCardNum, Nominal)): | |
| counts = df[col].value_counts() | |
| chart = "pie" if counts.size <= 6 else "bar" | |
| builder = ( | |
| self.build_pie_chart if chart == "pie" else self.build_bar_chart | |
| ) | |
| return chart, builder(counts.index, counts.values), dtypes | |
| if len(cols) == 2: | |
| x, y = cols | |
| dtype_x = dtypes[x] | |
| dtype_y = dtypes[y] | |
| data_x = df[x] | |
| data_y = df[y] | |
| if {type(dtype_x), type(dtype_y)} == {Nominal, Continuous}: | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, Nominal) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_bar_chart(label, value) | |
| return "bar", formatted_data, dtypes | |
| elif {type(dtype_x), type(dtype_y)} == {Continuous, Continuous}: | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, Continuous) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_bar_chart(label, value) | |
| return "bar", formatted_data, dtypes | |
| elif {type(dtype_x), type(dtype_y)} == {SmallCardNum, Continuous}: | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, SmallCardNum) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_bar_chart(label, value) | |
| return "bar", formatted_data, dtypes | |
| elif isinstance(dtype_x, SmallCardNum) and isinstance( | |
| dtype_y, SmallCardNum | |
| ): | |
| formatted_data = self.build_bar_chart(data_x, data_y) | |
| return "bar", formatted_data, dtypes | |
| elif {type(dtype_x), type(dtype_y)} == {DateTime, Continuous}: | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, DateTime) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_line_chart(label, value) | |
| return "line", formatted_data, dtypes | |
| elif ( | |
| isinstance(dtype_x, DateTime) and isinstance(dtype_y, SmallCardNum) | |
| ) or (isinstance(dtype_y, DateTime) and isinstance(dtype_x, SmallCardNum)): | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, DateTime) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_line_chart(label, value) | |
| return "line", formatted_data, dtypes | |
| elif {type(dtype_x), type(dtype_y)} == {Nominal, SmallCardNum}: | |
| label, value = ( | |
| (data_x, data_y) | |
| if isinstance(dtype_x, Nominal) | |
| else (data_y, data_x) | |
| ) | |
| formatted_data = self.build_bar_chart(label, value) | |
| return "bar", formatted_data, dtypes | |
| return None, None, None | |
| class SQLVizChain: | |
| def __init__(self, duckdb: DuckDBPyConnection, chain): | |
| self._duckdb = duckdb | |
| self.chain = chain | |
| self.router = QueryRouter(chain=self.chain) | |
| self.sql_generator = SQLPipeline(duckdb, chain=self.chain) | |
| self.charting = ChartFormatter() | |
| def create_chart_config( | |
| self, query_df: pd.DataFrame, user_question: str, sql: str | |
| ) -> tuple[list[dict[Any, Any]] | None, dict[str, Any] | None, str | None]: | |
| """Format data for visualization and return chart config.""" | |
| ( | |
| chart_type, | |
| formatted_data, | |
| dtypes, | |
| ) = self.charting.format_and_select_chart(df=query_df) | |
| if not all([formatted_data, dtypes, chart_type]): | |
| return None, None, None | |
| chart_config = self.chain.run( | |
| system_prompt=CHART_CONFIG_SYSTEM_PROMPT, | |
| user_prompt=CHART_CONFIG_USER_PROMPT.format( | |
| question=user_question, | |
| sql_query=sql, | |
| dtypes=dtypes, | |
| chart_type=chart_type, | |
| ), | |
| format_name="chart_config", | |
| response_format=PlotConfig, | |
| ) | |
| logger.info(f"Chart Config Generated: {chart_config}") | |
| return formatted_data, chart_config, chart_type | |
| def create_viz_with_text_response( | |
| self, query_df: pd.DataFrame, user_question: str, sql_config: dict[Any, Any] | |
| ) -> dict[str, Any]: | |
| formatted_data, chart_config, chart_type = self.create_chart_config( | |
| query_df, user_question, sql_config["sql_query"] | |
| ) | |
| table_data = TableData(data=query_df) | |
| if not all([formatted_data, chart_config, chart_type]): | |
| logger.info("Failed to format data or generate chart config.") | |
| logger.info(f"Total Token Counts: {self.chain.total_tokens}") | |
| return { | |
| "chart_data": table_data, | |
| "chart_config": None, | |
| "chart_type": None, | |
| "sql_config": sql_config, | |
| } | |
| chart_data = Data.validate_data(data=formatted_data) | |
| if chart_config and chart_config["type"] in {"bar", "line", "pie", "hist"}: | |
| data = Charts(**{chart_config["type"]: chart_data}) | |
| else: | |
| raise ValueError( | |
| "Invalid Plot Type. Must be one of 'bar', 'line', 'pie', 'hist'" | |
| ) | |
| logger.info("Visualization Chain Completed Successfully.") | |
| logger.info(f"Total Token Counts: {self.chain.total_tokens}") | |
| return { | |
| "chart_data": data, | |
| "chart_config": chart_config, | |
| "chart_type": chart_type, | |
| "sql_config": sql_config, | |
| } | |
| def run(self, user_question: str, context: str) -> dict[str, Any]: | |
| """Main pipeline: question → SQL → data → chart config.""" | |
| route = self.router.route_request(user_question=user_question, context=context) | |
| if route == 0: | |
| return { | |
| "chart_data": None, | |
| "chart_config": None, | |
| "chart_type": None, | |
| "sql_config": None, | |
| } | |
| sql_config, query_df = self.sql_generator.try_sql_with_retries( | |
| user_question=user_question, context=context | |
| ) | |
| if sql_config is None or query_df is None: | |
| logger.info("Failed to generate or execute SQL after retries.") | |
| logger.info(f"Total Token Counts: {self.chain.total_tokens}") | |
| return { | |
| "chart_data": None, | |
| "chart_config": None, | |
| "chart_type": None, | |
| "sql_config": None, | |
| } | |
| return self.create_viz_with_text_response( | |
| query_df=query_df, user_question=user_question, sql_config=sql_config | |
| ) | |