Deepa Shalini commited on
Commit
d04a737
·
1 Parent(s): 723862f

support for dumbbell, choropleth and polar charts

Browse files
app.py CHANGED
@@ -37,8 +37,8 @@ app.layout = dmc.MantineProvider(
37
  className="brand"
38
  ),
39
  html.Button(
40
- "New Chat",
41
- id="new-chat-button",
42
  className="pill",
43
  n_clicks=0,
44
  style={
@@ -281,7 +281,7 @@ def download_html(encoded):
281
  Output("html-buffer", "data", allow_duplicate=True),
282
  Output("submit-button", "disabled", allow_duplicate=True),
283
  Output("upload-data", "contents"),
284
- Input("new-chat-button", "n_clicks"),
285
  prevent_initial_call=True
286
  )
287
  def reset_chat(n_clicks):
 
37
  className="brand"
38
  ),
39
  html.Button(
40
+ "New Chart",
41
+ id="new-chart-button",
42
  className="pill",
43
  n_clicks=0,
44
  style={
 
281
  Output("html-buffer", "data", allow_duplicate=True),
282
  Output("submit-button", "disabled", allow_duplicate=True),
283
  Output("upload-data", "contents"),
284
+ Input("new-chart-button", "n_clicks"),
285
  prevent_initial_call=True
286
  )
287
  def reset_chat(n_clicks):
assets/example_dumbbell_chart.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import pandas as pd
3
+
4
+ # Sample data for dumbbell chart
5
+ countries = ['Country A', 'Country B', 'Country C', 'Country D', 'Country E']
6
+ year_1952 = [65, 68, 70, 72, 75]
7
+ year_2002 = [72, 76, 78, 80, 82]
8
+
9
+ # Prepare line coordinates for connecting dots
10
+ line_x = []
11
+ line_y = []
12
+ for i, country in enumerate(countries):
13
+ line_x.extend([year_1952[i], year_2002[i], None])
14
+ line_y.extend([country, country, None])
15
+
16
+ # Create dumbbell chart
17
+ fig = go.Figure(
18
+ data=[
19
+ # Add connecting lines
20
+ go.Scatter(
21
+ x=line_x,
22
+ y=line_y,
23
+ mode='markers+lines',
24
+ showlegend=False,
25
+ marker=dict(
26
+ symbol="arrow",
27
+ color="black",
28
+ size=16,
29
+ angleref="previous",
30
+ standoff=8
31
+ )
32
+ ),
33
+ # Add first year markers
34
+ go.Scatter(
35
+ x=year_1952,
36
+ y=countries,
37
+ mode='markers',
38
+ name='1952',
39
+ marker=dict(color='green', size=10)
40
+ ),
41
+ # Add second year markers
42
+ go.Scatter(
43
+ x=year_2002,
44
+ y=countries,
45
+ mode='markers',
46
+ name='2002',
47
+ marker=dict(color='blue', size=10)
48
+ ),
49
+ ]
50
+ )
51
+
52
+ # Update layout
53
+ fig.update_layout(
54
+ title='Comparison Between Two Years',
55
+ height=800,
56
+ plot_bgcolor='white',
57
+ legend_itemclick=False
58
+ )
59
+
60
+ # Show the figure
61
+ fig.show()
assets/example_polar_bar.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+
5
+ # Sample data for demonstration (full year 2024)
6
+ data = {
7
+ 'date': pd.date_range('2024-01-01', periods=365, freq='D'),
8
+ 'value': np.random.randint(50, 200, 365)
9
+ }
10
+ df = pd.DataFrame(data)
11
+
12
+ # Extract calendar components
13
+ df['month'] = df['date'].dt.month # 1..12
14
+ # Convert pandas weekday (Monday=0..Sunday=6) to Sun=0..Sat=6
15
+ df['weekday_sun0'] = (df['date'].dt.dayofweek + 1) % 7
16
+
17
+ # Aggregate values by month x weekday
18
+ agg = (
19
+ df.groupby(['month', 'weekday_sun0'], as_index=False)['value']
20
+ .sum()
21
+ .rename(columns={'value': 'total_value'})
22
+ )
23
+
24
+ # Ensure all 12x7 cells exist (fill missing with 0)
25
+ full = pd.MultiIndex.from_product(
26
+ [range(1, 13), range(0, 7)],
27
+ names=['month', 'weekday_sun0']
28
+ ).to_frame(index=False)
29
+ agg = full.merge(agg, on=['month', 'weekday_sun0'], how='left').fillna({'total_value': 0.0})
30
+
31
+ # Labels for months and weekdays
32
+ month_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
33
+ weekday_labels = ['Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat']
34
+
35
+ agg['month_name'] = agg['month'].map(lambda m: month_labels[m-1])
36
+ agg['weekday_name'] = agg['weekday_sun0'].map(lambda w: weekday_labels[w])
37
+
38
+ # Polar "cell" geometry
39
+ # Each month occupies a 30-degree sector (360/12 = 30)
40
+ month_width = 360 / 12
41
+ agg['theta'] = (agg['month'] - 1) * month_width # 0, 30, 60, ..., 330 (Jan at 0)
42
+ agg['width'] = month_width # sector width
43
+ agg['base'] = agg['weekday_sun0'] # ring start (0..6)
44
+ agg['r'] = 1 # ring thickness (each weekday is one ring)
45
+
46
+ # Bin values into 5 categories for color coding
47
+ s = agg['total_value'].astype(float)
48
+ nonzero = s[s > 0]
49
+
50
+ if nonzero.empty:
51
+ agg['bin'] = 'All zero'
52
+ bin_labels = ['All zero']
53
+ else:
54
+ # Quantile binning on non-zero values
55
+ binned_nz = pd.qcut(nonzero, q=5, duplicates='drop')
56
+ intervals = binned_nz.cat.categories
57
+ bin_labels = [f'{iv.left:,.0f}–{iv.right:,.0f}' for iv in intervals]
58
+
59
+ nz_labels = pd.Series(binned_nz.astype(str), index=nonzero.index)
60
+ interval_to_label = {str(iv): lbl for iv, lbl in zip(intervals, bin_labels)}
61
+ nz_labels = nz_labels.map(interval_to_label)
62
+
63
+ agg['bin'] = '0' # default for zeros
64
+ agg.loc[nonzero.index, 'bin'] = nz_labels.values
65
+ bin_labels = ['0'] + bin_labels
66
+
67
+ # Color palette (5 colors)
68
+ palette = ['#edf8fb', '#b2e2e2', '#66c2a4', '#2ca25f', '#006d2c']
69
+ unique_bins = [b for b in bin_labels if b in agg['bin'].unique()]
70
+ colors = palette[:max(1, len(unique_bins))]
71
+ color_map = dict(zip(unique_bins, colors))
72
+
73
+ # Build figure with one Barpolar trace per bin
74
+ fig = go.Figure()
75
+
76
+ for b in unique_bins:
77
+ sub = agg[agg['bin'] == b]
78
+ fig.add_trace(go.Barpolar(
79
+ theta=sub['theta'],
80
+ r=sub['r'],
81
+ base=sub['base'],
82
+ width=sub['width'],
83
+ name=b,
84
+ marker_color=color_map[b],
85
+ marker_line_width=0, # removes gaps between cells
86
+ hovertemplate=(
87
+ 'Month: %{customdata[0]}<br>'
88
+ 'Weekday: %{customdata[1]}<br>'
89
+ 'Value: %{customdata[2]:,.2f}<extra></extra>'
90
+ ),
91
+ customdata=np.stack([sub['month_name'], sub['weekday_name'], sub['total_value']], axis=1),
92
+ ))
93
+
94
+ # Radial ticks placed at ring centers (0.5..6.5)
95
+ tickvals = [i + 0.5 for i in range(7)]
96
+
97
+ fig.update_layout(
98
+ title='Circular Calendar View - Monthly Values by Weekday (2024)',
99
+ template='plotly_white',
100
+ margin=dict(l=40, r=40, t=70, b=40),
101
+ polar=dict(
102
+ angularaxis=dict(
103
+ direction='clockwise',
104
+ rotation=90, # puts theta=0 (Jan) at top
105
+ tickmode='array',
106
+ tickvals=[i * month_width for i in range(12)],
107
+ ticktext=month_labels,
108
+ ),
109
+ radialaxis=dict(
110
+ range=[0, 7],
111
+ tickmode='array',
112
+ tickvals=tickvals,
113
+ ticktext=weekday_labels, # Sun..Sat
114
+ showline=False,
115
+ gridcolor='rgba(0,0,0,0.12)',
116
+ ),
117
+ ),
118
+ legend_title_text='Value (binned)',
119
+ )
120
+
121
+ fig.show()
assets/example_polar_scatter.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import plotly.express as px
4
+
5
+ # Sample data for demonstration (full year 2024)
6
+ data = {
7
+ 'date': pd.date_range('2024-01-01', periods=365, freq='D'),
8
+ 'value': np.random.randint(50, 200, 365)
9
+ }
10
+ df = pd.DataFrame(data)
11
+
12
+ # Extract calendar components
13
+ df['month'] = df['date'].dt.month # 1..12
14
+ # Convert pandas weekday (Monday=0..Sunday=6) to Sun=0..Sat=6
15
+ df['weekday_sun0'] = (df['date'].dt.dayofweek + 1) % 7
16
+
17
+ # Aggregate values by month x weekday
18
+ agg = (
19
+ df.groupby(['month', 'weekday_sun0'], as_index=False)['value']
20
+ .sum()
21
+ .rename(columns={'value': 'total_value'})
22
+ )
23
+
24
+ # Ensure all 12x7 cells exist (fill missing with 0)
25
+ full = pd.MultiIndex.from_product(
26
+ [range(1, 13), range(0, 7)],
27
+ names=['month', 'weekday_sun0']
28
+ ).to_frame(index=False)
29
+ agg = full.merge(agg, on=['month', 'weekday_sun0'], how='left').fillna({'total_value': 0.0})
30
+
31
+ # Labels for months and weekdays
32
+ month_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
33
+ weekday_labels = ['Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat']
34
+
35
+ agg['month_name'] = agg['month'].map(lambda m: month_labels[m-1])
36
+ agg['weekday_name'] = agg['weekday_sun0'].map(lambda w: weekday_labels[w])
37
+
38
+ # Rings: 1..7 (Sun=1 inner → Sat=7 outer)
39
+ agg['r'] = agg['weekday_sun0'] + 1
40
+
41
+ # Bubble size normalization (log1p compresses large values; then scale to pixel range)
42
+ max_marker_px = 42
43
+ min_marker_px = 6
44
+
45
+ s = agg['total_value'].to_numpy(dtype=float)
46
+ s_log = np.log1p(s)
47
+
48
+ if np.allclose(s_log.max(), s_log.min()):
49
+ agg['size_px'] = min_marker_px
50
+ else:
51
+ # Scale log values to [min_marker_px, max_marker_px]
52
+ scaled = (s_log - s_log.min()) / (s_log.max() - s_log.min())
53
+ agg['size_px'] = min_marker_px + scaled * (max_marker_px - min_marker_px)
54
+
55
+ # Sizeref for area sizing
56
+ sizeref = 2.0 * agg['size_px'].max() / (max_marker_px ** 2)
57
+
58
+ # Build polar scatter chart
59
+ fig = px.scatter_polar(
60
+ agg,
61
+ r='r',
62
+ theta='month_name',
63
+ size='size_px',
64
+ size_max=max_marker_px,
65
+ color='total_value',
66
+ color_continuous_scale='Viridis',
67
+ hover_data={
68
+ 'month_name': True,
69
+ 'weekday_name': True,
70
+ 'total_value': ':,.2f',
71
+ 'r': False,
72
+ 'size_px': False,
73
+ 'month': False,
74
+ 'weekday_sun0': False,
75
+ },
76
+ title='Circular Calendar View - Monthly Values by Weekday (2024)',
77
+ )
78
+
79
+ # Force area sizing behavior
80
+ fig.update_traces(marker=dict(sizemode='area', sizeref=sizeref, line=dict(width=0.6)))
81
+
82
+ # Clockwise months, start Jan at top
83
+ fig.update_layout(
84
+ polar=dict(
85
+ angularaxis=dict(
86
+ direction='clockwise',
87
+ rotation=90, # puts Jan at the top
88
+ ),
89
+ radialaxis=dict(
90
+ tickmode='array',
91
+ tickvals=list(range(1, 8)),
92
+ ticktext=weekday_labels, # Sun..Sat
93
+ range=[0.5, 7.5],
94
+ showline=False,
95
+ gridcolor='rgba(0,0,0,0.12)',
96
+ ),
97
+ ),
98
+ coloraxis_colorbar=dict(title='Value'),
99
+ template='plotly_white',
100
+ margin=dict(l=40, r=40, t=70, b=40),
101
+ height=800,
102
+ )
103
+
104
+ fig.show()
data/polar_bar_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:862aa7ccd01dc0f0a2226b8ca69ae4d3d3b7f0eab93fc2fc359082c0bc228d25
3
+ size 1123
utils/helpers.py CHANGED
@@ -43,10 +43,14 @@ def get_fig_from_code(code, file_name):
43
  def display_response(response, file_name):
44
  try:
45
  code_block_match = re.search(r"```(?:[Pp]ython)?(.*?)```", response, re.DOTALL)
46
- #print(code_block_match)
47
 
48
  if code_block_match:
49
  code_block = code_block_match.group(1).strip()
 
 
 
 
 
50
  cleaned_code = re.sub(r'(?m)^\s*fig\.show\(\)\s*$', '', code_block)
51
 
52
  try:
 
43
  def display_response(response, file_name):
44
  try:
45
  code_block_match = re.search(r"```(?:[Pp]ython)?(.*?)```", response, re.DOTALL)
 
46
 
47
  if code_block_match:
48
  code_block = code_block_match.group(1).strip()
49
+
50
+ # Check if code ends with fig.show() and add it if missing
51
+ if not re.search(r'fig\.show\(\)\s*$', code_block, re.MULTILINE):
52
+ code_block = code_block + "\nfig.show()"
53
+
54
  cleaned_code = re.sub(r'(?m)^\s*fig\.show\(\)\s*$', '', code_block)
55
 
56
  try:
utils/prompt.py CHANGED
@@ -61,6 +61,14 @@ def get_prompt_text() -> str:
61
  If any validation rule fails, return ONLY the error message in the format specified above. Do NOT generate any Python code.
62
 
63
  IF VALIDATION PASSES, PROCEED WITH CODE GENERATION:
 
 
 
 
 
 
 
 
64
  Ensure that before performing any data manipulation or plotting, the code checks for column data types and converts them if necessary.
65
  For example, numeric columns should be converted to floats or integers using pd.to_numeric(), and non-numeric columns should be excluded from numeric operations.
66
  Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
@@ -71,9 +79,28 @@ def get_prompt_text() -> str:
71
  {data_visualization_best_practices}
72
  If the user requests a single visualization, figure height to 800.
73
  Ensure that the graph is clearly labeled with a title, x-axis label, y-axis label, and legend.
74
- If the user has requested for a choropleth map of the United States of America (USA), ensure that the locations parameter in the px.choropleth() method is
75
- set to the column which contains the two letter code state abbreviations, for example: AL, NY, TN, VT, UT (the column should not be determined by the name of the column,
76
- but by the values it contains) and the scope parameter is set to 'usa'.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  If the user requests multiple visualizations, create a subplot for each visualization.
78
  The libraries required for multiple visualizations are: import plotly.graph_objects as go and from plotly.subplots import make_subplots.
79
  Utilize the plotly.graph_objects library's make_subplots() method to create subplots, specifying the number of rows and columns,
@@ -96,7 +123,44 @@ def get_prompt_text() -> str:
96
  The height of the figure (fig) should be set to 800.
97
  Suppose that the data is provided as a {name_of_file} file.
98
  Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
99
- There should be no natural language text in the python code block."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> str:
102
  """
@@ -114,6 +178,54 @@ def get_response(user_input: str, data_top5_csv_string: str, file_name: str) ->
114
  Exception: If API call fails or validation fails
115
  """
116
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  prompt = ChatPromptTemplate.from_messages(
118
  [
119
  ("system", get_prompt_text()),
@@ -123,25 +235,27 @@ def get_response(user_input: str, data_top5_csv_string: str, file_name: str) ->
123
 
124
  chain = prompt | llm
125
 
126
- response = chain.invoke(
127
- {
128
- "messages": [HumanMessage(content=user_input)],
129
- "data_visualization_best_practices": helpers.read_doc(
130
- helpers.get_app_file_path("assets", "data_viz_best_practices.txt")
131
- ),
132
- "example_subplots1": helpers.read_doc(
133
- helpers.get_app_file_path("assets", "example_subplots1.txt")
134
- ),
135
- "example_subplots2": helpers.read_doc(
136
- helpers.get_app_file_path("assets", "example_subplots2.txt")
137
- ),
138
- "example_subplots3": helpers.read_doc(
139
- helpers.get_app_file_path("assets", "example_subplots3.txt")
140
- ),
141
- "data": data_top5_csv_string,
142
- "name_of_file": file_name
143
- }
144
- )
 
 
145
 
146
  # Check if the response is an error message instead of code
147
  response_text = response.content.strip()
 
61
  If any validation rule fails, return ONLY the error message in the format specified above. Do NOT generate any Python code.
62
 
63
  IF VALIDATION PASSES, PROCEED WITH CODE GENERATION:
64
+
65
+ PANDAS DATA HANDLING BEST PRACTICES:
66
+ - Always use .copy() when creating a new dataframe from a subset or filtered view to avoid SettingWithCopyWarning.
67
+ - Example: df_filtered = df[df['column'] > 0].copy()
68
+ - When modifying data, always work on explicit copies, not chained indexing.
69
+ - Use .loc[] for setting values: df.loc[condition, 'column'] = value
70
+ - Avoid chained assignment like df[condition]['column'] = value
71
+
72
  Ensure that before performing any data manipulation or plotting, the code checks for column data types and converts them if necessary.
73
  For example, numeric columns should be converted to floats or integers using pd.to_numeric(), and non-numeric columns should be excluded from numeric operations.
74
  Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
 
79
  {data_visualization_best_practices}
80
  If the user requests a single visualization, figure height to 800.
81
  Ensure that the graph is clearly labeled with a title, x-axis label, y-axis label, and legend.
82
+
83
+ SPECIFIC CHART TYPE INSTRUCTIONS:
84
+
85
+ CHOROPLETH MAPS:
86
+ CRITICAL: When creating a choropleth map of the United States, you MUST include ALL of the following parameters:
87
+ - locations: Set to the column containing two-letter state abbreviations (e.g., 'AL', 'NY', 'CA', 'TX')
88
+ - locationmode: MUST be set to 'USA-states' (this is CRITICAL - without it, the map will be blank)
89
+ - scope: Set to 'usa'
90
+ Example:
91
+ fig = px.choropleth(df,
92
+ locations='state_code_column',
93
+ locationmode='USA-states',
94
+ scope='usa',
95
+ color='value_column',
96
+ title='Map Title')
97
+ The locations parameter should reference the column with state codes, not the column with full state names.
98
+ Always verify that locationmode='USA-states' is present in the code.
99
+
100
+ {dumbbell_charts_section}
101
+
102
+ {polar_charts_section}
103
+
104
  If the user requests multiple visualizations, create a subplot for each visualization.
105
  The libraries required for multiple visualizations are: import plotly.graph_objects as go and from plotly.subplots import make_subplots.
106
  Utilize the plotly.graph_objects library's make_subplots() method to create subplots, specifying the number of rows and columns,
 
123
  The height of the figure (fig) should be set to 800.
124
  Suppose that the data is provided as a {name_of_file} file.
125
  Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
126
+ There should be no natural language text in the python code block.
127
+
128
+ REMINDER: Your code MUST end with fig.show() to display the visualization."""
129
+
130
+ def _should_include_dumbbell_examples(user_input: str) -> bool:
131
+ """
132
+ Check if user's request is about dumbbell charts or comparison visualizations.
133
+
134
+ Args:
135
+ user_input: User's visualization request
136
+
137
+ Returns:
138
+ bool: True if dumbbell chart examples should be included
139
+ """
140
+ dumbbell_keywords = [
141
+ 'dumbbell', 'dumb bell', 'dumbell', 'dumbel', 'comparison', 'before and after', 'before after',
142
+ 'start and end', 'start end', 'range', 'difference', 'gap', 'change over'
143
+ ]
144
+
145
+ user_input_lower = user_input.lower()
146
+ return any(keyword in user_input_lower for keyword in dumbbell_keywords)
147
+
148
+ def _should_include_polar_examples(user_input: str) -> bool:
149
+ """
150
+ Check if user's request is about polar charts, calendar views, or circular visualizations.
151
+
152
+ Args:
153
+ user_input: User's visualization request
154
+
155
+ Returns:
156
+ bool: True if polar chart examples should be included
157
+ """
158
+ polar_keywords = [
159
+ 'polar', 'circular', 'radial', 'circular fashion', 'radar', 'rose'
160
+ ]
161
+
162
+ user_input_lower = user_input.lower()
163
+ return any(keyword in user_input_lower for keyword in polar_keywords)
164
 
165
  def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> str:
166
  """
 
178
  Exception: If API call fails or validation fails
179
  """
180
  try:
181
+ # Determine if dumbbell chart examples should be included
182
+ include_dumbbell = _should_include_dumbbell_examples(user_input)
183
+
184
+ # Determine if polar chart examples should be included
185
+ include_polar = _should_include_polar_examples(user_input)
186
+
187
+ # Build dumbbell charts section conditionally
188
+ dumbbell_charts_section = ""
189
+ if include_dumbbell:
190
+ dumbbell_example = helpers.read_doc(
191
+ helpers.get_app_file_path("assets", "example_dumbbell_chart.txt")
192
+ )
193
+ dumbbell_charts_section = f"""
194
+ DUMBBELL PLOTS:
195
+ When creating a dumbbell plot, use plotly.graph_objects (go) instead of plotly.express (px).
196
+ Use go.Figure() and add two go.Scatter traces for the two data points, and a go.Scatter trace for the lines connecting them.
197
+ Ensure proper labeling of axes and title for clarity.
198
+ Example: \n
199
+ {dumbbell_example}
200
+ """
201
+
202
+ # Build polar charts section conditionally
203
+ polar_charts_section = ""
204
+ if include_polar:
205
+ polar_bar_example = helpers.read_doc(
206
+ helpers.get_app_file_path("assets", "example_polar_bar.txt")
207
+ )
208
+ polar_scatter_example = helpers.read_doc(
209
+ helpers.get_app_file_path("assets", "example_polar_scatter.txt")
210
+ )
211
+ polar_charts_section = f"""
212
+ POLAR CHARTS (RADIAL/CIRCULAR VISUALIZATIONS):
213
+ Polar charts are effective for displaying calendar views, weekly patterns, or circular data distributions.
214
+ Use them for innovative visualizations of time-based or cyclical data.
215
+
216
+ Example 1 - Polar Calendar with Cells (Barpolar):
217
+ {polar_bar_example}
218
+
219
+ Example 2 - Polar Calendar with Scatter:
220
+ {polar_scatter_example}
221
+
222
+ Use polar charts when the user requests:
223
+ - Calendar-like views
224
+ - Weekly or cyclical patterns
225
+ - Circular representations of data
226
+ - Radial visualizations
227
+ """
228
+
229
  prompt = ChatPromptTemplate.from_messages(
230
  [
231
  ("system", get_prompt_text()),
 
235
 
236
  chain = prompt | llm
237
 
238
+ invoke_params = {
239
+ "messages": [HumanMessage(content=user_input)],
240
+ "data_visualization_best_practices": helpers.read_doc(
241
+ helpers.get_app_file_path("assets", "data_viz_best_practices.txt")
242
+ ),
243
+ "example_subplots1": helpers.read_doc(
244
+ helpers.get_app_file_path("assets", "example_subplots1.txt")
245
+ ),
246
+ "example_subplots2": helpers.read_doc(
247
+ helpers.get_app_file_path("assets", "example_subplots2.txt")
248
+ ),
249
+ "example_subplots3": helpers.read_doc(
250
+ helpers.get_app_file_path("assets", "example_subplots3.txt")
251
+ ),
252
+ "dumbbell_charts_section": dumbbell_charts_section,
253
+ "polar_charts_section": polar_charts_section,
254
+ "data": data_top5_csv_string,
255
+ "name_of_file": file_name
256
+ }
257
+
258
+ response = chain.invoke(invoke_params)
259
 
260
  # Check if the response is an error message instead of code
261
  response_text = response.content.strip()