Deepa Shalini commited on
Commit
53782c9
·
1 Parent(s): ccdfc4f

system prompt and helper methods for the app

Browse files
assets/data_viz_best_practices.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Horizontal Bar Graphs for Long X-axis Labels: Use horizontal bar graphs for long X-axis labels or when the X-axis has more than 20 bars.
2
+
3
+ Sort Bar Graphs: When the bar graphs have categorical x-axes, always sort the bar graphs in order of the y-axis values,
4
+ placing the highest y-axis values at the top (for horizontal bars - ascending order) or the left (for vertical bars - descending order).
5
+
6
+ Pie charts should always have a hole in it.
7
+ Limit Donut Chart Slices: Use donut charts sparingly, with 5 or fewer slices, and set the hole size to 0.5.
8
+
9
+ Use Consistent Intervals: Maintain consistent axis intervals to avoid misleading or confusing viewers.
10
+
11
+ Use Line Charts for Trends: Prefer line charts for showing trends over time to effectively illustrate changes and patterns.
12
+
13
+ Label Data Directly: Whenever possible, label data points directly on the chart to reduce the need for users to cross-reference with legends.
14
+
15
+ Limit Dual Y-Axes: Use dual Y-axes only when absolutely necessary, ensuring both axes are clearly labeled to avoid confusion.
16
+
17
+ Stack Bars Appropriately: Use stacked bar charts to show part-to-whole relationships, ensuring that segments are clearly distinguishable.
18
+
19
+ Avoid Overloading Scatter Plots: Limit the number of data points in scatter plots to prevent overcrowding. Consider using heatmaps or summary statistics if necessary.
20
+
21
+ Use Tooltips for Extra Details: Use tooltips to provide additional information without overcrowding the graph.
assets/example_subplots1.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from plotly.subplots import make_subplots
2
+ import plotly.graph_objects as go
3
+
4
+ # Create subplots
5
+ fig = make_subplots(rows=1, cols=2, subplot_titles=('Bar Chart', 'Line Chart'))
6
+
7
+ # Add first trace (Bar chart)
8
+ fig.add_trace(
9
+ go.Bar(x=['A', 'B', 'C'], y=[1, 3, 2]),
10
+ row=1, col=1
11
+ )
12
+
13
+ # Add second trace (Line chart)
14
+ fig.add_trace(
15
+ go.Scatter(x=['A', 'B', 'C'], y=[2, 1, 3]),
16
+ row=1, col=2
17
+ )
18
+
19
+ # Update layout
20
+ fig.update_layout(title_text='Multiple Visualizations in One Figure', plot_bgcolor='white')
21
+
22
+ # Show the figure
23
+ fig.show()
assets/example_subplots2.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from plotly.subplots import make_subplots
2
+ import plotly.graph_objects as go
3
+
4
+ # Create subplots
5
+ fig = make_subplots(rows=2, cols=2, subplot_titles=('Scatter Plot', 'Pie Chart', 'Line Chart', 'Bar Chart'),
6
+ specs=[[{"type": "scatter"}, {"type": "domain"}],
7
+ [{"type": "scatter"}, {"type": "bar"}]])
8
+
9
+ # Add first trace (Scatter plot)
10
+ fig.add_trace(
11
+ go.Scatter(x=[1, 2, 3], y=[4, 5, 6], mode='markers'),
12
+ row=1, col=1
13
+ )
14
+
15
+ # Add second trace (Pie chart)
16
+ fig.add_trace(
17
+ go.Pie(labels=['A', 'B', 'C'], values=[10, 20, 30], hole=0.5),
18
+ row=1, col=2
19
+ )
20
+
21
+ # Add third trace (Line chart)
22
+ fig.add_trace(
23
+ go.Scatter(x=[1, 2, 3], y=[6, 5, 4], mode='lines'),
24
+ row=2, col=1
25
+ )
26
+
27
+ # Add fourth trace (Bar chart)
28
+ fig.add_trace(
29
+ go.Bar(x=['X', 'Y', 'Z'], y=[2, 3, 1]),
30
+ row=2, col=2
31
+ )
32
+
33
+ # Update layout
34
+ fig.update_layout(title_text='Multiple Visualizations Example', plot_bgcolor='white')
35
+
36
+ # Show the figure
37
+ fig.show()
assets/example_subplots3.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from plotly.subplots import make_subplots
2
+ import plotly.graph_objects as go
3
+
4
+ fig = make_subplots(
5
+ rows=2, cols=2,
6
+ specs=[[{"type": "bar"}, {"type": "barpolar"}],
7
+ [{"type": "pie"}, {"type": "scatter3d"}]],
8
+ )
9
+
10
+ fig.add_trace(go.Bar(y=[2, 3, 1]),
11
+ row=1, col=1)
12
+
13
+ fig.add_trace(go.Barpolar(theta=[0, 45, 90], r=[2, 3, 1]),
14
+ row=1, col=2)
15
+
16
+ fig.add_trace(go.Pie(values=[2, 3, 1], hole=0.5),
17
+ row=2, col=1)
18
+
19
+ fig.add_trace(go.Scatter3d(x=[2, 3, 1], y=[0, 0, 0],
20
+ z=[0.5, 1, 2], mode="lines"),
21
+ row=2, col=2)
22
+
23
+ fig.update_layout(height=700, plot_bgcolor='white')
24
+
25
+ fig.show()
utils/chartbot_dataset_layout.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ import dash_ag_grid as dag
3
+ import dash_mantine_components as dmc
4
+ import dash_bootstrap_components as dbc
5
+ from dash import html, dcc, callback, Input, Output, State
6
+
7
+ from utils import components
8
+
9
+ import pandas as pd
10
+
11
+ def chartbot_common(
12
+ page_title: str,
13
+ csv_file_path: str,
14
+ starter_prompt_1_id: str,
15
+ starter_prompt_1: str,
16
+ starter_prompt_2_id: str,
17
+ starter_prompt_2: str,
18
+ prompt_textarea_id: str,
19
+ submit_button_id: str,
20
+ chartbot_output_id: str,
21
+ python_content_id: str
22
+ ) -> tuple:
23
+ df = pd.read_csv(csv_file_path)
24
+
25
+ layout = html.Div([
26
+ dmc.Title(page_title, order=1),
27
+
28
+ html.Br(),
29
+
30
+ dag.AgGrid(
31
+ rowData=df.to_dict("records"),
32
+ columnDefs=[{"field": col} for col in df.columns],
33
+ defaultColDef={"filter": True, "sortable": True, "resizable": True}
34
+ ),
35
+
36
+ html.Br(),
37
+
38
+ dmc.Group([
39
+ components.button_with_prompt(starter_prompt_1_id, starter_prompt_1),
40
+ components.button_with_prompt(starter_prompt_2_id, starter_prompt_2),
41
+ ], justify="center"),
42
+
43
+ html.Br(),
44
+
45
+ dmc.Group([
46
+ dmc.Textarea(placeholder="Type the prompt here ...", id=prompt_textarea_id, size="lg", w=1470),
47
+ dmc.Button("Submit", id=submit_button_id, color="#E71316", className="float-end")
48
+ ]),
49
+
50
+ dcc.Loading([
51
+ html.Div(id=chartbot_output_id),
52
+ dcc.Markdown(id=python_content_id)
53
+ ], type="cube")
54
+
55
+ ], style={'fontFamily': 'Helvetica'})
56
+
57
+ return layout, df
utils/components.py CHANGED
@@ -37,3 +37,9 @@ def summary_card(
37
  dmc.Button("Get Started", color="#E71316", className="float-end")
38
  ], style={"margin": 10})
