Muhammad Mustehson commited on
Commit
4a84072
·
1 Parent(s): 332cf4d

Update Old Code

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
__pycache__/prompt.cpython-311.pyc DELETED
Binary file (2.63 kB)
 
app.py CHANGED
@@ -1,47 +1,35 @@
 
1
  import os
 
2
  import duckdb
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
- from transformers import HfEngine, ReactCodeAgent
6
- from transformers.agents import Tool
7
- from langsmith import traceable
8
- from langchain import hub
9
 
 
 
 
 
10
 
11
- # Height of the Tabs Text Area
 
 
12
  TAB_LINES = 8
13
 
 
 
 
 
 
14
 
15
- #----------CONNECT TO DATABASE----------
16
- md_token = os.getenv('MD_TOKEN')
17
- conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
18
- #---------------------------------------
19
-
20
- #-------LOAD HUGGINGFACE MODEL-------
21
- models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
22
- "meta-llama/Llama-3.1-70B-Instruct"]
23
 
24
- model_loaded = False
25
- for model in models:
26
- try:
27
- llm_engine = HfEngine(model=model)
28
- info = llm_engine.client.get_endpoint_info()
29
- model_loaded = True
30
- break
31
- except Exception as e:
32
- print(f"Error for model {model}: {e}")
33
- continue
34
 
35
- if not model_loaded:
36
- gr.Warning(f"❌ None of the model form {models} are available. {e}")
37
- #---------------------------------------
38
 
39
- #-----LOAD PROMPT FROM LANCHAIN HUB-----
40
- prompt = hub.pull("viz-prompt")
41
- #-------------------------------------
42
 
43
 
44
- #--------------ALL UTILS----------------
45
  def get_schemas():
46
  schemas = conn.execute("""
47
  SELECT DISTINCT schema_name
@@ -50,22 +38,26 @@ def get_schemas():
50
  """).fetchall()
51
  return [item[0] for item in schemas]
52
 
53
- # Get Tables
54
  def get_tables(schema_name):
55
- tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
 
 
56
  return [table[0] for table in tables]
57
 
58
- # Update Tables
59
  def update_tables(schema_name):
60
  tables = get_tables(schema_name)
61
  return gr.update(choices=tables)
62
 
63
- # Get Schema
64
  def get_table_schema(table):
65
- result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
66
- ddl_create = result.iloc[0,0]
67
- parent_database = result.iloc[0,1]
68
- schema_name = result.iloc[0,2]
 
 
69
  full_path = f"{parent_database}.{schema_name}.{table}"
70
  if schema_name != "main":
71
  old_path = f"{schema_name}.{table}"
@@ -75,62 +67,39 @@ def get_table_schema(table):
75
  return ddl_create, full_path
76
 
77
 
78
- class SQLExecutorTool(Tool):
79
- name = "sql_engine"
80
- inputs = {
81
- "query": {
82
- "type": "text",
83
- "description": f"The query to perform. This should be correct DuckDB SQL.",
84
- }
85
- }
86
- description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
87
- output_type = "pandas.core.frame.DataFrame"
88
-
89
- def forward(self, query: str) -> str:
90
- output_df = conn.sql(query).df()
91
- return output_df
92
-
93
- tool = SQLExecutorTool()
94
-
95
- def process_outputs(output) :
96
- return {
97
- 'sql': output.get('sql', None),
98
- 'code': output.get('code', None)
99
- }
100
-
101
- @traceable(process_outputs=process_outputs)
102
- def get_visualization(question, schema, table_name):
103
- agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
104
- additional_authorized_imports=['matplotlib.pyplot',
105
- 'pandas', 'plotly.express',
106
- 'seaborn'], max_iterations=10)
107
- results = agent.run(
108
- task= prompt.format(question=question, schema=schema, table_name=table_name)
109
- )
110
-
111
- return results
112
- #---------------------------------------
113
-
114
-
115
  def main(table, text_query):
116
- # Empty Fig
117
  fig, ax = plt.subplots()
118
- ax.set_axis_off()
119
-
120
- schema, table_name = get_table_schema(table)
121
-
122
  try:
123
- output = get_visualization(question=text_query, schema=schema, table_name=table_name)
124
- fig = output.get('fig', None)
125
- generated_sql = output.get('sql', None)
126
- data = output.get('data', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
 
129
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
130
-
131
- return fig, generated_sql, data
132
-
133
-
134
 
135
  custom_css = """
136
  .gradio-container {
@@ -150,7 +119,9 @@ custom_css = """
150
  }
151
  """
152
 
153
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
 
 
154
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
155
 
156
  gr.Markdown("""
