Spaces:
Sleeping
Sleeping
Muhammad Mustehson commited on
Commit ·
4a84072
1
Parent(s): 332cf4d
Update Old Code
Browse files- .gitignore +1 -0
- __pycache__/prompt.cpython-311.pyc +0 -0
- app.py +86 -103
- requirements.txt +7 -8
- src/__init__.py +0 -0
- src/client.py +120 -0
- src/models.py +142 -0
- src/pipelines.py +401 -0
- src/utils.py +57 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
__pycache__/prompt.cpython-311.pyc
DELETED
|
Binary file (2.63 kB)
|
|
|
app.py
CHANGED
|
@@ -1,47 +1,35 @@
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import duckdb
|
| 3 |
import gradio as gr
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
-
|
| 6 |
-
from transformers.agents import Tool
|
| 7 |
-
from langsmith import traceable
|
| 8 |
-
from langchain import hub
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
TAB_LINES = 8
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
#----------CONNECT TO DATABASE----------
|
| 16 |
-
md_token = os.getenv('MD_TOKEN')
|
| 17 |
-
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
| 18 |
-
#---------------------------------------
|
| 19 |
-
|
| 20 |
-
#-------LOAD HUGGINGFACE MODEL-------
|
| 21 |
-
models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
|
| 22 |
-
"meta-llama/Llama-3.1-70B-Instruct"]
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
try:
|
| 27 |
-
llm_engine = HfEngine(model=model)
|
| 28 |
-
info = llm_engine.client.get_endpoint_info()
|
| 29 |
-
model_loaded = True
|
| 30 |
-
break
|
| 31 |
-
except Exception as e:
|
| 32 |
-
print(f"Error for model {model}: {e}")
|
| 33 |
-
continue
|
| 34 |
|
| 35 |
-
if not model_loaded:
|
| 36 |
-
gr.Warning(f"❌ None of the model form {models} are available. {e}")
|
| 37 |
-
#---------------------------------------
|
| 38 |
|
| 39 |
-
|
| 40 |
-
prompt = hub.pull("viz-prompt")
|
| 41 |
-
#-------------------------------------
|
| 42 |
|
| 43 |
|
| 44 |
-
#--------------ALL UTILS----------------
|
| 45 |
def get_schemas():
|
| 46 |
schemas = conn.execute("""
|
| 47 |
SELECT DISTINCT schema_name
|
|
@@ -50,22 +38,26 @@ def get_schemas():
|
|
| 50 |
""").fetchall()
|
| 51 |
return [item[0] for item in schemas]
|
| 52 |
|
| 53 |
-
|
| 54 |
def get_tables(schema_name):
|
| 55 |
-
tables = conn.execute(
|
|
|
|
|
|
|
| 56 |
return [table[0] for table in tables]
|
| 57 |
|
| 58 |
-
|
| 59 |
def update_tables(schema_name):
|
| 60 |
tables = get_tables(schema_name)
|
| 61 |
return gr.update(choices=tables)
|
| 62 |
|
| 63 |
-
|
| 64 |
def get_table_schema(table):
|
| 65 |
-
result = conn.sql(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
full_path = f"{parent_database}.{schema_name}.{table}"
|
| 70 |
if schema_name != "main":
|
| 71 |
old_path = f"{schema_name}.{table}"
|
|
@@ -75,62 +67,39 @@ def get_table_schema(table):
|
|
| 75 |
return ddl_create, full_path
|
| 76 |
|
| 77 |
|
| 78 |
-
class SQLExecutorTool(Tool):
|
| 79 |
-
name = "sql_engine"
|
| 80 |
-
inputs = {
|
| 81 |
-
"query": {
|
| 82 |
-
"type": "text",
|
| 83 |
-
"description": f"The query to perform. This should be correct DuckDB SQL.",
|
| 84 |
-
}
|
| 85 |
-
}
|
| 86 |
-
description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
|
| 87 |
-
output_type = "pandas.core.frame.DataFrame"
|
| 88 |
-
|
| 89 |
-
def forward(self, query: str) -> str:
|
| 90 |
-
output_df = conn.sql(query).df()
|
| 91 |
-
return output_df
|
| 92 |
-
|
| 93 |
-
tool = SQLExecutorTool()
|
| 94 |
-
|
| 95 |
-
def process_outputs(output) :
|
| 96 |
-
return {
|
| 97 |
-
'sql': output.get('sql', None),
|
| 98 |
-
'code': output.get('code', None)
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
@traceable(process_outputs=process_outputs)
|
| 102 |
-
def get_visualization(question, schema, table_name):
|
| 103 |
-
agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
|
| 104 |
-
additional_authorized_imports=['matplotlib.pyplot',
|
| 105 |
-
'pandas', 'plotly.express',
|
| 106 |
-
'seaborn'], max_iterations=10)
|
| 107 |
-
results = agent.run(
|
| 108 |
-
task= prompt.format(question=question, schema=schema, table_name=table_name)
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
return results
|
| 112 |
-
#---------------------------------------
|
| 113 |
-
|
| 114 |
-
|
| 115 |
def main(table, text_query):
|
| 116 |
-
# Empty Fig
|
| 117 |
fig, ax = plt.subplots()
|
| 118 |
-
ax.set_axis_off()
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
try:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
except Exception as e:
|
|
|
|
| 129 |
gr.Warning(f"❌ Unable to generate the visualization. {e}")
|
| 130 |
-
|
| 131 |
-
return fig,
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
custom_css = """
|
| 136 |
.gradio-container {
|
|
@@ -150,7 +119,9 @@ custom_css = """
|
|
| 150 |
}
|
| 151 |
"""
|
| 152 |
|
| 153 |
-
with gr.Blocks(
|
|
|
|
|
|
|
| 154 |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
|
| 155 |
|
| 156 |
gr.Markdown("""
|
|
@@ -162,13 +133,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
| 162 |
""")
|
| 163 |
|
| 164 |
with gr.Row():
|
| 165 |
-
|
| 166 |
with gr.Column(scale=1):
|
| 167 |
-
schema_dropdown = gr.Dropdown(
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
with gr.Column(scale=2):
|
| 171 |
-
query_input = gr.Textbox(
|
|
|
|
|
|
|
| 172 |
with gr.Row():
|
| 173 |
with gr.Column(scale=7):
|
| 174 |
pass
|
|
@@ -178,18 +154,25 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
| 178 |
with gr.Tabs():
|
| 179 |
with gr.Tab("Plot"):
|
| 180 |
result_plot = gr.Plot()
|
| 181 |
-
with gr.Tab("SQL"):
|
| 182 |
-
generated_sql = gr.Textbox(
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
data = gr.Dataframe(label="Data", interactive=False)
|
| 186 |
|
| 187 |
-
schema_dropdown.change(
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
if __name__ == "__main__":
|
| 191 |
demo.launch(debug=True)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
+
|
| 4 |
import duckdb
|
| 5 |
import gradio as gr
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
from src.client import LLMChain
|
| 10 |
+
from src.models import Charts, TableData
|
| 11 |
+
from src.pipelines import SQLVizChain
|
| 12 |
+
from src.utils import plot_chart
|
| 13 |
|
| 14 |
+
MD_TOKEN = os.getenv("MD_TOKEN")
|
| 15 |
+
conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True)
|
| 16 |
+
LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
|
| 17 |
TAB_LINES = 8
|
| 18 |
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=getattr(logging, LEVEL, logging.INFO),
|
| 21 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def _load_pipeline():
|
| 27 |
+
return SQLVizChain(duckdb=conn, chain=LLMChain())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
pipeline = _load_pipeline()
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
|
|
|
| 33 |
def get_schemas():
|
| 34 |
schemas = conn.execute("""
|
| 35 |
SELECT DISTINCT schema_name
|
|
|
|
| 38 |
""").fetchall()
|
| 39 |
return [item[0] for item in schemas]
|
| 40 |
|
| 41 |
+
|
| 42 |
def get_tables(schema_name):
|
| 43 |
+
tables = conn.execute(
|
| 44 |
+
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'"
|
| 45 |
+
).fetchall()
|
| 46 |
return [table[0] for table in tables]
|
| 47 |
|
| 48 |
+
|
| 49 |
def update_tables(schema_name):
|
| 50 |
tables = get_tables(schema_name)
|
| 51 |
return gr.update(choices=tables)
|
| 52 |
|
| 53 |
+
|
| 54 |
def get_table_schema(table):
|
| 55 |
+
result = conn.sql(
|
| 56 |
+
f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
|
| 57 |
+
).df()
|
| 58 |
+
ddl_create = result.iloc[0, 0]
|
| 59 |
+
parent_database = result.iloc[0, 1]
|
| 60 |
+
schema_name = result.iloc[0, 2]
|
| 61 |
full_path = f"{parent_database}.{schema_name}.{table}"
|
| 62 |
if schema_name != "main":
|
| 63 |
old_path = f"{schema_name}.{table}"
|
|
|
|
| 67 |
return ddl_create, full_path
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def main(table, text_query):
|
|
|
|
| 71 |
fig, ax = plt.subplots()
|
| 72 |
+
ax.set_axis_off()
|
| 73 |
+
schema, _ = get_table_schema(table)
|
| 74 |
+
|
|
|
|
| 75 |
try:
|
| 76 |
+
results = pipeline.run(user_question=text_query, context=schema)
|
| 77 |
+
chart_data = results["chart_data"]
|
| 78 |
+
chart_config = results["chart_config"]
|
| 79 |
+
chart_type = results["chart_type"]
|
| 80 |
+
generated_sql = results["sql_config"]["sql_query"]
|
| 81 |
+
|
| 82 |
+
if not chart_type and chart_data is not None:
|
| 83 |
+
if isinstance(chart_data, TableData):
|
| 84 |
+
data = pd.DataFrame(chart_data.model_dump(exclude_none=True))
|
| 85 |
+
return (fig, generated_sql, data)
|
| 86 |
+
|
| 87 |
+
if chart_type is not None and chart_data is not None:
|
| 88 |
+
if isinstance(chart_data, Charts):
|
| 89 |
+
chart_dict = chart_data.model_dump(exclude_none=True).get(chart_type)
|
| 90 |
+
data = pd.DataFrame(chart_dict["data"])
|
| 91 |
+
fig = plot_chart(chart_type=chart_type, data=data, **chart_config)
|
| 92 |
+
return (fig, generated_sql, data)
|
| 93 |
+
|
| 94 |
+
if chart_data is None:
|
| 95 |
+
return fig, generated_sql, None
|
| 96 |
|
| 97 |
except Exception as e:
|
| 98 |
+
logger.error(e)
|
| 99 |
gr.Warning(f"❌ Unable to generate the visualization. {e}")
|
| 100 |
+
|
| 101 |
+
return fig, None, None
|
| 102 |
+
|
|
|
|
| 103 |
|
| 104 |
custom_css = """
|
| 105 |
.gradio-container {
|
|
|
|
| 119 |
}
|
| 120 |
"""
|
| 121 |
|
| 122 |
+
with gr.Blocks(
|
| 123 |
+
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
|
| 124 |
+
) as demo:
|
| 125 |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
|
| 126 |
|
| 127 |
gr.Markdown("""
|
|
|
|
| 133 |
""")
|
| 134 |
|
| 135 |
with gr.Row():
|
|
|
|
| 136 |
with gr.Column(scale=1):
|
| 137 |
+
schema_dropdown = gr.Dropdown(
|
| 138 |
+
choices=get_schemas(), label="Select Schema", interactive=True
|
| 139 |
+
)
|
| 140 |
+
tables_dropdown = gr.Dropdown(
|
| 141 |
+
choices=[], label="Available Tables", value=None
|
| 142 |
+
)
|
| 143 |
|
| 144 |
with gr.Column(scale=2):
|
| 145 |
+
query_input = gr.Textbox(
|
| 146 |
+
lines=3, label="Text Query", placeholder="Enter your text query here..."
|
| 147 |
+
)
|
| 148 |
with gr.Row():
|
| 149 |
with gr.Column(scale=7):
|
| 150 |
pass
|
|
|
|
| 154 |
with gr.Tabs():
|
| 155 |
with gr.Tab("Plot"):
|
| 156 |
result_plot = gr.Plot()
|
| 157 |
+
with gr.Tab("SQL"):
|
| 158 |
+
generated_sql = gr.Textbox(
|
| 159 |
+
lines=TAB_LINES,
|
| 160 |
+
label="Generated SQL",
|
| 161 |
+
value="",
|
| 162 |
+
interactive=False,
|
| 163 |
+
autoscroll=False,
|
| 164 |
+
)
|
| 165 |
+
with gr.Tab("Data"):
|
| 166 |
data = gr.Dataframe(label="Data", interactive=False)
|
| 167 |
|
| 168 |
+
schema_dropdown.change(
|
| 169 |
+
update_tables, inputs=schema_dropdown, outputs=tables_dropdown
|
| 170 |
+
)
|
| 171 |
+
generate_query_button.click(
|
| 172 |
+
main,
|
| 173 |
+
inputs=[tables_dropdown, query_input],
|
| 174 |
+
outputs=[result_plot, generated_sql, data],
|
| 175 |
+
)
|
| 176 |
|
| 177 |
if __name__ == "__main__":
|
| 178 |
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
torch
|
| 2 |
-
seaborn
|
| 3 |
-
plotly
|
| 4 |
huggingface_hub
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
huggingface_hub
|
| 2 |
+
duckdb
|
| 3 |
+
pandas
|
| 4 |
+
pydantic
|
| 5 |
+
python-dotenv
|
| 6 |
+
gradio
|
| 7 |
+
pandas
|
| 8 |
+
matplotlib
|
src/__init__.py
ADDED
|
File without changes
|
src/client.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from huggingface_hub import InferenceClient
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
MAX_RESPONSE_TOKENS = 2048
|
| 14 |
+
TEMPERATURE = 0.9
|
| 15 |
+
|
| 16 |
+
models = json.loads(os.getenv("MODEL_NAMES"))
|
| 17 |
+
providers = json.loads(os.getenv("PROVIDERS"))
|
| 18 |
+
EMB_MODEL = os.getenv("EMB_MODEL")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _engine_working(engine: InferenceClient) -> bool:
|
| 22 |
+
try:
|
| 23 |
+
engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
|
| 24 |
+
logger.info("Engine is Working.")
|
| 25 |
+
return True
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logger.exception(f"Engine is not working: {e}")
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _load_llm_client() -> InferenceClient:
|
| 32 |
+
"""
|
| 33 |
+
Attempts to load the provided model from the huggingface endpoint.
|
| 34 |
+
|
| 35 |
+
Returns InferenceClient if successful.
|
| 36 |
+
Raises Exception if no model is available.
|
| 37 |
+
"""
|
| 38 |
+
logger.warning("Loading Model...")
|
| 39 |
+
errors = []
|
| 40 |
+
for model in models:
|
| 41 |
+
for provider in providers:
|
| 42 |
+
if isinstance(model, str):
|
| 43 |
+
try:
|
| 44 |
+
logger.info(f"Checking model: {model} provider: {provider}")
|
| 45 |
+
client = InferenceClient(
|
| 46 |
+
model=model,
|
| 47 |
+
timeout=15,
|
| 48 |
+
provider=provider,
|
| 49 |
+
)
|
| 50 |
+
if _engine_working(client):
|
| 51 |
+
logger.info(
|
| 52 |
+
f"The model is loaded : {model} , provider: {provider}"
|
| 53 |
+
)
|
| 54 |
+
return client
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(
|
| 57 |
+
f"Error loading model {model} provider {provider}: {e}"
|
| 58 |
+
)
|
| 59 |
+
errors.append(str(e))
|
| 60 |
+
raise Exception(f"Unable to load any provided model: {errors}.")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
_default_client = _load_llm_client()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LLMChain:
|
| 67 |
+
def __init__(self, client: InferenceClient = _default_client):
|
| 68 |
+
self.client = client
|
| 69 |
+
self.total_tokens = 0
|
| 70 |
+
|
| 71 |
+
def run(
|
| 72 |
+
self,
|
| 73 |
+
system_prompt: str | None = None,
|
| 74 |
+
user_prompt: str | None = None,
|
| 75 |
+
messages: list[dict] | None = None,
|
| 76 |
+
format_name: str | None = None,
|
| 77 |
+
response_format: type[BaseModel] | None = None,
|
| 78 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 79 |
+
try:
|
| 80 |
+
if system_prompt and user_prompt:
|
| 81 |
+
messages = [
|
| 82 |
+
{"role": "system", "content": system_prompt},
|
| 83 |
+
{"role": "user", "content": user_prompt},
|
| 84 |
+
]
|
| 85 |
+
elif not messages:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
"Either system_prompt and user_prompt or messages must be provided."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
llm_response = self.client.chat_completion(
|
| 91 |
+
messages=messages,
|
| 92 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 93 |
+
temperature=TEMPERATURE,
|
| 94 |
+
response_format=(
|
| 95 |
+
{
|
| 96 |
+
"type": "json_schema",
|
| 97 |
+
"json_schema": {
|
| 98 |
+
"name": format_name,
|
| 99 |
+
"schema": response_format.model_json_schema(),
|
| 100 |
+
"strict": True,
|
| 101 |
+
},
|
| 102 |
+
}
|
| 103 |
+
if format_name and response_format
|
| 104 |
+
else None
|
| 105 |
+
),
|
| 106 |
+
)
|
| 107 |
+
self.total_tokens += llm_response.usage.total_tokens
|
| 108 |
+
analysis = llm_response.choices[0].message.content
|
| 109 |
+
if response_format:
|
| 110 |
+
analysis = json.loads(analysis)
|
| 111 |
+
fields = list(response_format.model_fields.keys())
|
| 112 |
+
if len(fields) == 1:
|
| 113 |
+
return analysis.get(fields[0])
|
| 114 |
+
return {field: analysis.get(field) for field in fields}
|
| 115 |
+
|
| 116 |
+
return analysis
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"Error during LLM calls: {e}")
|
| 120 |
+
return None
|
src/models.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SmallCardNum:
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Continuous:
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DateTime:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Nominal:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Route(BaseModel):
|
| 26 |
+
label: int = Field(
|
| 27 |
+
description="Classify user queries as: 0 for Irrelevant/Vague/Incomplete, 1 for Visualizable, and 2 for SQL-only."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SQLQueryModel(BaseModel):
|
| 32 |
+
sql_query: str = Field(..., description="SQL query to execute.")
|
| 33 |
+
explanation: str = Field(..., description="Short explanation of the SQL query.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DataPoint(BaseModel):
|
| 37 |
+
x: int | float | str | None = None
|
| 38 |
+
y: int | float | str | None = None
|
| 39 |
+
bin_start: int | float | None = None
|
| 40 |
+
bin_end: int | float | None = None
|
| 41 |
+
frequency: int | float | None = None
|
| 42 |
+
|
| 43 |
+
@field_validator("bin_start", "bin_end", "frequency", "x", "y", mode="before")
|
| 44 |
+
@classmethod
|
| 45 |
+
def to_native(cls, field_value):
|
| 46 |
+
if field_value is not None and isinstance(
|
| 47 |
+
field_value, np.float64 | np.float32 | np.int64
|
| 48 |
+
):
|
| 49 |
+
return float(field_value)
|
| 50 |
+
if isinstance(field_value, (datetime, np.datetime64, pd.Timestamp)): # noqa: UP038
|
| 51 |
+
return field_value.strftime("%Y-%m-%d")
|
| 52 |
+
return field_value
|
| 53 |
+
|
| 54 |
+
@model_validator(mode="before")
|
| 55 |
+
@classmethod
|
| 56 |
+
def validate_keys(cls, values):
|
| 57 |
+
x, y = values.get("x"), values.get("y")
|
| 58 |
+
bin_start, bin_end, frequency = (
|
| 59 |
+
values.get("bin_start"),
|
| 60 |
+
values.get("bin_end"),
|
| 61 |
+
values.get("frequency"),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
xy = x is not None and y is not None
|
| 65 |
+
bxy = bin_start is not None and bin_end is not None and frequency is not None
|
| 66 |
+
|
| 67 |
+
if not (xy or bxy):
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Invalid input: Must provide either (x, y) OR (bin_start, bin_end, frequency), but not a mix."
|
| 70 |
+
)
|
| 71 |
+
return values
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Data(BaseModel):
|
| 75 |
+
data: list[DataPoint] = Field(default_factory=list)
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def validate_data(cls, data):
|
| 79 |
+
try:
|
| 80 |
+
return cls(data=data)
|
| 81 |
+
except ValidationError as e:
|
| 82 |
+
raise ValueError(f"Invalid data format: {e.errors()[0]}") # noqa: B904
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TableData(BaseModel):
|
| 86 |
+
data: pd.DataFrame = Field(default_factory=None)
|
| 87 |
+
|
| 88 |
+
class Config:
|
| 89 |
+
arbitrary_types_allowed = True
|
| 90 |
+
|
| 91 |
+
@model_validator(mode="after")
|
| 92 |
+
def timestamp_to_str(self):
|
| 93 |
+
# Convert all datetime columns to string format
|
| 94 |
+
for col in self.data.select_dtypes(include=["datetime"]).columns:
|
| 95 |
+
if col:
|
| 96 |
+
self.data[col] = self.data[col].astype(str)
|
| 97 |
+
return self
|
| 98 |
+
|
| 99 |
+
def model_dump(self, *args, **kwargs): # noqa: ARG002
|
| 100 |
+
return self.data.to_dict(orient="list")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Charts(BaseModel):
|
| 104 |
+
bar: Data | None = None
|
| 105 |
+
line: Data | None = None
|
| 106 |
+
pie: Data | None = None
|
| 107 |
+
hist: Data | None = None
|
| 108 |
+
|
| 109 |
+
@model_validator(mode="after")
|
| 110 |
+
def process_charts_data(self):
|
| 111 |
+
def stringify(data):
|
| 112 |
+
if data and data.data:
|
| 113 |
+
for point in data.data:
|
| 114 |
+
if not isinstance(point.x, str):
|
| 115 |
+
point.x = str(point.x)
|
| 116 |
+
return data
|
| 117 |
+
|
| 118 |
+
if self.bar:
|
| 119 |
+
self.bar = stringify(self.bar)
|
| 120 |
+
if self.pie:
|
| 121 |
+
self.pie = stringify(self.pie)
|
| 122 |
+
|
| 123 |
+
return self
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class PlotType(str, Enum):
|
| 127 |
+
bar = ("bar",)
|
| 128 |
+
line = ("line",)
|
| 129 |
+
pie = ("pie",)
|
| 130 |
+
hist = ("hist",)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PlotConfig(BaseModel):
|
| 134 |
+
type: PlotType = Field(
|
| 135 |
+
description="Type of plot, e.g., 'bar', 'line', 'pie'. Supported types depend on ShadCN implementation.",
|
| 136 |
+
)
|
| 137 |
+
title: str = Field(description="Title of the plot to display above the plot.")
|
| 138 |
+
x_axis_label: str = Field(description="Label for the X-axis of the plot.")
|
| 139 |
+
y_axis_label: str = Field(description="Label for the Y-axis of the plot.")
|
| 140 |
+
legend: bool = Field(
|
| 141 |
+
default=True, description="Flag to display a legend for the plot."
|
| 142 |
+
)
|
src/pipelines.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from duckdb import DuckDBPyConnection
|
| 9 |
+
|
| 10 |
+
from src.models import (
|
| 11 |
+
Charts,
|
| 12 |
+
Continuous,
|
| 13 |
+
Data,
|
| 14 |
+
DateTime,
|
| 15 |
+
Nominal,
|
| 16 |
+
PlotConfig,
|
| 17 |
+
Route,
|
| 18 |
+
SmallCardNum,
|
| 19 |
+
SQLQueryModel,
|
| 20 |
+
TableData,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
load_dotenv()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAX_BARS_COUNT = 20
|
| 30 |
+
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
|
| 31 |
+
SQL_PROMPT = os.getenv("SQL_PROMPT")
|
| 32 |
+
USER_PROMPT = os.getenv("USER_PROMPT")
|
| 33 |
+
ROUTER_SYSTEM_PROMPT = os.getenv("ROUTER_SYSTEM_PROMPT")
|
| 34 |
+
CHART_CONFIG_SYSTEM_PROMPT = os.getenv("CHART_CONFIG_SYSTEM_PROMPT")
|
| 35 |
+
CHART_CONFIG_USER_PROMPT = os.getenv("CHART_CONFIG_USER_PROMPT")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SQLPipeline:
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
duckdb: DuckDBPyConnection,
|
| 42 |
+
chain,
|
| 43 |
+
) -> None:
|
| 44 |
+
self._duckdb = duckdb
|
| 45 |
+
self.chain = chain
|
| 46 |
+
|
| 47 |
+
def generate_sql(
|
| 48 |
+
self, user_question: str, context: str, errors: str | None = None
|
| 49 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 50 |
+
"""Generate SQL + description."""
|
| 51 |
+
user_prompt_formatted = USER_PROMPT.format(
|
| 52 |
+
question=user_question, context=context
|
| 53 |
+
)
|
| 54 |
+
if errors:
|
| 55 |
+
user_prompt_formatted += f"Carefully review the previous error or\
|
| 56 |
+
exception and rewrite the SQL so that the error does not occur again.\
|
| 57 |
+
Try a different approach or rewrite SQL if needed. Last error: {errors}"
|
| 58 |
+
|
| 59 |
+
sql = self.chain.run(
|
| 60 |
+
system_prompt=SQL_PROMPT,
|
| 61 |
+
user_prompt=user_prompt_formatted,
|
| 62 |
+
format_name="sql_query",
|
| 63 |
+
response_format=SQLQueryModel,
|
| 64 |
+
)
|
| 65 |
+
logger.info(f"SQL Generated Successfully: {sql}")
|
| 66 |
+
return sql
|
| 67 |
+
|
| 68 |
+
def run_query(self, sql_query: str) -> pd.DataFrame | None:
|
| 69 |
+
"""Execute SQL and return dataframe."""
|
| 70 |
+
logger.info("Query Execution Started.")
|
| 71 |
+
return self._duckdb.query(sql_query).df()
|
| 72 |
+
|
| 73 |
+
def try_sql_with_retries(
|
| 74 |
+
self,
|
| 75 |
+
user_question: str,
|
| 76 |
+
context: str,
|
| 77 |
+
max_retries: int = SQL_GENERATION_RETRIES,
|
| 78 |
+
) -> tuple[
|
| 79 |
+
str | dict[str, str | int | float | None] | list[str] | None,
|
| 80 |
+
pd.DataFrame | None,
|
| 81 |
+
]:
|
| 82 |
+
"""Try SQL generation + execution with retries."""
|
| 83 |
+
last_error = None
|
| 84 |
+
all_errors = ""
|
| 85 |
+
|
| 86 |
+
for attempt in range(
|
| 87 |
+
1, max_retries + 2
|
| 88 |
+
): # @ Since the first is normal and not consider in retries
|
| 89 |
+
try:
|
| 90 |
+
if attempt > 1 and last_error:
|
| 91 |
+
logger.info(f"Retrying: {attempt - 1}")
|
| 92 |
+
# Generate SQL
|
| 93 |
+
sql = self.generate_sql(user_question, context, errors=all_errors)
|
| 94 |
+
if not sql:
|
| 95 |
+
return None, None
|
| 96 |
+
else:
|
| 97 |
+
# Generate SQL
|
| 98 |
+
sql = self.generate_sql(user_question, context)
|
| 99 |
+
if not sql:
|
| 100 |
+
return None, None
|
| 101 |
+
|
| 102 |
+
# Try executing query
|
| 103 |
+
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
|
| 104 |
+
if not isinstance(sql_query_str, str):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
|
| 107 |
+
)
|
| 108 |
+
query_df = self.run_query(sql_query_str)
|
| 109 |
+
|
| 110 |
+
# If execution succeeds, stop retrying or if df is not empty
|
| 111 |
+
if query_df is not None and not query_df.empty:
|
| 112 |
+
return sql, query_df
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
|
| 116 |
+
logger.error(f"Error during SQL generation or execution: {last_error}")
|
| 117 |
+
all_errors += last_error
|
| 118 |
+
|
| 119 |
+
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
|
| 120 |
+
return None, None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class QueryRouter:
|
| 124 |
+
def __init__(self, chain) -> None:
|
| 125 |
+
self.chain = chain
|
| 126 |
+
|
| 127 |
+
def route_request(self, user_question: str, context: str) -> int:
|
| 128 |
+
"""Route the user question to 0, 1, or 2."""
|
| 129 |
+
user_prompt_formatted = USER_PROMPT.format(
|
| 130 |
+
question=user_question, context=context
|
| 131 |
+
)
|
| 132 |
+
route = self.chain.run(
|
| 133 |
+
system_prompt=ROUTER_SYSTEM_PROMPT,
|
| 134 |
+
user_prompt=user_prompt_formatted,
|
| 135 |
+
format_name="route_queries",
|
| 136 |
+
response_format=Route,
|
| 137 |
+
)
|
| 138 |
+
logger.info(
|
| 139 |
+
f"Query routed to: {route} Where if query is routed to 0 its irrelevant, if 1 its visualizable, if 2 its only sql, and 3 if its datetime."
|
| 140 |
+
)
|
| 141 |
+
return route
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ChartFormatter:
|
| 145 |
+
def _build_xy_data(self, label_data, value_data, limit_unique_x=False):
|
| 146 |
+
df = pd.DataFrame({"x": label_data, "y": value_data})
|
| 147 |
+
|
| 148 |
+
if limit_unique_x and df["x"].nunique() > MAX_BARS_COUNT:
|
| 149 |
+
df = df.head(MAX_BARS_COUNT)
|
| 150 |
+
|
| 151 |
+
return df.to_dict(orient="records")
|
| 152 |
+
|
| 153 |
+
def is_continuous(self, dtype) -> bool:
|
| 154 |
+
if pd.api.types.is_bool_dtype(dtype):
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
return (
|
| 158 |
+
pd.api.types.is_integer_dtype(dtype)
|
| 159 |
+
or pd.api.types.is_float_dtype(dtype)
|
| 160 |
+
or pd.api.types.is_numeric_dtype(dtype)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def is_datetime(self, dtype) -> bool:
|
| 164 |
+
return pd.api.types.is_datetime64_any_dtype(
|
| 165 |
+
dtype
|
| 166 |
+
) or pd.api.types.is_timedelta64_dtype(dtype)
|
| 167 |
+
|
| 168 |
+
def detect_dtype(self, data):
|
| 169 |
+
"""Detects dtypes of columns."""
|
| 170 |
+
type_ = {}
|
| 171 |
+
for col_name in data.columns:
|
| 172 |
+
col_data = data[col_name]
|
| 173 |
+
if self.is_continuous(col_data.dtype):
|
| 174 |
+
# detect as categorical if distinct value is small
|
| 175 |
+
if isinstance(col_data, pd.Series):
|
| 176 |
+
nuniques = col_data.nunique()
|
| 177 |
+
else:
|
| 178 |
+
raise TypeError(f"unprocessed column type:{type(col_name)}")
|
| 179 |
+
small_cardinality_threshold = 10
|
| 180 |
+
if nuniques < small_cardinality_threshold:
|
| 181 |
+
type_[col_name] = SmallCardNum()
|
| 182 |
+
else:
|
| 183 |
+
type_[col_name] = Continuous()
|
| 184 |
+
|
| 185 |
+
elif self.is_datetime(col_data.dtype):
|
| 186 |
+
type_[col_name] = DateTime()
|
| 187 |
+
else:
|
| 188 |
+
type_[col_name] = Nominal()
|
| 189 |
+
return type_
|
| 190 |
+
|
| 191 |
+
def build_bar_chart(self, label_data, value_data):
|
| 192 |
+
return self._build_xy_data(label_data, value_data, limit_unique_x=True)
|
| 193 |
+
|
| 194 |
+
def build_line_chart(self, label_data, value_data):
|
| 195 |
+
return self._build_xy_data(label_data, value_data)
|
| 196 |
+
|
| 197 |
+
def build_pie_chart(self, label_data, value_data):
|
| 198 |
+
return self._build_xy_data(label_data, value_data)
|
| 199 |
+
|
| 200 |
+
def build_histogram(self, data):
|
| 201 |
+
range_ = (data.min(), data.max())
|
| 202 |
+
counts, bins = np.histogram(data, bins=50, range=range_)
|
| 203 |
+
|
| 204 |
+
return [
|
| 205 |
+
{
|
| 206 |
+
"bin_start": bins[i],
|
| 207 |
+
"bin_end": bins[i + 1],
|
| 208 |
+
"frequency": counts[i],
|
| 209 |
+
}
|
| 210 |
+
for i in range(len(counts))
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
def format_and_select_chart(self, df: pd.DataFrame):
|
| 214 |
+
cols = df.columns.tolist()
|
| 215 |
+
dtypes = self.detect_dtype(df)
|
| 216 |
+
|
| 217 |
+
if len(cols) == 1:
|
| 218 |
+
col = cols[0]
|
| 219 |
+
dtype = dtypes[col]
|
| 220 |
+
|
| 221 |
+
if isinstance(dtype, Continuous):
|
| 222 |
+
return "hist", self.build_histogram(df[col].dropna()), dtypes
|
| 223 |
+
|
| 224 |
+
if isinstance(dtype, (SmallCardNum, Nominal)):
|
| 225 |
+
counts = df[col].value_counts()
|
| 226 |
+
chart = "pie" if counts.size <= 6 else "bar"
|
| 227 |
+
builder = (
|
| 228 |
+
self.build_pie_chart if chart == "pie" else self.build_bar_chart
|
| 229 |
+
)
|
| 230 |
+
return chart, builder(counts.index, counts.values), dtypes
|
| 231 |
+
|
| 232 |
+
if len(cols) == 2:
|
| 233 |
+
x, y = cols
|
| 234 |
+
dtype_x = dtypes[x]
|
| 235 |
+
dtype_y = dtypes[y]
|
| 236 |
+
data_x = df[x]
|
| 237 |
+
data_y = df[y]
|
| 238 |
+
|
| 239 |
+
if {type(dtype_x), type(dtype_y)} == {Nominal, Continuous}:
|
| 240 |
+
label, value = (
|
| 241 |
+
(data_x, data_y)
|
| 242 |
+
if isinstance(dtype_x, Nominal)
|
| 243 |
+
else (data_y, data_x)
|
| 244 |
+
)
|
| 245 |
+
formatted_data = self.build_bar_chart(label, value)
|
| 246 |
+
return "bar", formatted_data, dtypes
|
| 247 |
+
|
| 248 |
+
elif {type(dtype_x), type(dtype_y)} == {Continuous, Continuous}:
|
| 249 |
+
label, value = (
|
| 250 |
+
(data_x, data_y)
|
| 251 |
+
if isinstance(dtype_x, Continuous)
|
| 252 |
+
else (data_y, data_x)
|
| 253 |
+
)
|
| 254 |
+
formatted_data = self.build_bar_chart(label, value)
|
| 255 |
+
return "bar", formatted_data, dtypes
|
| 256 |
+
|
| 257 |
+
elif {type(dtype_x), type(dtype_y)} == {SmallCardNum, Continuous}:
|
| 258 |
+
label, value = (
|
| 259 |
+
(data_x, data_y)
|
| 260 |
+
if isinstance(dtype_x, SmallCardNum)
|
| 261 |
+
else (data_y, data_x)
|
| 262 |
+
)
|
| 263 |
+
formatted_data = self.build_bar_chart(label, value)
|
| 264 |
+
return "bar", formatted_data, dtypes
|
| 265 |
+
|
| 266 |
+
elif isinstance(dtype_x, SmallCardNum) and isinstance(
|
| 267 |
+
dtype_y, SmallCardNum
|
| 268 |
+
):
|
| 269 |
+
formatted_data = self.build_bar_chart(data_x, data_y)
|
| 270 |
+
return "bar", formatted_data, dtypes
|
| 271 |
+
|
| 272 |
+
elif {type(dtype_x), type(dtype_y)} == {DateTime, Continuous}:
|
| 273 |
+
label, value = (
|
| 274 |
+
(data_x, data_y)
|
| 275 |
+
if isinstance(dtype_x, DateTime)
|
| 276 |
+
else (data_y, data_x)
|
| 277 |
+
)
|
| 278 |
+
formatted_data = self.build_line_chart(label, value)
|
| 279 |
+
return "line", formatted_data, dtypes
|
| 280 |
+
|
| 281 |
+
elif (
|
| 282 |
+
isinstance(dtype_x, DateTime) and isinstance(dtype_y, SmallCardNum)
|
| 283 |
+
) or (isinstance(dtype_y, DateTime) and isinstance(dtype_x, SmallCardNum)):
|
| 284 |
+
label, value = (
|
| 285 |
+
(data_x, data_y)
|
| 286 |
+
if isinstance(dtype_x, DateTime)
|
| 287 |
+
else (data_y, data_x)
|
| 288 |
+
)
|
| 289 |
+
formatted_data = self.build_line_chart(label, value)
|
| 290 |
+
return "line", formatted_data, dtypes
|
| 291 |
+
|
| 292 |
+
elif {type(dtype_x), type(dtype_y)} == {Nominal, SmallCardNum}:
|
| 293 |
+
label, value = (
|
| 294 |
+
(data_x, data_y)
|
| 295 |
+
if isinstance(dtype_x, Nominal)
|
| 296 |
+
else (data_y, data_x)
|
| 297 |
+
)
|
| 298 |
+
formatted_data = self.build_bar_chart(label, value)
|
| 299 |
+
return "bar", formatted_data, dtypes
|
| 300 |
+
|
| 301 |
+
return None, None, None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class SQLVizChain:
|
| 305 |
+
def __init__(self, duckdb: DuckDBPyConnection, chain):
|
| 306 |
+
self._duckdb = duckdb
|
| 307 |
+
self.chain = chain
|
| 308 |
+
self.router = QueryRouter(chain=self.chain)
|
| 309 |
+
self.sql_generator = SQLPipeline(duckdb, chain=self.chain)
|
| 310 |
+
self.charting = ChartFormatter()
|
| 311 |
+
|
| 312 |
+
def create_chart_config(
|
| 313 |
+
self, query_df: pd.DataFrame, user_question: str, sql: str
|
| 314 |
+
) -> tuple[list[dict[Any, Any]] | None, dict[str, Any] | None, str | None]:
|
| 315 |
+
"""Format data for visualization and return chart config."""
|
| 316 |
+
(
|
| 317 |
+
chart_type,
|
| 318 |
+
formatted_data,
|
| 319 |
+
dtypes,
|
| 320 |
+
) = self.charting.format_and_select_chart(df=query_df)
|
| 321 |
+
|
| 322 |
+
if not all([formatted_data, dtypes, chart_type]):
|
| 323 |
+
return None, None, None
|
| 324 |
+
|
| 325 |
+
chart_config = self.chain.run(
|
| 326 |
+
system_prompt=CHART_CONFIG_SYSTEM_PROMPT,
|
| 327 |
+
user_prompt=CHART_CONFIG_USER_PROMPT.format(
|
| 328 |
+
question=user_question,
|
| 329 |
+
sql_query=sql,
|
| 330 |
+
dtypes=dtypes,
|
| 331 |
+
chart_type=chart_type,
|
| 332 |
+
),
|
| 333 |
+
format_name="chart_config",
|
| 334 |
+
response_format=PlotConfig,
|
| 335 |
+
)
|
| 336 |
+
logger.info(f"Chart Config Generated: {chart_config}")
|
| 337 |
+
return formatted_data, chart_config, chart_type
|
| 338 |
+
|
| 339 |
+
def create_viz_with_text_response(
|
| 340 |
+
self, query_df: pd.DataFrame, user_question: str, sql_config: dict[Any, Any]
|
| 341 |
+
) -> dict[str, Any]:
|
| 342 |
+
formatted_data, chart_config, chart_type = self.create_chart_config(
|
| 343 |
+
query_df, user_question, sql_config["sql_query"]
|
| 344 |
+
)
|
| 345 |
+
table_data = TableData(data=query_df)
|
| 346 |
+
|
| 347 |
+
if not all([formatted_data, chart_config, chart_type]):
|
| 348 |
+
logger.info("Failed to format data or generate chart config.")
|
| 349 |
+
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
|
| 350 |
+
return {
|
| 351 |
+
"chart_data": table_data,
|
| 352 |
+
"chart_config": None,
|
| 353 |
+
"chart_type": None,
|
| 354 |
+
"sql_config": sql_config,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
chart_data = Data.validate_data(data=formatted_data)
|
| 358 |
+
|
| 359 |
+
if chart_config and chart_config["type"] in {"bar", "line", "pie", "hist"}:
|
| 360 |
+
data = Charts(**{chart_config["type"]: chart_data})
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Invalid Plot Type. Must be one of 'bar', 'line', 'pie', 'hist'"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
logger.info("Visualization Chain Completed Successfully.")
|
| 367 |
+
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
|
| 368 |
+
return {
|
| 369 |
+
"chart_data": data,
|
| 370 |
+
"chart_config": chart_config,
|
| 371 |
+
"chart_type": chart_type,
|
| 372 |
+
"sql_config": sql_config,
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
def run(self, user_question: str, context: str) -> dict[str, Any]:
|
| 376 |
+
"""Main pipeline: question → SQL → data → chart config."""
|
| 377 |
+
route = self.router.route_request(user_question=user_question, context=context)
|
| 378 |
+
if route == 0:
|
| 379 |
+
return {
|
| 380 |
+
"chart_data": None,
|
| 381 |
+
"chart_config": None,
|
| 382 |
+
"chart_type": None,
|
| 383 |
+
"sql_config": None,
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
sql_config, query_df = self.sql_generator.try_sql_with_retries(
|
| 387 |
+
user_question=user_question, context=context
|
| 388 |
+
)
|
| 389 |
+
if sql_config is None or query_df is None:
|
| 390 |
+
logger.info("Failed to generate or execute SQL after retries.")
|
| 391 |
+
logger.info(f"Total Token Counts: {self.chain.total_tokens}")
|
| 392 |
+
return {
|
| 393 |
+
"chart_data": None,
|
| 394 |
+
"chart_config": None,
|
| 395 |
+
"chart_type": None,
|
| 396 |
+
"sql_config": None,
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
return self.create_viz_with_text_response(
|
| 400 |
+
query_df=query_df, user_question=user_question, sql_config=sql_config
|
| 401 |
+
)
|
src/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def plot_chart(
|
| 5 |
+
chart_type,
|
| 6 |
+
data,
|
| 7 |
+
title=None,
|
| 8 |
+
x_axis_label=None,
|
| 9 |
+
y_axis_label=None,
|
| 10 |
+
**kwargs,
|
| 11 |
+
):
|
| 12 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 13 |
+
|
| 14 |
+
if chart_type in {"bar", "line", "pie"}:
|
| 15 |
+
if data.shape[1] < 2:
|
| 16 |
+
raise ValueError("DataFrame must have at least two columns")
|
| 17 |
+
data = data.head(20)
|
| 18 |
+
x = data.iloc[:, 0]
|
| 19 |
+
y = data.iloc[:, 1]
|
| 20 |
+
|
| 21 |
+
if chart_type == "bar":
|
| 22 |
+
ax.bar(x, y)
|
| 23 |
+
ax.set_xlabel(x_axis_label or data.columns[0])
|
| 24 |
+
ax.set_ylabel(y_axis_label or data.columns[1])
|
| 25 |
+
|
| 26 |
+
elif chart_type == "line":
|
| 27 |
+
ax.plot(x, y, marker="o")
|
| 28 |
+
ax.set_xlabel(x_axis_label or data.columns[0])
|
| 29 |
+
ax.set_ylabel(y_axis_label or data.columns[1])
|
| 30 |
+
|
| 31 |
+
elif chart_type == "pie":
|
| 32 |
+
ax.pie(y, labels=x, autopct="%1.1f%%")
|
| 33 |
+
ax.axis("equal")
|
| 34 |
+
|
| 35 |
+
elif chart_type == "hist":
|
| 36 |
+
if data.shape[1] < 3:
|
| 37 |
+
raise ValueError("Histogram DataFrame must have 3 columns")
|
| 38 |
+
|
| 39 |
+
bin_start = data.iloc[:, 0]
|
| 40 |
+
bin_end = data.iloc[:, 1]
|
| 41 |
+
frequency = data.iloc[:, 2]
|
| 42 |
+
|
| 43 |
+
widths = bin_end - bin_start
|
| 44 |
+
ax.bar(bin_start, frequency, width=widths, align="edge")
|
| 45 |
+
|
| 46 |
+
ax.set_xlabel(x_axis_label or "Value Range")
|
| 47 |
+
ax.set_ylabel(y_axis_label or "Frequency")
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
plt.close(fig)
|
| 51 |
+
raise ValueError(f"Unsupported chart type: {chart_type}")
|
| 52 |
+
|
| 53 |
+
if title:
|
| 54 |
+
ax.set_title(title)
|
| 55 |
+
|
| 56 |
+
fig.tight_layout()
|
| 57 |
+
return fig
|