Spaces:
Sleeping
Sleeping
Mustehson commited on
Commit ·
5862423
1
Parent(s): e0b017d
Added Colors
Browse files- app.py +9 -9
- plot_utils.py +11 -4
app.py
CHANGED
|
@@ -121,9 +121,11 @@ Consider these types of questions when recommending a visualization:
|
|
| 121 |
5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
|
| 122 |
6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
|
| 123 |
|
| 124 |
-
|
|
|
|
| 125 |
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
|
| 126 |
Reason: [Brief explanation for your recommendation]
|
|
|
|
| 127 |
'''
|
| 128 |
human_message = '''
|
| 129 |
User question: {question}
|
|
@@ -139,19 +141,17 @@ Recommend a visualization:
|
|
| 139 |
])
|
| 140 |
|
| 141 |
final_prompt = prompt.format_prompt(question=text_query,
|
| 142 |
-
sql_query=sql_query, results=sql_result
|
| 143 |
response = run_llm(final_prompt)
|
| 144 |
response = response.replace('```', '')
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
visualization =
|
| 148 |
-
reason =
|
| 149 |
-
|
| 150 |
return visualization, reason
|
| 151 |
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
| 156 |
instruction = graph_instructions[visualization_type]
|
| 157 |
|
|
|
|
| 121 |
5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
|
| 122 |
6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
|
| 123 |
|
| 124 |
+
Generate a JSON object. The JSON object should have the following structure:
|
| 125 |
+
|
| 126 |
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
|
| 127 |
Reason: [Brief explanation for your recommendation]
|
| 128 |
+
|
| 129 |
'''
|
| 130 |
human_message = '''
|
| 131 |
User question: {question}
|
|
|
|
| 141 |
])
|
| 142 |
|
| 143 |
final_prompt = prompt.format_prompt(question=text_query,
|
| 144 |
+
sql_query=sql_query, results=sql_result)
|
| 145 |
response = run_llm(final_prompt)
|
| 146 |
response = response.replace('```', '')
|
| 147 |
+
response = response.replace('json', '')
|
| 148 |
+
json_data = json.loads(response)
|
| 149 |
+
visualization = json_data['Recommended Visualization']
|
| 150 |
+
reason = json_data['Reason']
|
| 151 |
+
print(visualization, reason)
|
| 152 |
return visualization, reason
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
| 155 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
| 156 |
instruction = graph_instructions[visualization_type]
|
| 157 |
|
plot_utils.py
CHANGED
|
@@ -1,13 +1,20 @@
|
|
| 1 |
import matplotlib.pyplot as plt
|
|
|
|
| 2 |
|
| 3 |
def plot_bar_chart(data):
|
| 4 |
-
|
| 5 |
fig, ax = plt.subplots()
|
| 6 |
labels = data['labels']
|
| 7 |
values = data['values']
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
ax.set_title('Bar Chart')
|
| 13 |
ax.set_xlabel('Labels')
|
|
@@ -21,7 +28,7 @@ def plot_horizontal_bar_chart(data):
|
|
| 21 |
values = data['values']
|
| 22 |
|
| 23 |
for value in values:
|
| 24 |
-
ax.barh(labels, value['data'], label=value['label'])
|
| 25 |
|
| 26 |
ax.set_title('Horizontal Bar Chart')
|
| 27 |
ax.set_xlabel('Values')
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
|
| 4 |
def plot_bar_chart(data):
|
|
|
|
| 5 |
fig, ax = plt.subplots()
|
| 6 |
labels = data['labels']
|
| 7 |
values = data['values']
|
| 8 |
|
| 9 |
+
colors = ['darkviolet', 'indigo', 'blueviolet'] # Define the colors for each group
|
| 10 |
+
width = 0.2 # Width of the bars
|
| 11 |
+
x = np.arange(len(labels)) # Label locations
|
| 12 |
+
|
| 13 |
+
for i, value in enumerate(values):
|
| 14 |
+
ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
|
| 15 |
+
|
| 16 |
+
ax.set_xticks(x + width / 2 * (len(values) - 1))
|
| 17 |
+
ax.set_xticklabels(labels)
|
| 18 |
|
| 19 |
ax.set_title('Bar Chart')
|
| 20 |
ax.set_xlabel('Labels')
|
|
|
|
| 28 |
values = data['values']
|
| 29 |
|
| 30 |
for value in values:
|
| 31 |
+
ax.barh(labels, value['data'], label=value['label'], color=['darkviolet', 'indigo', 'blueviolet'])
|
| 32 |
|
| 33 |
ax.set_title('Horizontal Bar Chart')
|
| 34 |
ax.set_xlabel('Values')
|