Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| import tempfile | |
| import re | |
| import base64 | |
| import io | |
| import zipfile | |
| import logging | |
| import asyncio | |
| from PIL import Image | |
| from docx import Document | |
| from docx.shared import Inches | |
| from agent.agent import DataVizAgent | |
| from mcp_tools.client import DataVizClient | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('dataviz_agent.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Agent | |
| agent = DataVizAgent() | |
| # Initialize MCP Client | |
| mcp_client = DataVizClient() | |
| def b64_to_pil(b64_str): | |
| return Image.open(io.BytesIO(base64.b64decode(b64_str))) | |
| def analyze_dataset(file_path): | |
| """ | |
| Analyzes the dataset and returns a summary and the dataframe. | |
| """ | |
| if file_path is None: | |
| return None, "No file uploaded." | |
| try: | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| elif file_path.endswith('.xlsx'): | |
| df = pd.read_excel(file_path) | |
| else: | |
| return None, "Unsupported file format. Please upload CSV or Excel." | |
| # Validate dataset | |
| if df.empty: | |
| return None, "Error: The uploaded file is empty." | |
| if len(df.columns) == 0: | |
| return None, "Error: No columns found in the dataset." | |
| if len(df) > 1000000: | |
| return None, "Error: Dataset is too large (>1M rows). Please use a smaller file." | |
| except Exception as e: | |
| return None, f"Error loading file: {str(e)}" | |
| summary = { | |
| "columns": [], | |
| "row_count": len(df) | |
| } | |
| for col in df.columns: | |
| col_info = { | |
| "name": col, | |
| "type": str(df[col].dtype), | |
| "unique_values": df[col].nunique(), | |
| "missing_values": df[col].isnull().sum() | |
| } | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| try: | |
| min_val = df[col].min() | |
| max_val = df[col].max() | |
| col_info["min"] = float(min_val) if pd.notna(min_val) else None | |
| col_info["max"] = float(max_val) if pd.notna(max_val) else None | |
| except (ValueError, TypeError): | |
| col_info["min"] = None | |
| col_info["max"] = None | |
| col_info["is_numeric"] = True | |
| else: | |
| col_info["is_numeric"] = False | |
| summary["columns"].append(col_info) | |
| return df, summary | |
| def process_upload(file): | |
| logger.info(f"Processing file upload: {file.name}") | |
| df, summary = analyze_dataset(file.name) | |
| if df is None: | |
| logger.error(f"Failed to load file: {file.name}") | |
| return None, {}, "Error loading file.", None | |
| # Save dataframe to a temporary parquet file for the MCP tool | |
| fd, path = tempfile.mkstemp(suffix='.parquet') | |
| os.close(fd) | |
| df.to_parquet(path) | |
| logger.info(f"Dataset saved to temp file: {path}") | |
| # Create a readable summary string | |
| summary_str = f"Dataset Loaded: {len(df)} rows, {len(df.columns)} columns.\n\nColumns:\n" | |
| for col in summary["columns"]: | |
| summary_str += f"- {col['name']} ({col['type']}): {col['unique_values']} unique" | |
| if col['is_numeric'] and col.get('min') is not None and col.get('max') is not None: | |
| summary_str += f", range: [{col['min']:.2f}, {col['max']:.2f}]" | |
| summary_str += "\n" | |
| return df, summary, summary_str, path | |
| async def respond(message, chat_history, state): | |
| logger.info(f"User message: {message}") | |
| if state["dataframe"] is None: | |
| logger.warning("User attempted to chat without uploading dataset") | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": "Please upload a dataset first."}) | |
| return "", chat_history, gr.update(), state, gr.update(choices=[]) | |
| # Check for chart modification request | |
| chart_id_match = re.search(r'#(\d+)', message) | |
| existing_code = None | |
| target_chart_id = None | |
| if chart_id_match: | |
| chart_id = int(chart_id_match.group(1)) | |
| if chart_id in state["charts"]: | |
| existing_code = state["charts"][chart_id]["code"] | |
| target_chart_id = chart_id | |
| logger.info(f"Modifying chart #{chart_id}") | |
| else: | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": f"Chart #{chart_id} not found."}) | |
| return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| # Generate response using Agent (with chat history) | |
| response = agent.generate_plot_code( | |
| message, | |
| state["columns_summary"], | |
| history=chat_history, | |
| existing_code=existing_code | |
| ) | |
| chat_history.append({"role": "user", "content": message}) | |
| # Check response type | |
| if response["type"] == "error": | |
| logger.error(f"Agent error: {response['content']}") | |
| chat_history.append({"role": "assistant", "content": f"Error: {response['content']}"}) | |
| return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| elif response["type"] == "message": | |
| # Conversational response - no code to execute | |
| logger.info("Agent provided conversational response") | |
| chat_history.append({"role": "assistant", "content": response["content"]}) | |
| return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| elif response["type"] == "code": | |
| # Code generation - execute it | |
| code = response["content"] | |
| logger.info("Executing generated code") | |
| # Execute code using MCP Tool | |
| result = await mcp_client.generate_plot(code, state["data_path"]) | |
| gallery_update = _get_gallery_items(state) | |
| if result["success"]: | |
| # Determine Chart ID | |
| if target_chart_id: | |
| cid = target_chart_id | |
| action = "Updated" | |
| else: | |
| cid = state["next_chart_id"] | |
| state["next_chart_id"] += 1 | |
| action = "Created" | |
| # Generate description | |
| description = agent.describe_chart(message, code) | |
| # Update State | |
| state["charts"][cid] = { | |
| "code": code, | |
| "image": result["image"], | |
| "description": description | |
| } | |
| response_text = f"{action} chart #{cid}: {description}" | |
| chat_history.append({"role": "assistant", "content": response_text}) | |
| logger.info(f"{action} chart #{cid}") | |
| gallery_update = _get_gallery_items(state, selected_cid=cid) | |
| else: | |
| error_details = result.get('stderr', result.get('error', 'Unknown error occurred')) | |
| error_msg = f"Failed to generate chart.\nError: {error_details}\n\nCode:\n```python\n{code}\n```" | |
| chat_history.append({"role": "assistant", "content": error_msg}) | |
| logger.error(f"Chart generation failed: {error_details}") | |
| return "", chat_history, gallery_update, state, _get_chart_choices(state) | |
| # Fallback | |
| return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| def _get_gallery_items(state, selected_cid=None): | |
| items = [] | |
| selected_index = None | |
| current_idx = 0 | |
| # Sort by ID | |
| for cid in sorted(state["charts"].keys()): | |
| chart = state["charts"][cid] | |
| if chart["image"]: | |
| img = b64_to_pil(chart["image"]) | |
| items.append((img, f"#{cid} {chart['description']}")) | |
| if selected_cid is not None and cid == selected_cid: | |
| selected_index = current_idx | |
| current_idx += 1 | |
| if selected_cid is not None: | |
| return gr.update(value=items, selected_index=selected_index) | |
| return items | |
| def _get_chart_choices(state): | |
| return gr.update(choices=[f"#{cid}" for cid in sorted(state["charts"].keys())]) | |
| def delete_chart(chart_str, chat_history, state): | |
| if not chart_str: | |
| return chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| try: | |
| cid = int(chart_str.replace("#", "")) | |
| if cid in state["charts"]: | |
| del state["charts"][cid] | |
| chat_history.append({"role": "assistant", "content": f"🗑️ Chart #{cid} has been deleted."}) | |
| except: | |
| pass | |
| return chat_history, _get_gallery_items(state), state, _get_chart_choices(state) | |
| def download_zip(state): | |
| if not state["charts"]: | |
| return None | |
| zip_filename = tempfile.mktemp(suffix=".zip") | |
| with zipfile.ZipFile(zip_filename, 'w') as zipf: | |
| for cid, chart in state["charts"].items(): | |
| if chart["image"]: | |
| img_data = base64.b64decode(chart["image"]) | |
| zipf.writestr(f"chart_{cid}.png", img_data) | |
| return zip_filename | |
| def download_report(state): | |
| if not state["charts"]: | |
| return None | |
| doc = Document() | |
| doc.add_heading('DataViz Agent Report', 0) | |
| for cid in sorted(state["charts"].keys()): | |
| chart = state["charts"][cid] | |
| if chart["image"]: | |
| doc.add_heading(f"Chart #{cid}: {chart['description']}", level=1) | |
| # Save temp image for docx | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: | |
| tmp_img.write(base64.b64decode(chart["image"])) | |
| tmp_img_path = tmp_img.name | |
| try: | |
| doc.add_picture(tmp_img_path, width=Inches(6)) | |
| finally: | |
| os.remove(tmp_img_path) | |
| doc.add_paragraph(f"Code:\n{chart['code']}") | |
| doc.add_page_break() | |
| doc_filename = tempfile.mktemp(suffix=".docx") | |
| doc.save(doc_filename) | |
| return doc_filename | |
| def global_clear(): | |
| logger.info("Global clear initiated") | |
| new_state = { | |
| "dataframe": None, | |
| "columns_summary": {}, | |
| "charts": {}, | |
| "next_chart_id": 1, | |
| "data_path": None | |
| } | |
| return ( | |
| None, # File | |
| "Upload a dataset to get started.", # Info | |
| [], # Chat | |
| [], # Gallery | |
| new_state, # State | |
| gr.update(choices=[]), # Dropdown | |
| None # Download File | |
| ) | |
| with gr.Blocks(title="DataViz Agent", theme=gr.themes.Soft(), fill_height=True) as demo: | |
| state = gr.State({ | |
| "dataframe": None, | |
| "columns_summary": {}, | |
| "charts": {}, | |
| "next_chart_id": 1, | |
| "data_path": None | |
| }) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 🤖 DataViz Agent Chat") | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 📊 Charts Gallery") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Group(): | |
| file_upload = gr.File(label="Upload Dataset (CSV/XLSX)", file_types=[".csv", ".xlsx"]) | |
| with gr.Accordion("Dataset Info", open=False): | |
| dataset_info = gr.Markdown("Upload a dataset to get started.") | |
| with gr.Row(scale=1, height=700): | |
| chatbot = gr.Chatbot(type="messages", height=700) | |
| with gr.Row(height=50, equal_height=True): | |
| msg = gr.Textbox( | |
| placeholder="Ask to visualize data (e.g., 'Show distribution of age')", | |
| show_label=False, | |
| elem_id="chat-input", | |
| lines=1, | |
| max_lines=1, | |
| scale=1 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=0) | |
| with gr.Column(scale=2): | |
| with gr.Row(height=626): | |
| gallery = gr.Gallery(label="Generated Charts", columns=1, object_fit="contain", height=626) | |
| with gr.Row(): | |
| with gr.Group(): | |
| gr.Markdown("### Manage Charts") | |
| with gr.Row(): | |
| chart_selector = gr.Dropdown(label="Select Chart to Delete", choices=[]) | |
| delete_btn = gr.Button("🗑️ Delete Chart", variant="stop") | |
| with gr.Row(): | |
| dl_zip_btn = gr.Button("💾 Download All (ZIP)") | |
| dl_report_btn = gr.Button("📄 Download Report (Word)") | |
| with gr.Row(height=80): | |
| dl_file = gr.File(label="Download", visible=True) | |
| # Global Clear (Bottom) | |
| with gr.Row(): | |
| global_clear_btn = gr.Button("Global Clear (Reset All)", variant="stop") | |
| # Event Handlers | |
| def on_file_upload(file, current_state): | |
| if file is None: | |
| return current_state, "Upload a dataset to get started." | |
| df, summary, summary_str, path = process_upload(file) | |
| if df is not None: | |
| current_state["dataframe"] = df | |
| current_state["columns_summary"] = summary | |
| current_state["data_path"] = path | |
| return current_state, summary_str | |
| return current_state, summary_str | |
| def on_file_upload_wrapper(file, current_state): | |
| # Clean up old temporary file if exists | |
| if current_state.get("data_path") and os.path.exists(current_state["data_path"]): | |
| try: | |
| os.remove(current_state["data_path"]) | |
| logger.info(f"Cleaned up old temp file: {current_state['data_path']}") | |
| except Exception as e: | |
| logger.warning(f"Failed to remove temp file: {e}") | |
| return on_file_upload(file, current_state) | |
| file_upload.change( | |
| on_file_upload_wrapper, | |
| inputs=[file_upload, state], | |
| outputs=[state, dataset_info] | |
| ) | |
| # Chat interactions | |
| msg.submit( | |
| respond, | |
| inputs=[msg, chatbot, state], | |
| outputs=[msg, chatbot, gallery, state, chart_selector] | |
| ).then( | |
| None, None, None, | |
| js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }" | |
| ) | |
| send_btn.click( | |
| respond, | |
| inputs=[msg, chatbot, state], | |
| outputs=[msg, chatbot, gallery, state, chart_selector] | |
| ).then( | |
| None, None, None, | |
| js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }" | |
| ) | |
| # Chart Management | |
| delete_btn.click( | |
| delete_chart, | |
| inputs=[chart_selector, chatbot, state], | |
| outputs=[chatbot, gallery, state, chart_selector] | |
| ) | |
| dl_zip_btn.click( | |
| download_zip, | |
| inputs=[state], | |
| outputs=[dl_file] | |
| ) | |
| dl_report_btn.click( | |
| download_report, | |
| inputs=[state], | |
| outputs=[dl_file] | |
| ) | |
| global_clear_btn.click( | |
| global_clear, | |
| inputs=[], | |
| outputs=[file_upload, dataset_info, chatbot, gallery, state, chart_selector, dl_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |