Update app.py
Browse files
app.py
CHANGED
|
@@ -56,22 +56,29 @@ def process_file(file, instructions, api_key):
|
|
| 56 |
for plot in plots[:3]: # Ensure max 3 plots
|
| 57 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 58 |
|
| 59 |
-
# Apply preprocessing
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
if plot['plot_type'] == 'bar':
|
| 64 |
-
|
| 65 |
elif plot['plot_type'] == 'line':
|
| 66 |
-
|
| 67 |
elif plot['plot_type'] == 'scatter':
|
| 68 |
-
|
| 69 |
elif plot['plot_type'] == 'hist':
|
| 70 |
-
|
| 71 |
|
| 72 |
ax.set_title(plot['title'])
|
| 73 |
ax.set_xlabel(plot['x'])
|
| 74 |
-
ax.set_ylabel(plot['y'] if plot['y'] else '
|
| 75 |
plt.tight_layout()
|
| 76 |
|
| 77 |
buf = io.BytesIO()
|
|
@@ -91,23 +98,4 @@ def process_file(file, instructions, api_key):
|
|
| 91 |
draw.text((10, 10), error_message, fill=(255, 0, 0))
|
| 92 |
return [error_image] * 3
|
| 93 |
|
| 94 |
-
|
| 95 |
-
gr.Markdown("# Data Analysis Dashboard")
|
| 96 |
-
|
| 97 |
-
with gr.Row():
|
| 98 |
-
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
|
| 99 |
-
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
|
| 100 |
-
|
| 101 |
-
api_key = gr.Textbox(label="Gemini API Key", type="password")
|
| 102 |
-
submit = gr.Button("Generate Insights", variant="primary")
|
| 103 |
-
|
| 104 |
-
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
|
| 105 |
-
|
| 106 |
-
submit.click(
|
| 107 |
-
process_file,
|
| 108 |
-
inputs=[file, instructions, api_key],
|
| 109 |
-
outputs=output_images
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
if __name__ == "__main__":
|
| 113 |
-
demo.launch()
|
|
|
|
| 56 |
for plot in plots[:3]: # Ensure max 3 plots
|
| 57 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 58 |
|
| 59 |
+
# Apply preprocessing
|
| 60 |
+
plot_df = df.copy()
|
| 61 |
+
if 'Group data by' in plot['preprocessing']:
|
| 62 |
+
group_by = plot['x']
|
| 63 |
+
agg_column = plot['y'][0] if isinstance(plot['y'], list) else plot['y']
|
| 64 |
+
plot_df = plot_df.groupby(group_by)[agg_column].sum().reset_index()
|
| 65 |
+
if 'Sort' in plot['preprocessing']:
|
| 66 |
+
plot_df = plot_df.sort_values(by=plot['y'][0] if isinstance(plot['y'], list) else plot['y'], ascending=False)
|
| 67 |
+
if 'Filter to keep only the top 5' in plot['preprocessing']:
|
| 68 |
+
plot_df = plot_df.head(5)
|
| 69 |
|
| 70 |
if plot['plot_type'] == 'bar':
|
| 71 |
+
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
|
| 72 |
elif plot['plot_type'] == 'line':
|
| 73 |
+
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
|
| 74 |
elif plot['plot_type'] == 'scatter':
|
| 75 |
+
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
|
| 76 |
elif plot['plot_type'] == 'hist':
|
| 77 |
+
plot_df[plot['x']].hist(ax=ax)
|
| 78 |
|
| 79 |
ax.set_title(plot['title'])
|
| 80 |
ax.set_xlabel(plot['x'])
|
| 81 |
+
ax.set_ylabel(plot['y'][0] if isinstance(plot['y'], list) else plot['y'])
|
| 82 |
plt.tight_layout()
|
| 83 |
|
| 84 |
buf = io.BytesIO()
|
|
|
|
| 98 |
draw.text((10, 10), error_message, fill=(255, 0, 0))
|
| 99 |
return [error_image] * 3
|
| 100 |
|
| 101 |
+
# The rest of your code remains unchanged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|