DataViz-Agent / src /pipelines.py
Muhammad Mustehson
Update Old Code
4a84072
import logging
import os
from typing import Any
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from duckdb import DuckDBPyConnection
from src.models import (
Charts,
Continuous,
Data,
DateTime,
Nominal,
PlotConfig,
Route,
SmallCardNum,
SQLQueryModel,
TableData,
)
load_dotenv()
logger = logging.getLogger(__name__)
MAX_BARS_COUNT = 20
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
SQL_PROMPT = os.getenv("SQL_PROMPT")
USER_PROMPT = os.getenv("USER_PROMPT")
ROUTER_SYSTEM_PROMPT = os.getenv("ROUTER_SYSTEM_PROMPT")
CHART_CONFIG_SYSTEM_PROMPT = os.getenv("CHART_CONFIG_SYSTEM_PROMPT")
CHART_CONFIG_USER_PROMPT = os.getenv("CHART_CONFIG_USER_PROMPT")
class SQLPipeline:
def __init__(
self,
duckdb: DuckDBPyConnection,
chain,
) -> None:
self._duckdb = duckdb
self.chain = chain
def generate_sql(
self, user_question: str, context: str, errors: str | None = None
) -> str | dict[str, str | int | float | None] | list[str] | None:
"""Generate SQL + description."""
user_prompt_formatted = USER_PROMPT.format(
question=user_question, context=context
)
if errors:
user_prompt_formatted += f"Carefully review the previous error or\
exception and rewrite the SQL so that the error does not occur again.\
Try a different approach or rewrite SQL if needed. Last error: {errors}"
sql = self.chain.run(
system_prompt=SQL_PROMPT,
user_prompt=user_prompt_formatted,
format_name="sql_query",
response_format=SQLQueryModel,
)
logger.info(f"SQL Generated Successfully: {sql}")
return sql
def run_query(self, sql_query: str) -> pd.DataFrame | None:
"""Execute SQL and return dataframe."""
logger.info("Query Execution Started.")
return self._duckdb.query(sql_query).df()
def try_sql_with_retries(
self,
user_question: str,
context: str,
max_retries: int = SQL_GENERATION_RETRIES,
) -> tuple[
str | dict[str, str | int | float | None] | list[str] | None,
pd.DataFrame | None,
]:
"""Try SQL generation + execution with retries."""
last_error = None
all_errors = ""
for attempt in range(
1, max_retries + 2
): # @ Since the first is normal and not consider in retries
try:
if attempt > 1 and last_error:
logger.info(f"Retrying: {attempt - 1}")
# Generate SQL
sql = self.generate_sql(user_question, context, errors=all_errors)
if not sql:
return None, None
else:
# Generate SQL
sql = self.generate_sql(user_question, context)
if not sql:
return None, None
# Try executing query
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
if not isinstance(sql_query_str, str):
raise ValueError(
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
)
query_df = self.run_query(sql_query_str)
# If execution succeeds, stop retrying or if df is not empty
if query_df is not None and not query_df.empty:
return sql, query_df
except Exception as e:
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
logger.error(f"Error during SQL generation or execution: {last_error}")
all_errors += last_error
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
return None, None
class QueryRouter:
def __init__(self, chain) -> None:
self.chain = chain
def route_request(self, user_question: str, context: str) -> int:
"""Route the user question to 0, 1, or 2."""
user_prompt_formatted = USER_PROMPT.format(
question=user_question, context=context
)
route = self.chain.run(
system_prompt=ROUTER_SYSTEM_PROMPT,
user_prompt=user_prompt_formatted,
format_name="route_queries",
response_format=Route,
)
logger.info(
f"Query routed to: {route} Where if query is routed to 0 its irrelevant, if 1 its visualizable, if 2 its only sql, and 3 if its datetime."
)
return route
class ChartFormatter:
def _build_xy_data(self, label_data, value_data, limit_unique_x=False):
df = pd.DataFrame({"x": label_data, "y": value_data})
if limit_unique_x and df["x"].nunique() > MAX_BARS_COUNT:
df = df.head(MAX_BARS_COUNT)
return df.to_dict(orient="records")
def is_continuous(self, dtype) -> bool:
if pd.api.types.is_bool_dtype(dtype):
return False
return (
pd.api.types.is_integer_dtype(dtype)
or pd.api.types.is_float_dtype(dtype)
or pd.api.types.is_numeric_dtype(dtype)
)
def is_datetime(self, dtype) -> bool:
return pd.api.types.is_datetime64_any_dtype(
dtype
) or pd.api.types.is_timedelta64_dtype(dtype)
def detect_dtype(self, data):
"""Detects dtypes of columns."""
type_ = {}
for col_name in data.columns:
col_data = data[col_name]
if self.is_continuous(col_data.dtype):
# detect as categorical if distinct value is small
if isinstance(col_data, pd.Series):
nuniques = col_data.nunique()
else:
raise TypeError(f"unprocessed column type:{type(col_name)}")
small_cardinality_threshold = 10
if nuniques < small_cardinality_threshold:
type_[col_name] = SmallCardNum()
else:
type_[col_name] = Continuous()
elif self.is_datetime(col_data.dtype):
type_[col_name] = DateTime()
else:
type_[col_name] = Nominal()
return type_
def build_bar_chart(self, label_data, value_data):
return self._build_xy_data(label_data, value_data, limit_unique_x=True)
def build_line_chart(self, label_data, value_data):
return self._build_xy_data(label_data, value_data)
def build_pie_chart(self, label_data, value_data):
return self._build_xy_data(label_data, value_data)
def build_histogram(self, data):
range_ = (data.min(), data.max())
counts, bins = np.histogram(data, bins=50, range=range_)
return [
{
"bin_start": bins[i],
"bin_end": bins[i + 1],
"frequency": counts[i],
}
for i in range(len(counts))
]
def format_and_select_chart(self, df: pd.DataFrame):
cols = df.columns.tolist()
dtypes = self.detect_dtype(df)
if len(cols) == 1:
col = cols[0]
dtype = dtypes[col]
if isinstance(dtype, Continuous):
return "hist", self.build_histogram(df[col].dropna()), dtypes
if isinstance(dtype, (SmallCardNum, Nominal)):
counts = df[col].value_counts()
chart = "pie" if counts.size <= 6 else "bar"
builder = (
self.build_pie_chart if chart == "pie" else self.build_bar_chart
)
return chart, builder(counts.index, counts.values), dtypes
if len(cols) == 2:
x, y = cols
dtype_x = dtypes[x]
dtype_y = dtypes[y]
data_x = df[x]
data_y = df[y]
if {type(dtype_x), type(dtype_y)} == {Nominal, Continuous}:
label, value = (
(data_x, data_y)
if isinstance(dtype_x, Nominal)
else (data_y, data_x)
)
formatted_data = self.build_bar_chart(label, value)
return "bar", formatted_data, dtypes
elif {type(dtype_x), type(dtype_y)} == {Continuous, Continuous}:
label, value = (
(data_x, data_y)
if isinstance(dtype_x, Continuous)
else (data_y, data_x)
)
formatted_data = self.build_bar_chart(label, value)
return "bar", formatted_data, dtypes
elif {type(dtype_x), type(dtype_y)} == {SmallCardNum, Continuous}:
label, value = (
(data_x, data_y)
if isinstance(dtype_x, SmallCardNum)
else (data_y, data_x)
)
formatted_data = self.build_bar_chart(label, value)
return "bar", formatted_data, dtypes
elif isinstance(dtype_x, SmallCardNum) and isinstance(
dtype_y, SmallCardNum
):
formatted_data = self.build_bar_chart(data_x, data_y)
return "bar", formatted_data, dtypes
elif {type(dtype_x), type(dtype_y)} == {DateTime, Continuous}:
label, value = (
(data_x, data_y)
if isinstance(dtype_x, DateTime)
else (data_y, data_x)
)
formatted_data = self.build_line_chart(label, value)
return "line", formatted_data, dtypes
elif (
isinstance(dtype_x, DateTime) and isinstance(dtype_y, SmallCardNum)
) or (isinstance(dtype_y, DateTime) and isinstance(dtype_x, SmallCardNum)):
label, value = (
(data_x, data_y)
if isinstance(dtype_x, DateTime)
else (data_y, data_x)
)
formatted_data = self.build_line_chart(label, value)
return "line", formatted_data, dtypes
elif {type(dtype_x), type(dtype_y)} == {Nominal, SmallCardNum}:
label, value = (
(data_x, data_y)
if isinstance(dtype_x, Nominal)
else (data_y, data_x)
)
formatted_data = self.build_bar_chart(label, value)
return "bar", formatted_data, dtypes
return None, None, None
class SQLVizChain:
def __init__(self, duckdb: DuckDBPyConnection, chain):
self._duckdb = duckdb
self.chain = chain
self.router = QueryRouter(chain=self.chain)
self.sql_generator = SQLPipeline(duckdb, chain=self.chain)
self.charting = ChartFormatter()
def create_chart_config(
self, query_df: pd.DataFrame, user_question: str, sql: str
) -> tuple[list[dict[Any, Any]] | None, dict[str, Any] | None, str | None]:
"""Format data for visualization and return chart config."""
(
chart_type,
formatted_data,
dtypes,
) = self.charting.format_and_select_chart(df=query_df)
if not all([formatted_data, dtypes, chart_type]):
return None, None, None
chart_config = self.chain.run(
system_prompt=CHART_CONFIG_SYSTEM_PROMPT,
user_prompt=CHART_CONFIG_USER_PROMPT.format(
question=user_question,
sql_query=sql,
dtypes=dtypes,
chart_type=chart_type,
),
format_name="chart_config",
response_format=PlotConfig,
)
logger.info(f"Chart Config Generated: {chart_config}")
return formatted_data, chart_config, chart_type
def create_viz_with_text_response(
self, query_df: pd.DataFrame, user_question: str, sql_config: dict[Any, Any]
) -> dict[str, Any]:
formatted_data, chart_config, chart_type = self.create_chart_config(
query_df, user_question, sql_config["sql_query"]
)
table_data = TableData(data=query_df)
if not all([formatted_data, chart_config, chart_type]):
logger.info("Failed to format data or generate chart config.")
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
return {
"chart_data": table_data,
"chart_config": None,
"chart_type": None,
"sql_config": sql_config,
}
chart_data = Data.validate_data(data=formatted_data)
if chart_config and chart_config["type"] in {"bar", "line", "pie", "hist"}:
data = Charts(**{chart_config["type"]: chart_data})
else:
raise ValueError(
"Invalid Plot Type. Must be one of 'bar', 'line', 'pie', 'hist'"
)
logger.info("Visualization Chain Completed Successfully.")
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
return {
"chart_data": data,
"chart_config": chart_config,
"chart_type": chart_type,
"sql_config": sql_config,
}
def run(self, user_question: str, context: str) -> dict[str, Any]:
"""Main pipeline: question → SQL → data → chart config."""
route = self.router.route_request(user_question=user_question, context=context)
if route == 0:
return {
"chart_data": None,
"chart_config": None,
"chart_type": None,
"sql_config": None,
}
sql_config, query_df = self.sql_generator.try_sql_with_retries(
user_question=user_question, context=context
)
if sql_config is None or query_df is None:
logger.info("Failed to generate or execute SQL after retries.")
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
return {
"chart_data": None,
"chart_config": None,
"chart_type": None,
"sql_config": None,
}
return self.create_viz_with_text_response(
query_df=query_df, user_question=user_question, sql_config=sql_config
)