@@ -162,13 +133,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
162
  """)
163
 
164
  with gr.Row():
165
-
166
  with gr.Column(scale=1):
167
- schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
168
- tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
 
 
 
 
169
 
170
  with gr.Column(scale=2):
171
- query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...")
 
 
172
  with gr.Row():
173
  with gr.Column(scale=7):
174
  pass
@@ -178,18 +154,25 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
178
  with gr.Tabs():
179
  with gr.Tab("Plot"):
180
  result_plot = gr.Plot()
181
- with gr.Tab("SQL"):
182
- generated_sql = gr.Textbox(lines=TAB_LINES, label="Generated SQL", value="", interactive=False,
183
- autoscroll=False)
184
- with gr.Tab("Data"):
 
 
 
 
 
185
  data = gr.Dataframe(label="Data", interactive=False)
186
 
187
- schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
188
- generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot, generated_sql, data])
 
 
 
 
 
 
189
 
190
  if __name__ == "__main__":
191
  demo.launch(debug=True)
192
-
193
-
194
-
195
-
 
1
+ import logging
2
  import os
3
+
4
  import duckdb
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
+ import pandas as pd
 
 
 
8
 
9
+ from src.client import LLMChain
10
+ from src.models import Charts, TableData
11
+ from src.pipelines import SQLVizChain
12
+ from src.utils import plot_chart
13
 
14
+ MD_TOKEN = os.getenv("MD_TOKEN")
15
+ conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True)
16
+ LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
17
  TAB_LINES = 8
18
 
19
+ logging.basicConfig(
20
+ level=getattr(logging, LEVEL, logging.INFO),
21
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
22
+ )
23
+ logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
 
 
 
25
 
26
+ def _load_pipeline():
27
+ return SQLVizChain(duckdb=conn, chain=LLMChain())
 
 
 
 
 
 
 
 
28
 
 
 
 
29
 
30
+ pipeline = _load_pipeline()
 
 
31
 
32
 
 
33
  def get_schemas():
34
  schemas = conn.execute("""
35
  SELECT DISTINCT schema_name
 
38
  """).fetchall()
39
  return [item[0] for item in schemas]
40
 
41
+
42
  def get_tables(schema_name):
43
+ tables = conn.execute(
44
+ f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'"
45
+ ).fetchall()
46
  return [table[0] for table in tables]
47
 
48
+
49
  def update_tables(schema_name):
50
  tables = get_tables(schema_name)
51
  return gr.update(choices=tables)
52
 
53
+
54
  def get_table_schema(table):
55
+ result = conn.sql(
56
+ f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
57
+ ).df()
58
+ ddl_create = result.iloc[0, 0]
59
+ parent_database = result.iloc[0, 1]
60
+ schema_name = result.iloc[0, 2]
61
  full_path = f"{parent_database}.{schema_name}.{table}"
62
  if schema_name != "main":
63
  old_path = f"{schema_name}.{table}"
 
67
  return ddl_create, full_path
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def main(table, text_query):
 
71
  fig, ax = plt.subplots()
72
+ ax.set_axis_off()
73
+ schema, _ = get_table_schema(table)
74
+
 
75
  try:
76
+ results = pipeline.run(user_question=text_query, context=schema)
77
+ chart_data = results["chart_data"]
78
+ chart_config = results["chart_config"]
79
+ chart_type = results["chart_type"]
80
+ generated_sql = results["sql_config"]["sql_query"]
81
+
82
+ if not chart_type and chart_data is not None:
83
+ if isinstance(chart_data, TableData):
84
+ data = pd.DataFrame(chart_data.model_dump(exclude_none=True))
85
+ return (fig, generated_sql, data)
86
+
87
+ if chart_type is not None and chart_data is not None:
88
+ if isinstance(chart_data, Charts):
89
+ chart_dict = chart_data.model_dump(exclude_none=True).get(chart_type)
90
+ data = pd.DataFrame(chart_dict["data"])
91
+ fig = plot_chart(chart_type=chart_type, data=data, **chart_config)
92
+ return (fig, generated_sql, data)
93
+
94
+ if chart_data is None:
95
+ return fig, generated_sql, None
96
 
97
  except Exception as e:
98
+ logger.error(e)
99
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
100
+
101
+ return fig, None, None
102
+
 
103
 
104
  custom_css = """
105
  .gradio-container {
 
119
  }
120
  """
121
 
122
+ with gr.Blocks(
123
+ theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
124
+ ) as demo:
125
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
126
 
127
  gr.Markdown("""
 