39
  ], withBorder=True, radius="md", w=320)
 
 
 
 
 
 
 
37
  dmc.Button("Get Started", color="#E71316", className="float-end")
38
  ], style={"margin": 10})
39
  ], withBorder=True, radius="md", w=320)
40
+
41
+ def button_with_prompt(
42
+ identity: str,
43
+ prompt: str
44
+ ) -> dmc.Button:
45
+ return dmc.Button(prompt, id=identity, color="gray", variant="outline", radius="md")
utils/helpers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from dash import html, dcc
4
+
5
+ # ilibraries to help upload files and parse the contents of the files
6
+ import io
7
+ import re
8
+ import base64
9
+
10
+ # libraries to help with the Dash app, layout, and callbacks
11
+ import dash_ag_grid as dag
12
+
13
+ from utils import prompt
14
+
15
+ # Function to get the path of a file in the app source code
16
+ def get_app_file_path(directory_name: str, file_name: str) -> str:
17
+ return os.path.join(os.path.dirname(__file__), "..\\{}".format(directory_name), file_name)
18
+
19
+ # Function to read the content of a file
20
+ def read_doc(file_path: str) -> str:
21
+ file = open(file_path, "r")
22
+ lines = file.readlines()
23
+ file.close()
24
+
25
+ return "".join(lines)
26
+
27
+ # Function to get the figure from the code
28
+ def get_fig_from_code(code, file_name):
29
+ local_variables = {}
30
+
31
+ try:
32
+ exec(code, {}, local_variables)
33
+
34
+ except Exception as e:
35
+ result_output = prompt.get_python_exception_response(code, str(e))
36
+ return display_response(result_output, file_name)
37
+
38
+ return local_variables["fig"]
39
+
40
+ def display_response(response, file_name):
41
+ code_block_match = re.search(r"```(?:[Pp]ython)?(.*?)```", response, re.DOTALL)
42
+ #print(code_block_match)
43
+
44
+ if code_block_match:
45
+ code_block = code_block_match.group(1).strip()
46
+ cleaned_code = re.sub(r'(?m)^\s*fig\.show\(\)\s*$', '', code_block)
47
+ fig = get_fig_from_code(cleaned_code, file_name)
48
+
49
+ return dcc.Graph(figure=fig), response
50
+
51
+ else:
52
+ return "", response
53
+
54
+ # Function to parse the contents of the uploaded file
55
+ def parse_contents(contents, filename):
56
+ _, content_string = contents.split(",")
57
+ decoded = base64.b64decode(content_string)
58
+
59
+ try:
60
+ if 'csv' in filename:
61
+ df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
62
+
63
+ elif 'xls' in filename:
64
+ df = pd.read_excel(io.BytesIO(decoded))
65
+
66
+ except Exception as e:
67
+ print(e)
68
+
69
+ return html.Div([
70
+ "There was an error processing this file."
71
+ ])
72
+
73
+ return html.Div([
74
+ html.H5(filename),
75
+ dag.AgGrid(
76
+ rowData=df.to_dict("records"),
77
+ columnDefs=[{"field": col} for col in df.columns],
78
+ defaultColDef={"filter": True, "sortable": True, "resizable": True},
79
+ ),
80
+ dcc.Store(id='stored-data', data=df.to_dict("records")),
81
+ dcc.Store(id='stored-file-name', data=filename),
82
+
83
+ html.Hr()
84
+ ])
85
+
86
+ # Function to save the dataframe to the current path
87
+ def save_dataframe_to_current_path(df: pd.DataFrame, filename: str) -> None:
88
+ if os.path.exists(filename):
89
+ return
90
+
91
+ if 'csv' in filename:
92
+ df.to_csv(filename, index=False)
93
+
94
+ elif 'xls' in filename:
95
+ df.to_excel(filename, index=False)
utils/prompt.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # libraries to help with the environment variables
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ # libraries to help with the AI model
6
+ from langchain_openai import AzureChatOpenAI
7
+ from langchain_core.messages import HumanMessage
8
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+
10
+ from utils import helpers
11
+
12
+ # get the credentials from .env
13
+ load_dotenv()
14
+ AZURE_KEY = os.getenv('KEY')
15
+ AZURE_ENDPOINT = os.getenv('LLM_ENDPOINT')
16
+ AZURE_NAME = os.getenv('LLM_DEPLOYMENT_NAME')
17
+ AZURE_VERSION = os.getenv('VERSION')
18
+
19
+ # define connectivity to the llm
20
+ llm = AzureChatOpenAI(
21
+ deployment_name=AZURE_NAME,
22
+ openai_api_version=AZURE_VERSION,
23
+ openai_api_key=AZURE_KEY,
24
+ azure_endpoint=AZURE_ENDPOINT
25
+ )
26
+
27
+ '''Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
28
+ handle missing values appropriately based on the context, ensuring cleaner visualizations.
29
+ For example, use df.dropna(subset=[column_name]) for data cleaning. Never use this statement: df.dropna(inplace=True).'''
30
+
31
+ def get_prompt_text() -> str:
32
+ return """You are a data visualization expert and you only use the graphing library Plotly.
33
+ Ensure that before performing any data manipulation or plotting, the code checks for column data types and converts them if necessary.
34
+ For example, numeric columns should be converted to floats or integers using pd.to_numeric(), and non-numeric columns should be excluded from numeric operations.
35
+ Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
36
+ handle missing values appropriately based on the context, ensuring cleaner visualizations.
37
+ For example, use df.dropna(subset=[column_name]) for data cleaning. Never use this statement: df.dropna(inplace=True).
38
+ The graphs you plot shall always have a white background and shall follow data visualization best practices.
39
+ Do not ignore any of the following visualization best practices:
40
+ {data_visualization_best_practices}
41
+ If the user requests a single visualization, create the graph using the plotly.express library and set the fig height to 800.
42
+ Ensure that the graph is clearly labeled with a title, x-axis label, y-axis label, and legend.
43
+ If the user has requested for a choropleth map of the United States of America (USA), ensure that the locations parameter in the px.choropleth() method is
44
+ set to the column which contains the two letter code state abbreviations, for example: AL, NY, TN, VT, UT (the column should not be determined by the name of the column,
45
+ but by the values it contains) and the scope parameter is set to 'usa'.
46
+ If the user requests multiple visualizations, create a subplot for each visualization.
47
+ The libraries required for multiple visualizations are: import plotly.graph_objects as go and from plotly.subplots import make_subplots.
48
+ Utilize the plotly.graph_objects library's make_subplots() method to create subplots, specifying the number of rows and columns,
49
+ and the specs parameter to define what type of graph will be present in each subplot to accommodate all requested visualizations.
50
+ Then, use the add_trace() method to add each graph to the appropriate subplot.
51
+ When generating subplots that include pie charts and xy plots (like bar or scatter), ensure that pie charts are assigned a separate 'domain' subplot type.
52
+ Use the make_subplots() function with the specs argument correctly set for pie charts and other plots.
53
+ For example, use make_subplots(rows=1, cols=2, specs=[[dict(type='domain'), dict(type='xy)]]) for a pie chart and a bar plot.
54
+ Before returning the final code, verify that all trace types are compatible with the assigned subplot types,
55
+ particularly ensuring that pie charts are in domain-type subplots. If an error is detected, correct the subplot type automatically.
56
+ Validate the layout before adding traces.
57
+ Ensure each subplot is clearly labeled and formatted according to best practices.
58
+ All the labels in the graph should be of the font family Helvetica, be it title, x-axis, y-axis, or legend.
59
+ Here are examples of how to create multiple visualizations in a single figure:
60
+ Example 1: \n
61
+ {example_subplots1}
62
+ Example 2: \n
63
+ {example_subplots2}
64
+ Example 3: \n
65
+ {example_subplots3}
66
+ The height of the figure (fig) should be set to 800.
67
+ Suppose that the data is provided as a {name_of_file} file.
68
+ Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
69
+ There should be no natural language text in the python code block."""
70
+
71
+ def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> None:
72
+ prompt = ChatPromptTemplate.from_messages(
73
+ [
74
+ (
75
+ "system",
76
+ get_prompt_text()
77
+ ),
78
+ MessagesPlaceholder(variable_name="messages")
79
+ ]
80
+ )
81
+
82
+ chain = prompt | llm
83
+
84
+ response = chain.invoke(
85
+ {
86
+ "messages": [HumanMessage(content=user_input)],
87
+ "data_visualization_best_practices": helpers.read_doc(helpers.get_app_file_path("assets", "data_viz_best_practices.txt")),
88
+ "example_subplots1": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots1.txt")),
89
+ "example_subplots2": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots2.txt")),
90
+ "example_subplots3": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots3.txt")),
91
+ "data": data_top5_csv_string,
92
+ "name_of_file": file_name
93
+ }
94
+ )
95
+
96
+ return response.content
97
+
98
+ def get_python_exception_prompt_text() -> str:
99
+ return """The Python code you provided {code} has an error {exception}"""
100
+
101
+ def get_python_exception_response(code: str, exception: str) -> None:
102
+ prompt = ChatPromptTemplate.from_messages(
103
+ [
104
+ (
105
+ "system",
106
+ get_python_exception_prompt_text()
107
+ ),
108
+ MessagesPlaceholder(variable_name="messages")
109
+ ]
110
+ )
111
+
112
+ chain = prompt | llm
113
+
114
+ response = chain.invoke(
115
+ {
116
+ "messages": [HumanMessage(content="Rewrite the entire Python code so that it does not contain any errors. The code should be able to run without any errors.")],
117
+ "code": code,
118
+ "exception": exception
119
+ }
120
+ )
121
+
122
+ return response.content