| import pandas as pd |
| import matplotlib.pyplot as plt |
| import io |
| import ast |
| from PIL import Image, ImageDraw |
| import google.generativeai as genai |
| import traceback |
| import os |
| from pywebio import start_server |
| from pywebio.input import file_upload, input |
| from pywebio.output import put_text, put_image, put_row, put_column, put_buttons, use_scope |
| from pywebio.session import run_js |
| import base64 |
| import threading |
|
|
| def process_file(file, instructions): |
| try: |
| |
| api_key = os.environ.get('GEMINI_API_KEY') |
| genai.configure(api_key=api_key) |
| model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
| |
| |
| content = file['content'] |
| if file['filename'].endswith('.csv'): |
| df = pd.read_csv(io.BytesIO(content)) |
| else: |
| df = pd.read_excel(io.BytesIO(content)) |
| |
| |
| response = model.generate_content(f""" |
| Analyze the following dataset and instructions: |
| |
| Data columns: {list(df.columns)} |
| Data shape: {df.shape} |
| Instructions: {instructions} |
| |
| Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization: |
| 1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap) |
| 2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages) |
| 3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size) |
| 4. Provide a clear, concise title that explains the insight |
| |
| Consider data density and choose visualizations that simplify and clarify the information. |
| Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly). |
| |
| Return your response as a Python list of dictionaries: |
| [ |
| {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
| {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
| {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}} |
| ] |
| """) |
|
|
| |
| code_block = response.text |
| if '```python' in code_block: |
| code_block = code_block.split('```python')[1].split('```')[0].strip() |
| elif '```' in code_block: |
| code_block = code_block.split('```')[1].strip() |
| |
| print("Generated code block:") |
| print(code_block) |
| |
| plots = ast.literal_eval(code_block) |
| |
| |
| images = [] |
| for plot in plots[:3]: |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| |
| |
| plot_df = df.copy() |
| if plot['agg_func'] == 'sum': |
| plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index() |
| elif plot['agg_func'] == 'mean': |
| plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index() |
| elif plot['agg_func'] == 'count': |
| plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y']) |
| |
| if 'top_n' in plot and plot['top_n']: |
| plot_df = plot_df.nlargest(plot['top_n'], plot['y']) |
| |
| if plot['plot_type'] == 'bar': |
| plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax) |
| elif plot['plot_type'] == 'line': |
| plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax) |
| elif plot['plot_type'] == 'scatter': |
| plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax, |
| c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')]) |
| elif plot['plot_type'] == 'hist': |
| plot_df[plot['x']].hist(ax=ax, bins=20) |
| elif plot['plot_type'] == 'pie': |
| plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%') |
| elif plot['plot_type'] == 'heatmap': |
| pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y']) |
| ax.imshow(pivot_df, cmap='YlOrRd') |
| ax.set_xticks(range(len(pivot_df.columns))) |
| ax.set_yticks(range(len(pivot_df.index))) |
| ax.set_xticklabels(pivot_df.columns) |
| ax.set_yticklabels(pivot_df.index) |
| |
| ax.set_title(plot['title']) |
| if plot['plot_type'] != 'pie': |
| ax.set_xlabel(plot['x']) |
| ax.set_ylabel(plot['y']) |
| plt.tight_layout() |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png') |
| buf.seek(0) |
| img = Image.open(buf) |
| images.append(img) |
| plt.close(fig) |
|
|
| return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images)) |
|
|
| except Exception as e: |
| error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| print(error_message) |
| error_image = Image.new('RGB', (800, 400), (255, 255, 255)) |
| draw = ImageDraw.Draw(error_image) |
| draw.text((10, 10), error_message, fill=(255, 0, 0)) |
| return [error_image] * 3 |
|
|
| def data_analysis_dashboard(): |
| put_text("# Data Analysis Dashboard") |
| |
| with use_scope('form'): |
| put_row([ |
| put_column([ |
| file_upload("Upload Dataset", accept=[".csv", ".xlsx"], name="file"), |
| input("Analysis Instructions", type="text", placeholder="Describe the analysis you want...", name="instructions"), |
| put_buttons(['Generate Insights'], onclick=[lambda: generate_insights()]) |
| ]) |
| ]) |
| |
| with use_scope('output'): |
| for i in range(3): |
| put_image(name=f'visualization_{i+1}') |
|
|
| def generate_insights(): |
| file = file_upload.files.get('file') |
| instructions = input.inputs.get('instructions') |
| |
| if not file or not instructions: |
| put_text("Please upload a file and provide instructions.") |
| return |
| |
| images = process_file(file, instructions) |
| |
| for i, img in enumerate(images): |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode() |
| with use_scope(f'visualization_{i+1}', clear=True): |
| put_image(img_str, width='100%') |
|
|
| def main(): |
| data_analysis_dashboard() |
|
|
| if __name__ == '__main__': |
| start_server(main, host='0.0.0.0', port=7860, debug=True, cdn=False, auto_open_webbrowser=True) |