133
  """)
134
 
135
  with gr.Row():
 
136
  with gr.Column(scale=1):
137
+ schema_dropdown = gr.Dropdown(
138
+ choices=get_schemas(), label="Select Schema", interactive=True
139
+ )
140
+ tables_dropdown = gr.Dropdown(
141
+ choices=[], label="Available Tables", value=None
142
+ )
143
 
144
  with gr.Column(scale=2):
145
+ query_input = gr.Textbox(
146
+ lines=3, label="Text Query", placeholder="Enter your text query here..."
147
+ )
148
  with gr.Row():
149
  with gr.Column(scale=7):
150
  pass
 
154
  with gr.Tabs():
155
  with gr.Tab("Plot"):
156
  result_plot = gr.Plot()
157
+ with gr.Tab("SQL"):
158
+ generated_sql = gr.Textbox(
159
+ lines=TAB_LINES,
160
+ label="Generated SQL",
161
+ value="",
162
+ interactive=False,
163
+ autoscroll=False,
164
+ )
165
+ with gr.Tab("Data"):
166
  data = gr.Dataframe(label="Data", interactive=False)
167
 
168
+ schema_dropdown.change(
169
+ update_tables, inputs=schema_dropdown, outputs=tables_dropdown
170
+ )
171
+ generate_query_button.click(
172
+ main,
173
+ inputs=[tables_dropdown, query_input],
174
+ outputs=[result_plot, generated_sql, data],
175
+ )
176
 
177
  if __name__ == "__main__":
178
  demo.launch(debug=True)
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
- torch
2
- seaborn
3
- plotly
4
  huggingface_hub
5
- accelerate==0.34.2
6
- transformers==4.44.2
7
- duckdb==1.1.1
8
- langsmith==0.1.135
9
- langchain==0.3.4
 
 
 
 
 
 
1
  huggingface_hub
2
+ duckdb
3
+ pandas
4
+ pydantic
5
+ python-dotenv
6
+ gradio
7
+ pandas
8
+ matplotlib
src/__init__.py ADDED
File without changes
src/client.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import InferenceClient
7
+ from pydantic import BaseModel
8
+
9
+ load_dotenv()
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ MAX_RESPONSE_TOKENS = 2048
14
+ TEMPERATURE = 0.9
15
+
16
+ models = json.loads(os.getenv("MODEL_NAMES"))
17
+ providers = json.loads(os.getenv("PROVIDERS"))
18
+ EMB_MODEL = os.getenv("EMB_MODEL")
19
+
20
+
21
+ def _engine_working(engine: InferenceClient) -> bool:
22
+ try:
23
+ engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
24
+ logger.info("Engine is Working.")
25
+ return True
26
+ except Exception as e:
27
+ logger.exception(f"Engine is not working: {e}")
28
+ return False
29
+
30
+
31
+ def _load_llm_client() -> InferenceClient:
32
+ """
33
+ Attempts to load the provided model from the huggingface endpoint.
34
+
35
+ Returns InferenceClient if successful.
36
+ Raises Exception if no model is available.
37
+ """
38
+ logger.warning("Loading Model...")
39
+ errors = []
40
+ for model in models:
41
+ for provider in providers:
42
+ if isinstance(model, str):
43
+ try:
44
+ logger.info(f"Checking model: {model} provider: {provider}")
45
+ client = InferenceClient(
46
+ model=model,
47
+ timeout=15,
48
+ provider=provider,
49
+ )
50
+ if _engine_working(client):
51
+ logger.info(
52
+ f"The model is loaded : {model} , provider: {provider}"
53
+ )
54
+ return client
55
+ except Exception as e:
56
+ logger.error(
57
+ f"Error loading model {model} provider {provider}: {e}"
58
+ )
59
+ errors.append(str(e))
60
+ raise Exception(f"Unable to load any provided model: {errors}.")
61
+
62
+
63
+ _default_client = _load_llm_client()
64
+
65
+
66
+ class LLMChain:
67
+ def __init__(self, client: InferenceClient = _default_client):
68
+ self.client = client
69
+ self.total_tokens = 0
70
+
71
+ def run(
72
+ self,
73
+ system_prompt: str | None = None,
74
+ user_prompt: str | None = None,
75
+ messages: list[dict] | None = None,
76
+ format_name: str | None = None,
77
+ response_format: type[BaseModel] | None = None,
78
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
79
+ try:
80
+ if system_prompt and user_prompt:
81
+ messages = [
82
+ {"role": "system", "content": system_prompt},
83
+ {"role": "user", "content": user_prompt},
84
+ ]
85
+ elif not messages:
86
+ raise ValueError(
87
+ "Either system_prompt and user_prompt or messages must be provided."
88
+ )
89
+
90
+ llm_response = self.client.chat_completion(
91
+ messages=messages,
92
+ max_tokens=MAX_RESPONSE_TOKENS,
93
+ temperature=TEMPERATURE,
94
+ response_format=(
95
+ {
96
+ "type": "json_schema",
97
+ "json_schema": {
98
+ "name": format_name,
99
+ "schema": response_format.model_json_schema(),
100
+ "strict": True,
101
+ },
102
+ }
103
+ if format_name and response_format
104
+ else None
105
+ ),
106
+ )
107
+ self.total_tokens += llm_response.usage.total_tokens
108
+ analysis = llm_response.choices[0].message.content
109
+ if response_format:
110
+ analysis = json.loads(analysis)
111
+ fields = list(response_format.model_fields.keys())
112
+ if len(fields) == 1:
113
+ return analysis.get(fields[0])
114
+ return {field: analysis.get(field) for field in fields}
115
+
116
+ return analysis
117
+
118
+ except Exception as e:
119
+ logger.error(f"Error during LLM calls: {e}")
120
+ return None
src/models.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from enum import Enum
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
7
+
8
+
9
+ class SmallCardNum:
10
+ pass
11
+
12
+
13
+ class Continuous:
14
+ pass
15
+
16
+
17
+ class DateTime:
18
+ pass
19
+
20
+
21
+ class Nominal:
22
+ pass
23
+
24
+
25
+ class Route(BaseModel):
26
+ label: int = Field(
27
+ description="Classify user queries as: 0 for Irrelevant/Vague/Incomplete, 1 for Visualizable, and 2 for SQL-only."
28
+ )
29
+
30
+
31
+ class SQLQueryModel(BaseModel):
32
+ sql_query: str = Field(..., description="SQL query to execute.")
33
+ explanation: str = Field(..., description="Short explanation of the SQL query.")
34
+
35
+
36
+ class DataPoint(BaseModel):
37
+ x: int | float | str | None = None
38
+ y: int | float | str | None = None
39
+ bin_start: int | float | None = None
40
+ bin_end: int | float | None = None
41
+ frequency: int | float | None = None
42
+
43
+ @field_validator("bin_start", "bin_end", "frequency", "x", "y", mode="before")
44
+ @classmethod
45
+ def to_native(cls, field_value):
46
+ if field_value is not None and isinstance(
47
+ field_value, np.float64 | np.float32 | np.int64
48
+ ):
49
+ return float(field_value)
50
+ if isinstance(field_value, (datetime, np.datetime64, pd.Timestamp)): # noqa: UP038
51
+ return field_value.strftime("%Y-%m-%d")
52
+ return field_value
53
+
54
+ @model_validator(mode="before")
55
+ @classmethod
56
+ def validate_keys(cls, values):
57
+ x, y = values.get("x"), values.get("y")
58
+ bin_start, bin_end, frequency = (
59
+ values.get("bin_start"),
60
+ values.get("bin_end"),
61
+ values.get("frequency"),
62
+ )
63
+
64
+ xy = x is not None and y is not None
65
+ bxy = bin_start is not None and bin_end is not None and frequency is not None
66
+
67
+ if not (xy or bxy):
68
+ raise ValueError(
69
+ "Invalid input: Must provide either (x, y) OR (bin_start, bin_end, frequency), but not a mix."
70
+ )
71
+ return values
72
+
73
+
74
+ class Data(BaseModel):
75
+ data: list[DataPoint] = Field(default_factory=list)
76
+
77
+ @classmethod
78
+ def validate_data(cls, data):
79
+ try:
80
+ return cls(data=data)
81
+ except ValidationError as e:
82
+ raise ValueError(f"Invalid data format: {e.errors()[0]}") # noqa: B904
83
+
84
+
85
+ class TableData(BaseModel):
86
+ data: pd.DataFrame = Field(default_factory=None)
87
+
88
+ class Config:
89
+ arbitrary_types_allowed = True
90
+
91
+ @model_validator(mode="after")
92
+ def timestamp_to_str(self):
93
+ # Convert all datetime columns to string format
94
+ for col in self.data.select_dtypes(include=["datetime"]).columns:
95
+ if col:
96
+ self.data[col] = self.data[col].astype(str)
97
+ return self
98
+
99
+ def model_dump(self, *args, **kwargs): # noqa: ARG002
100
+ return self.data.to_dict(orient="list")
101
+
102
+
103
+ class Charts(BaseModel):
104
+ bar: Data | None = None
105
+ line: Data | None = None
106
+ pie: Data | None = None
107
+ hist: Data | None = None
108
+
109
+ @model_validator(mode="after")
110
+ def process_charts_data(self):
111
+ def stringify(data):
112
+ if data and data.data:
113
+ for point in data.data:
114
+ if not isinstance(point.x, str):
115
+ point.x = str(point.x)
116
+ return data
117
+
118
+ if self.bar:
119
+ self.bar = stringify(self.bar)
120
+ if self.pie:
121
+ self.pie = stringify(self.pie)
122
+
123
+ return self
124
+
125
+
126
+ class PlotType(str, Enum):
127
+ bar = ("bar",)
128
+ line = ("line",)
129
+ pie = ("pie",)
130
+ hist = ("hist",)
131
+
132
+
133
+ class PlotConfig(BaseModel):
134
+ type: PlotType = Field(
135
+ description="Type of plot, e.g., 'bar', 'line', 'pie'. Supported types depend on ShadCN implementation.",
136
+ )
137
+ title: str = Field(description="Title of the plot to display above the plot.")
138
+ x_axis_label: str = Field(description="Label for the X-axis of the plot.")
139
+ y_axis_label: str = Field(description="Label for the Y-axis of the plot.")
140
+ legend: bool = Field(
141
+ default=True, description="Flag to display a legend for the plot."
142
+ )
src/pipelines.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from dotenv import load_dotenv
8
+ from duckdb import DuckDBPyConnection
9
+
10
+ from src.models import (
11
+ Charts,
12
+ Continuous,
13
+ Data,
14
+ DateTime,
15
+ Nominal,
16
+ PlotConfig,
17
+ Route,
18
+ SmallCardNum,
19
+ SQLQueryModel,
20
+ TableData,
21
+ )
22
+
23
+ load_dotenv()
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ MAX_BARS_COUNT = 20
30
+ SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
31
+ SQL_PROMPT = os.getenv("SQL_PROMPT")
32
+ USER_PROMPT = os.getenv("USER_PROMPT")
33
+ ROUTER_SYSTEM_PROMPT = os.getenv("ROUTER_SYSTEM_PROMPT")
34
+ CHART_CONFIG_SYSTEM_PROMPT = os.getenv("CHART_CONFIG_SYSTEM_PROMPT")
35
+ CHART_CONFIG_USER_PROMPT = os.getenv("CHART_CONFIG_USER_PROMPT")
36
+
37
+
38
+ class SQLPipeline:
39
+ def __init__(
40
+ self,
41
+ duckdb: DuckDBPyConnection,
42
+ chain,
43
+ ) -> None:
44
+ self._duckdb = duckdb
45
+ self.chain = chain
46
+
47
+ def generate_sql(
48
+ self, user_question: str, context: str, errors: str | None = None
49
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
50
+ """Generate SQL + description."""
51
+ user_prompt_formatted = USER_PROMPT.format(
52
+ question=user_question, context=context
53
+ )
54
+ if errors:
55
+ user_prompt_formatted += f"Carefully review the previous error or\
56
+ exception and rewrite the SQL so that the error does not occur again.\
57
+ Try a different approach or rewrite SQL if needed. Last error: {errors}"
58
+
59
+ sql = self.chain.run(
60
+ system_prompt=SQL_PROMPT,
61
+ user_prompt=user_prompt_formatted,
62
+ format_name="sql_query",
63
+ response_format=SQLQueryModel,
64
+ )
65
+ logger.info(f"SQL Generated Successfully: {sql}")
66
+ return sql
67
+
68
+ def run_query(self, sql_query: str) -> pd.DataFrame | None:
69
+ """Execute SQL and return dataframe."""
70
+ logger.info("Query Execution Started.")
71
+ return self._duckdb.query(sql_query).df()
72
+
73
+ def try_sql_with_retries(
74
+ self,
75
+ user_question: str,
76
+ context: str,
77
+ max_retries: int = SQL_GENERATION_RETRIES,
78
+ ) -> tuple[
79
+ str | dict[str, str | int | float | None] | list[str] | None,
80
+ pd.DataFrame | None,
81
+ ]:
82
+ """Try SQL generation + execution with retries."""
83
+ last_error = None
84
+ all_errors = ""
85
+
86
+ for attempt in range(
87
+ 1, max_retries + 2
88
+ ): # @ Since the first is normal and not consider in retries
89
+ try:
90
+ if attempt > 1 and last_error:
91
+ logger.info(f"Retrying: {attempt - 1}")
92
+ # Generate SQL
93
+ sql = self.generate_sql(user_question, context, errors=all_errors)
94
+ if not sql:
95
+ return None, None
96
+ else:
97
+ # Generate SQL
98
+ sql = self.generate_sql(user_question, context)
99
+ if not sql:
100
+ return None, None
101
+
102
+ # Try executing query
103
+ sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
104
+ if not isinstance(sql_query_str, str):
105
+ raise ValueError(
106
+ f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
107
+ )
108
+ query_df = self.run_query(sql_query_str)
109
+
110
+ # If execution succeeds, stop retrying or if df is not empty
111
+ if query_df is not None and not query_df.empty:
112
+ return sql, query_df
113
+
114
+ except Exception as e:
115
+ last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
116
+ logger.error(f"Error during SQL generation or execution: {last_error}")
117
+ all_errors += last_error
118
+
119
+ logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
120
+ return None, None
121
+
122
+
123
+ class QueryRouter:
124
+ def __init__(self, chain) -> None:
125
+ self.chain = chain
126
+
127
+ def route_request(self, user_question: str, context: str) -> int:
128
+ """Route the user question to 0, 1, or 2."""
129
+ user_prompt_formatted = USER_PROMPT.format(
130
+ question=user_question, context=context
131
+ )
132
+ route = self.chain.run(
133
+ system_prompt=ROUTER_SYSTEM_PROMPT,
134
+ user_prompt=user_prompt_formatted,
135
+ format_name="route_queries",
136
+ response_format=Route,
137
+ )
138
+ logger.info(
139
+ f"Query routed to: {route} Where if query is routed to 0 its irrelevant, if 1 its visualizable, if 2 its only sql, and 3 if its datetime."
140
+ )
141
+ return route
142
+
143
+
144
+ class ChartFormatter:
145
+ def _build_xy_data(self, label_data, value_data, limit_unique_x=False):
146
+ df = pd.DataFrame({"x": label_data, "y": value_data})
147
+
148
+ if limit_unique_x and df["x"].nunique() > MAX_BARS_COUNT:
149
+ df = df.head(MAX_BARS_COUNT)
150
+
151
+ return df.to_dict(orient="records")
152
+
153
+ def is_continuous(self, dtype) -> bool:
154
+ if pd.api.types.is_bool_dtype(dtype):
155
+ return False
156
+
157
+ return (
158
+ pd.api.types.is_integer_dtype(dtype)
159
+ or pd.api.types.is_float_dtype(dtype)
160
+ or pd.api.types.is_numeric_dtype(dtype)
161
+ )
162
+
163
+ def is_datetime(self, dtype) -> bool:
164
+ return pd.api.types.is_datetime64_any_dtype(
165
+ dtype
166
+ ) or pd.api.types.is_timedelta64_dtype(dtype)
167
+
168
+ def detect_dtype(self, data):
169
+ """Detects dtypes of columns."""
170
+ type_ = {}
171
+ for col_name in data.columns:
172
+ col_data = data[col_name]
173
+ if self.is_continuous(col_data.dtype):
174
+ # detect as categorical if distinct value is small
175
+ if isinstance(col_data, pd.Series):
176
+ nuniques = col_data.nunique()
177
+ else:
178
+ raise TypeError(f"unprocessed column type:{type(col_name)}")
179
+ small_cardinality_threshold = 10
180
+ if nuniques < small_cardinality_threshold:
181
+ type_[col_name] = SmallCardNum()
182
+ else:
183
+ type_[col_name] = Continuous()
184
+
185
+ elif self.is_datetime(col_data.dtype):
186
+ type_[col_name] = DateTime()
187
+ else:
188
+ type_[col_name] = Nominal()
189
+ return type_
190
+
191
+ def build_bar_chart(self, label_data, value_data):
192
+ return self._build_xy_data(label_data, value_data, limit_unique_x=True)
193
+
194
+ def build_line_chart(self, label_data, value_data):
195
+ return self._build_xy_data(label_data, value_data)
196
+
197
+ def build_pie_chart(self, label_data, value_data):
198
+ return self._build_xy_data(label_data, value_data)
199
+
200
+ def build_histogram(self, data):
201
+ range_ = (data.min(), data.max())
202
+ counts, bins = np.histogram(data, bins=50, range=range_)
203
+
204
+ return [
205
+ {
206
+ "bin_start": bins[i],
207
+ "bin_end": bins[i + 1],
208
+ "frequency": counts[i],
209
+ }
210
+ for i in range(len(counts))
211
+ ]
212
+
213
+ def format_and_select_chart(self, df: pd.DataFrame):
214
+ cols = df.columns.tolist()
215
+ dtypes = self.detect_dtype(df)
216
+
217
+ if len(cols) == 1:
218
+ col = cols[0]
219
+ dtype = dtypes[col]
220
+
221
+ if isinstance(dtype, Continuous):
222
+ return "hist", self.build_histogram(df[col].dropna()), dtypes
223
+
224
+ if isinstance(dtype, (SmallCardNum, Nominal)):
225
+ counts = df[col].value_counts()
226
+ chart = "pie" if counts.size <= 6 else "bar"
227
+ builder = (
228
+ self.build_pie_chart if chart == "pie" else self.build_bar_chart
229
+ )
230
+ return chart, builder(counts.index, counts.values), dtypes
231
+
232
+ if len(cols) == 2:
233
+ x, y = cols
234
+ dtype_x = dtypes[x]
235
+ dtype_y = dtypes[y]
236
+ data_x = df[x]
237
+ data_y = df[y]
238
+
239
+ if {type(dtype_x), type(dtype_y)} == {Nominal, Continuous}:
240
+ label, value = (
241
+ (data_x, data_y)
242
+ if isinstance(dtype_x, Nominal)
243
+ else (data_y, data_x)
244
+ )
245
+ formatted_data = self.build_bar_chart(label, value)
246
+ return "bar", formatted_data, dtypes
247
+
248
+ elif {type(dtype_x), type(dtype_y)} == {Continuous, Continuous}:
249
+ label, value = (
250
+ (data_x, data_y)
251
+ if isinstance(dtype_x, Continuous)
252
+ else (data_y, data_x)
253
+ )
254
+ formatted_data = self.build_bar_chart(label, value)
255
+ return "bar", formatted_data, dtypes
256
+
257
+ elif {type(dtype_x), type(dtype_y)} == {SmallCardNum, Continuous}:
258
+ label, value = (
259
+ (data_x, data_y)
260
+ if isinstance(dtype_x, SmallCardNum)
261
+ else (data_y, data_x)
262
+ )
263
+ formatted_data = self.build_bar_chart(label, value)
264
+ return "bar", formatted_data, dtypes
265
+
266
+ elif isinstance(dtype_x, SmallCardNum) and isinstance(
267
+ dtype_y, SmallCardNum
268
+ ):
269
+ formatted_data = self.build_bar_chart(data_x, data_y)
270
+ return "bar", formatted_data, dtypes
271
+
272
+ elif {type(dtype_x), type(dtype_y)} == {DateTime, Continuous}:
273
+ label, value = (
274
+ (data_x, data_y)
275
+ if isinstance(dtype_x, DateTime)
276
+ else (data_y, data_x)
277
+ )
278
+ formatted_data = self.build_line_chart(label, value)
279
+ return "line", formatted_data, dtypes
280
+
281
+ elif (
282
+ isinstance(dtype_x, DateTime) and isinstance(dtype_y, SmallCardNum)
283
+ ) or (isinstance(dtype_y, DateTime) and isinstance(dtype_x, SmallCardNum)):
284
+ label, value = (
285
+ (data_x, data_y)
286
+ if isinstance(dtype_x, DateTime)
287
+ else (data_y, data_x)
288
+ )
289
+ formatted_data = self.build_line_chart(label, value)
290
+ return "line", formatted_data, dtypes
291
+
292
+ elif {type(dtype_x), type(dtype_y)} == {Nominal, SmallCardNum}:
293
+ label, value = (
294
+ (data_x, data_y)
295
+ if isinstance(dtype_x, Nominal)
296
+ else (data_y, data_x)
297
+ )
298
+ formatted_data = self.build_bar_chart(label, value)
299
+ return "bar", formatted_data, dtypes
300
+
301
+ return None, None, None
302
+
303
+
304
+ class SQLVizChain:
305
+ def __init__(self, duckdb: DuckDBPyConnection, chain):
306
+ self._duckdb = duckdb
307
+ self.chain = chain
308
+ self.router = QueryRouter(chain=self.chain)
309
+ self.sql_generator = SQLPipeline(duckdb, chain=self.chain)
310
+ self.charting = ChartFormatter()
311
+
312
+ def create_chart_config(
313
+ self, query_df: pd.DataFrame, user_question: str, sql: str
314
+ ) -> tuple[list[dict[Any, Any]] | None, dict[str, Any] | None, str | None]:
315
+ """Format data for visualization and return chart config."""
316
+ (
317
+ chart_type,
318
+ formatted_data,
319
+ dtypes,
320
+ ) = self.charting.format_and_select_chart(df=query_df)
321
+
322
+ if not all([formatted_data, dtypes, chart_type]):
323
+ return None, None, None
324
+
325
+ chart_config = self.chain.run(
326
+ system_prompt=CHART_CONFIG_SYSTEM_PROMPT,
327
+ user_prompt=CHART_CONFIG_USER_PROMPT.format(
328
+ question=user_question,
329
+ sql_query=sql,
330
+ dtypes=dtypes,
331
+ chart_type=chart_type,
332
+ ),
333
+ format_name="chart_config",
334
+ response_format=PlotConfig,
335
+ )
336
+ logger.info(f"Chart Config Generated: {chart_config}")
337
+ return formatted_data, chart_config, chart_type
338
+
339
+ def create_viz_with_text_response(
340
+ self, query_df: pd.DataFrame, user_question: str, sql_config: dict[Any, Any]
341
+ ) -> dict[str, Any]:
342
+ formatted_data, chart_config, chart_type = self.create_chart_config(
343
+ query_df, user_question, sql_config["sql_query"]
344
+ )
345
+ table_data = TableData(data=query_df)
346
+
347
+ if not all([formatted_data, chart_config, chart_type]):
348
+ logger.info("Failed to format data or generate chart config.")
349
+ logger.info(f"Total Token Counts: {self.chain.total_tokens}")
350
+ return {
351
+ "chart_data": table_data,
352
+ "chart_config": None,
353
+ "chart_type": None,
354
+ "sql_config": sql_config,
355
+ }
356
+
357
+ chart_data = Data.validate_data(data=formatted_data)
358
+
359
+ if chart_config and chart_config["type"] in {"bar", "line", "pie", "hist"}:
360
+ data = Charts(**{chart_config["type"]: chart_data})
361
+ else:
362
+ raise ValueError(
363
+ "Invalid Plot Type. Must be one of 'bar', 'line', 'pie', 'hist'"
364
+ )
365
+
366
+ logger.info("Visualization Chain Completed Successfully.")
367
+ logger.info(f"Total Token Counts: {self.chain.total_tokens}")
368
+ return {
369
+ "chart_data": data,
370
+ "chart_config": chart_config,
371
+ "chart_type": chart_type,
372
+ "sql_config": sql_config,
373
+ }
374
+
375
+ def run(self, user_question: str, context: str) -> dict[str, Any]:
376
+ """Main pipeline: question → SQL → data → chart config."""
377
+ route = self.router.route_request(user_question=user_question, context=context)
378
+ if route == 0:
379
+ return {
380
+ "chart_data": None,
381
+ "chart_config": None,
382
+ "chart_type": None,
383
+ "sql_config": None,
384
+ }
385
+
386
+ sql_config, query_df = self.sql_generator.try_sql_with_retries(
387
+ user_question=user_question, context=context
388
+ )
389
+ if sql_config is None or query_df is None:
390
+ logger.info("Failed to generate or execute SQL after retries.")
391
+ logger.info(f"Total Token Counts: {self.chain.total_tokens}")
392
+ return {
393
+ "chart_data": None,
394
+ "chart_config": None,
395
+ "chart_type": None,
396
+ "sql_config": None,
397
+ }
398
+
399
+ return self.create_viz_with_text_response(
400
+ query_df=query_df, user_question=user_question, sql_config=sql_config
401
+ )
src/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ def plot_chart(
5
+ chart_type,
6
+ data,
7
+ title=None,
8
+ x_axis_label=None,
9
+ y_axis_label=None,
10
+ **kwargs,
11
+ ):
12
+ fig, ax = plt.subplots(figsize=(8, 5))
13
+
14
+ if chart_type in {"bar", "line", "pie"}:
15
+ if data.shape[1] < 2:
16
+ raise ValueError("DataFrame must have at least two columns")
17
+ data = data.head(20)
18
+ x = data.iloc[:, 0]
19
+ y = data.iloc[:, 1]
20
+
21
+ if chart_type == "bar":
22
+ ax.bar(x, y)
23
+ ax.set_xlabel(x_axis_label or data.columns[0])
24
+ ax.set_ylabel(y_axis_label or data.columns[1])
25
+
26
+ elif chart_type == "line":
27
+ ax.plot(x, y, marker="o")
28
+ ax.set_xlabel(x_axis_label or data.columns[0])
29
+ ax.set_ylabel(y_axis_label or data.columns[1])
30
+
31
+ elif chart_type == "pie":
32
+ ax.pie(y, labels=x, autopct="%1.1f%%")
33
+ ax.axis("equal")
34
+
35
+ elif chart_type == "hist":
36
+ if data.shape[1] < 3:
37
+ raise ValueError("Histogram DataFrame must have 3 columns")
38
+
39
+ bin_start = data.iloc[:, 0]
40
+ bin_end = data.iloc[:, 1]
41
+ frequency = data.iloc[:, 2]
42
+
43
+ widths = bin_end - bin_start
44
+ ax.bar(bin_start, frequency, width=widths, align="edge")
45
+
46
+ ax.set_xlabel(x_axis_label or "Value Range")
47
+ ax.set_ylabel(y_axis_label or "Frequency")
48
+
49
+ else:
50
+ plt.close(fig)
51
+ raise ValueError(f"Unsupported chart type: {chart_type}")
52
+
53
+ if title:
54
+ ax.set_title(title)
55
+
56
+ fig.tight_layout()
57
+ return fig