Mustehson commited on
Commit
5862423
·
1 Parent(s): e0b017d

Added Colors

Browse files
Files changed (2) hide show
  1. app.py +9 -9
  2. plot_utils.py +11 -4
app.py CHANGED
@@ -121,9 +121,11 @@ Consider these types of questions when recommending a visualization:
121
  5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
122
  6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
123
 
124
- Provide your response in the following format:
 
125
  Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
126
  Reason: [Brief explanation for your recommendation]
 
127
  '''
128
  human_message = '''
129
  User question: {question}
@@ -139,19 +141,17 @@ Recommend a visualization:
139
  ])
140
 
141
  final_prompt = prompt.format_prompt(question=text_query,
142
- sql_query=sql_query, results=sql_result[:20])
143
  response = run_llm(final_prompt)
144
  response = response.replace('```', '')
145
- lines = response.strip().split('\n')
146
- print(lines)
147
- visualization = lines[0].split(': ')[1]
148
- reason = lines[1].split(': ')[1]
149
-
150
  return visualization, reason
151
 
152
 
153
-
154
-
155
  def format_data(text_query, sql_query, sql_result, visualization_type):
156
  instruction = graph_instructions[visualization_type]
157
 
 
121
  5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
122
  6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
123
 
124
+ Generate a JSON object. The JSON object should have the following structure:
125
+
126
  Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
127
  Reason: [Brief explanation for your recommendation]
128
+
129
  '''
130
  human_message = '''
131
  User question: {question}
 
141
  ])
142
 
143
  final_prompt = prompt.format_prompt(question=text_query,
144
+ sql_query=sql_query, results=sql_result)
145
  response = run_llm(final_prompt)
146
  response = response.replace('```', '')
147
+ response = response.replace('json', '')
148
+ json_data = json.loads(response)
149
+ visualization = json_data['Recommended Visualization']
150
+ reason = json_data['Reason']
151
+ print(visualization, reason)
152
  return visualization, reason
153
 
154
 
 
 
155
  def format_data(text_query, sql_query, sql_result, visualization_type):
156
  instruction = graph_instructions[visualization_type]
157
 
plot_utils.py CHANGED
@@ -1,13 +1,20 @@
1
  import matplotlib.pyplot as plt
 
2
 
3
  def plot_bar_chart(data):
4
-
5
  fig, ax = plt.subplots()
6
  labels = data['labels']
7
  values = data['values']
8
 
9
- for value in values:
10
- ax.bar(labels, value['data'], label=value['label'])
 
 
 
 
 
 
 
11
 
12
  ax.set_title('Bar Chart')
13
  ax.set_xlabel('Labels')
@@ -21,7 +28,7 @@ def plot_horizontal_bar_chart(data):
21
  values = data['values']
22
 
23
  for value in values:
24
- ax.barh(labels, value['data'], label=value['label'])
25
 
26
  ax.set_title('Horizontal Bar Chart')
27
  ax.set_xlabel('Values')
 
1
  import matplotlib.pyplot as plt
2
+ import numpy as np
3
 
4
  def plot_bar_chart(data):
 
5
  fig, ax = plt.subplots()
6
  labels = data['labels']
7
  values = data['values']
8
 
9
+ colors = ['darkviolet', 'indigo', 'blueviolet'] # Define the colors for each group
10
+ width = 0.2 # Width of the bars
11
+ x = np.arange(len(labels)) # Label locations
12
+
13
+ for i, value in enumerate(values):
14
+ ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
15
+
16
+ ax.set_xticks(x + width / 2 * (len(values) - 1))
17
+ ax.set_xticklabels(labels)
18
 
19
  ax.set_title('Bar Chart')
20
  ax.set_xlabel('Labels')
 
28
  values = data['values']
29
 
30
  for value in values:
31
+ ax.barh(labels, value['data'], label=value['label'], color=['darkviolet', 'indigo', 'blueviolet'])
32
 
33
  ax.set_title('Horizontal Bar Chart')
34
  ax.set_xlabel('Values')