Медведев Андрей Васильевич commited on
Commit
ac776ac
·
0 Parent(s):

init commit

Browse files
Files changed (10) hide show
  1. .gitignore +34 -0
  2. LICENSE +21 -0
  3. README.md +66 -0
  4. agent/agent.py +172 -0
  5. app.py +11 -0
  6. mcp_tools/client.py +52 -0
  7. mcp_tools/server.py +136 -0
  8. requirements.txt +11 -0
  9. run.bat +5 -0
  10. ui/app.py +438 -0
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ env/
8
+ venv/
9
+ .env
10
+ .env.*
11
+ .venv
12
+ pip-log.txt
13
+ pip-delete-this-directory.txt
14
+ .tox/
15
+ .coverage
16
+ .coverage.*
17
+ .cache
18
+ nosetests.xml
19
+ coverage.xml
20
+ *.cover
21
+ *.log
22
+ .pytest_cache/
23
+ .mypy_cache/
24
+
25
+ # VS Code
26
+ .vscode/
27
+
28
+ # Project specific
29
+ *.parquet
30
+ *.png
31
+ output.png
32
+
33
+ # Logs
34
+ dataviz_agent.log
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 DataViz Agent
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📊 DataViz Agent — MCP-Powered Data Analyst
2
+
3
+ **DataViz Agent** is an intelligent data analyst that turns your CSV/Excel files into beautiful charts through conversation.
4
+
5
+ The project demonstrates the power of **Model Context Protocol (MCP)**: The UI communicates with an isolated tool server via a standard protocol, ensuring security and flexibility.
6
+
7
+ 🚀 **Demo for Hugging Face Hackathon**
8
+
9
+ ---
10
+
11
+ ## ✨ Key Features
12
+
13
+ * **🗣️ Chat with Data**: Just ask "Plot a histogram of age" or "Show correlation between salary and experience".
14
+ * **🛡️ Sandboxed Execution**: Chart generation code runs in isolated temporary processes. Direct system access is blocked.
15
+ * **🔌 MCP Architecture**: The application is split into Client (UI) and Server (Tools), communicating via the MCP standard (Stdio).
16
+ * **📈 Interactive Gallery**: All charts are saved, have IDs, and can be modified ("Make chart #2 green").
17
+ * **📦 Export**: Download charts as an archive (ZIP) or a ready-made report (Word).
18
+
19
+ ---
20
+
21
+ ## 🛠 Tech Stack
22
+
23
+ * **UI**: Gradio (Async)
24
+ * **LLM**: Gemini 2.0 Flash (via Google GenAI)
25
+ * **Protocol**: Model Context Protocol (MCP) Python SDK
26
+ * **Data**: Pandas, Matplotlib, Seaborn
27
+ * **Security**: `tempfile` isolation, `ast` validation, `matplotlib` Agg backend
28
+
29
+ ---
30
+
31
+ ## 🚀 How to Run
32
+
33
+ ### Locally
34
+
35
+ 1. Clone the repository.
36
+ 2. Create a `.env` file with your key: `GEMINI_API_KEY=your_key`
37
+ 3. Install dependencies:
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ ```
41
+ 4. Run the application:
42
+ ```bash
43
+ python app.py
44
+ ```
45
+
46
+ ### Hugging Face Spaces
47
+
48
+ The project is fully ready for deployment on HF Spaces (SDK: Gradio).
49
+ Just add `GEMINI_API_KEY` to secrets (Settings -> Variables and secrets).
50
+
51
+ ---
52
+
53
+ ## 📂 Project Structure
54
+
55
+ ```text
56
+ /
57
+ ├── app.py # Entry point (for HF Spaces)
58
+ ├── agent/ # LLM Agent logic
59
+ ├── mcp_tools/
60
+ │ ├── server.py # MCP Server (visualization tools)
61
+ │ └── client.py # MCP Client (connection to server)
62
+ ├── ui/ # Gradio Interface
63
+ └── requirements.txt # Dependencies
64
+ ```
65
+
66
+
agent/agent.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import google.generativeai as genai
3
+ from dotenv import load_dotenv
4
+ import re
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configure Gemini
15
+ api_key = os.getenv("GEMINI_API_KEY")
16
+ if api_key:
17
+ genai.configure(api_key=api_key)
18
+
19
+ class DataVizAgent:
20
+ def __init__(self):
21
+ if not api_key:
22
+ raise ValueError("GEMINI_API_KEY not found. Please check your .env file.")
23
+ self.model = genai.GenerativeModel('gemini-2.0-flash') # Using a fast model
24
+
25
+ def generate_plot_code(self, user_query, columns_summary, history=None, existing_code=None):
26
+ """
27
+ Generates Python code for plotting based on user query and dataset summary.
28
+ Can also respond conversationally without generating code.
29
+
30
+ Args:
31
+ user_query: User's message
32
+ columns_summary: Dataset column information
33
+ history: Chat history for context (list of dicts with 'role' and 'content')
34
+ existing_code: Code from existing chart to modify
35
+
36
+ Returns:
37
+ dict with 'type' ('code' or 'message') and 'content'
38
+ """
39
+
40
+ summary_str = "Dataset Columns:\n"
41
+ for col in columns_summary.get("columns", []):
42
+ summary_str += f"- {col['name']} ({col['type']})"
43
+ if col.get('is_numeric') and col.get('min') is not None:
44
+ summary_str += f", range: [{col.get('min')}, {col.get('max')}]"
45
+ summary_str += f", unique values: {col.get('unique_values')}\n"
46
+
47
+ system_prompt = f"""
48
+ You are an expert Data Visualization Assistant. You have access to a pandas DataFrame named `df`.
49
+
50
+ {summary_str}
51
+
52
+ YOUR CAPABILITIES:
53
+ 1. **Conversational Mode**: Answer questions, provide suggestions, explain concepts about data visualization
54
+ 2. **Code Generation Mode**: Generate Python code for creating visualizations
55
+
56
+ WHEN TO USE EACH MODE:
57
+ - Use CONVERSATIONAL mode when user:
58
+ * Asks for suggestions or advice (e.g., "What graphs can I build?", "What would you recommend?")
59
+ * Asks questions about the data (e.g., "What columns do I have?")
60
+ * Wants explanations or discussions
61
+ * Greets you or makes small talk
62
+
63
+ - Use CODE GENERATION mode when user:
64
+ * Explicitly requests a visualization (e.g., "Create a histogram", "Show distribution", "Plot X vs Y")
65
+ * Asks to modify an existing chart
66
+ * Uses visualization-related verbs (plot, show, draw, create, build, visualize)
67
+
68
+ CODE GENERATION RULES (only when generating code):
69
+ 1. The DataFrame `df` is ALREADY LOADED and available with the columns listed above.
70
+ 2. Import pandas as pd, matplotlib.pyplot as plt, and seaborn as sns at the start of your code.
71
+ 3. MANDATORY: Save the plot to a file named 'plot.png' in the current directory:
72
+ ```python
73
+ plt.savefig('plot.png')
74
+ ```
75
+ 4. Do NOT use `plt.show()`.
76
+ 5. Create clear plots with proper titles, labels, and legends.
77
+ 6. Handle potential NaN or missing values appropriately.
78
+ 7. If modifying an existing chart, I will provide the existing code - update it to match the new request.
79
+
80
+ OUTPUT FORMAT:
81
+ - For CONVERSATIONAL responses: Reply naturally in plain text, no code blocks
82
+ - For CODE GENERATION: Output ONLY Python code wrapped in ```python ... ``` markdown block
83
+
84
+ EXAMPLES:
85
+ User: "What visualizations would you suggest for this data?"
86
+ Assistant: "Based on your dataset, here are some visualization ideas:
87
+ 1. Distribution plots for numerical columns like [column names]
88
+ 2. Count plots for categorical data
89
+ 3. Correlation heatmaps if you want to see relationships between variables
90
+ 4. Scatter plots to explore relationships between specific pairs of columns
91
+ What interests you most?"
92
+
93
+ User: "Create a histogram of age"
94
+ Assistant: ```python
95
+ import pandas as pd
96
+ import matplotlib.pyplot as plt
97
+ import seaborn as sns
98
+
99
+ plt.figure(figsize=(10, 6))
100
+ plt.hist(df['age'].dropna(), bins=30, edgecolor='black')
101
+ plt.title('Distribution of Age')
102
+ plt.xlabel('Age')
103
+ plt.ylabel('Frequency')
104
+ plt.grid(alpha=0.3)
105
+ plt.savefig('plot.png')
106
+ ```
107
+ """
108
+
109
+ messages = []
110
+
111
+ # Add chat history for context
112
+ if history:
113
+ for msg in history:
114
+ role = "user" if msg["role"] == "user" else "model"
115
+ messages.append({"role": role, "parts": [msg["content"]]})
116
+
117
+ # Add system prompt and current query
118
+ if existing_code:
119
+ current_message = f"{system_prompt}\n\nExisting Code:\n```python\n{existing_code}\n```\n\nUser Request: {user_query}"
120
+ else:
121
+ current_message = f"{system_prompt}\n\nUser Request: {user_query}"
122
+
123
+ messages.append({"role": "user", "parts": [current_message]})
124
+
125
+ try:
126
+ logger.info(f"Generating response for: {user_query}")
127
+ response = self.model.generate_content(messages)
128
+ response_text = response.text
129
+
130
+ # Check if response contains code
131
+ if "```python" in response_text or "```\n" in response_text:
132
+ code = self._extract_code(response_text)
133
+ logger.info("Generated code response")
134
+ return {"type": "code", "content": code}
135
+ else:
136
+ logger.info("Generated conversational response")
137
+ return {"type": "message", "content": response_text}
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error generating response: {str(e)}")
141
+ return {"type": "error", "content": f"Error generating response: {str(e)}"}
142
+
143
+ def _extract_code(self, text):
144
+ """
145
+ Extracts python code from markdown code blocks.
146
+ """
147
+ match = re.search(r'```python\n(.*?)\n```', text, re.DOTALL)
148
+ if match:
149
+ return match.group(1)
150
+
151
+ # Fallback: try finding any code block
152
+ match = re.search(r'```\n(.*?)\n```', text, re.DOTALL)
153
+ if match:
154
+ return match.group(1)
155
+
156
+ return text # Return raw text if no code block found (might be an error message or direct code)
157
+
158
+ def describe_chart(self, user_query, code):
159
+ """
160
+ Generates a short description/title for the chart.
161
+ """
162
+ prompt = f"""
163
+ Based on the user query: "{user_query}" and the generated code, provide a short, descriptive title for this chart (max 10 words).
164
+ Code:
165
+ {code}
166
+ """
167
+ try:
168
+ response = self.model.generate_content(prompt)
169
+ return response.text.strip()
170
+ except:
171
+ return "Chart"
172
+
app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add current directory to path so we can import from agent, mcp_tools, ui
5
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ # Import the demo object from ui/app.py
8
+ from ui.app import demo
9
+
10
+ if __name__ == "__main__":
11
+ demo.launch()
mcp_tools/client.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import json
5
+ from contextlib import asynccontextmanager
6
+ from mcp import ClientSession, StdioServerParameters
7
+ from mcp.client.stdio import stdio_client
8
+
9
+ class DataVizClient:
10
+ def __init__(self):
11
+ # Determine the path to the server script
12
+ current_dir = os.path.dirname(os.path.abspath(__file__))
13
+ self.server_script = os.path.join(current_dir, "server.py")
14
+
15
+ # Server launch parameters (python mcp_tools/server.py)
16
+ self.server_params = StdioServerParameters(
17
+ command=sys.executable, # Use the same python as the main app
18
+ args=[self.server_script],
19
+ env=None
20
+ )
21
+
22
+ @asynccontextmanager
23
+ async def connect(self):
24
+ # Start server and connect to it
25
+ async with stdio_client(self.server_params) as (read, write):
26
+ async with ClientSession(read, write) as session:
27
+ yield session
28
+
29
+ async def generate_plot(self, code: str, data_path: str = None):
30
+ """
31
+ Calls the 'run_plot_code' tool via MCP protocol
32
+ """
33
+ async with self.connect() as session:
34
+ # Initialization (handshake)
35
+ await session.initialize()
36
+
37
+ # Call tool
38
+ result = await session.call_tool(
39
+ "run_plot_code",
40
+ arguments={"code": code, "data_path": data_path}
41
+ )
42
+
43
+ # Parse result
44
+ if not result.content:
45
+ return {"success": False, "error": "Empty response from MCP server"}
46
+
47
+ # FastMCP returns JSON string inside TextContent
48
+ try:
49
+ text_content = result.content[0].text
50
+ return json.loads(text_content)
51
+ except Exception as e:
52
+ return {"success": False, "error": f"Failed to parse MCP response: {e}", "raw": str(result.content)}
mcp_tools/server.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mcp.server.fastmcp import FastMCP
2
+ import subprocess
3
+ import os
4
+ import tempfile
5
+ import base64
6
+ import sys
7
+ import ast
8
+ import logging
9
+
10
+ # Configure logging to file to avoid interfering with Stdio
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ filename='mcp_server.log',
14
+ filemode='a',
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Initialize FastMCP server
20
+ mcp = FastMCP("DataViz Tools")
21
+
22
+ # Whitelist of allowed imports for security
23
+ ALLOWED_IMPORTS = {
24
+ 'pandas', 'pd',
25
+ 'matplotlib', 'pyplot', 'plt',
26
+ 'seaborn', 'sns',
27
+ 'numpy', 'np',
28
+ 'warnings',
29
+ 'math',
30
+ 'datetime',
31
+ 'collections',
32
+ }
33
+
34
+ def validate_code_safety(code: str) -> tuple[bool, str]:
35
+ """
36
+ Validates Python code for security risks.
37
+ Returns (is_safe, error_message)
38
+ """
39
+ try:
40
+ tree = ast.parse(code)
41
+ except SyntaxError as e:
42
+ return False, f"Syntax error: {str(e)}"
43
+
44
+ for node in ast.walk(tree):
45
+ # Check imports
46
+ if isinstance(node, ast.Import):
47
+ for alias in node.names:
48
+ module_name = alias.name.split('.')[0]
49
+ if module_name not in ALLOWED_IMPORTS:
50
+ return False, f"Import '{alias.name}' is not allowed for security reasons"
51
+
52
+ elif isinstance(node, ast.ImportFrom):
53
+ if node.module:
54
+ module_name = node.module.split('.')[0]
55
+ if module_name not in ALLOWED_IMPORTS:
56
+ return False, f"Import from '{node.module}' is not allowed for security reasons"
57
+
58
+ # Check for dangerous functions
59
+ elif isinstance(node, ast.Call):
60
+ if isinstance(node.func, ast.Name):
61
+ # Block dangerous built-in functions
62
+ dangerous_funcs = {'eval', 'exec', 'compile', '__import__', 'open'}
63
+ if node.func.id in dangerous_funcs:
64
+ return False, f"Function '{node.func.id}' is not allowed for security reasons"
65
+
66
+ return True, ""
67
+
68
+ @mcp.tool()
69
+ def run_plot_code(code: str, data_path: str = None) -> dict:
70
+ """
71
+ Executes Python code to generate a plot.
72
+ The code should use matplotlib/seaborn and save the figure to 'plot.png'.
73
+
74
+ Args:
75
+ code: The Python code to execute.
76
+ data_path: Optional path to a dataset file (csv, xlsx, parquet) to load as 'df' before execution.
77
+
78
+ Returns:
79
+ A dictionary containing success status, base64 encoded image, stdout, and stderr.
80
+ """
81
+
82
+ # Create a temporary directory for execution
83
+ with tempfile.TemporaryDirectory() as temp_dir:
84
+ script_path = os.path.join(temp_dir, 'script.py')
85
+ plot_path = os.path.join(temp_dir, 'plot.png')
86
+
87
+ # Prepare the script content
88
+ script_content = "import matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n"
89
+ if data_path:
90
+ # Inject data loading code
91
+ # Use raw string for path and forward slashes to avoid escape issues
92
+ safe_data_path = data_path.replace('\\', '/')
93
+ if data_path.endswith('.csv'):
94
+ script_content += f"import pandas as pd\ndf = pd.read_csv(r'{safe_data_path}')\n"
95
+ elif data_path.endswith('.xlsx'):
96
+ script_content += f"import pandas as pd\ndf = pd.read_excel(r'{safe_data_path}')\n"
97
+ elif data_path.endswith('.parquet'):
98
+ script_content += f"import pandas as pd\ndf = pd.read_parquet(r'{safe_data_path}')\n"
99
+
100
+ script_content += code
101
+
102
+ # Write the script
103
+ with open(script_path, 'w', encoding='utf-8') as f:
104
+ f.write(script_content)
105
+
106
+ try:
107
+ # Run the script in the temporary directory
108
+ result = subprocess.run(
109
+ [sys.executable, script_path],
110
+ capture_output=True,
111
+ text=True,
112
+ cwd=temp_dir,
113
+ timeout=120,
114
+ stdin=subprocess.DEVNULL
115
+ )
116
+
117
+ if result.returncode != 0:
118
+ return {"success": False, "error": result.stderr, "logs": result.stdout}
119
+
120
+ # Check if plot was created
121
+ if os.path.exists(plot_path):
122
+ with open(plot_path, "rb") as img_file:
123
+ b64_img = base64.b64encode(img_file.read()).decode('utf-8')
124
+ return {"success": True, "image": b64_img, "logs": result.stdout}
125
+ else:
126
+ return {"success": False, "error": "Plot file 'plot.png' was not created.", "logs": result.stdout}
127
+
128
+ except Exception as e:
129
+ return {"success": False, "error": str(e)}
130
+
131
+ if __name__ == "__main__":
132
+ try:
133
+ mcp.run()
134
+ except Exception as e:
135
+ logger.critical(f"Server failed to start: {e}", exc_info=True)
136
+ sys.exit(1)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ mcp
3
+ pandas
4
+ matplotlib
5
+ seaborn
6
+ google-generativeai
7
+ python-dotenv
8
+ openpyxl
9
+ uvicorn
10
+ pyarrow
11
+ python-docx
run.bat ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ @echo off
2
+ call env\Scripts\activate
3
+ set PYTHONPATH=%CD%
4
+ python -m app
5
+ pause
ui/app.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ import tempfile
5
+ import re
6
+ import base64
7
+ import io
8
+ import zipfile
9
+ import logging
10
+ import asyncio
11
+ from PIL import Image
12
+ from docx import Document
13
+ from docx.shared import Inches
14
+ from agent.agent import DataVizAgent
15
+ from mcp_tools.client import DataVizClient
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.FileHandler('dataviz_agent.log'),
23
+ logging.StreamHandler()
24
+ ]
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Initialize Agent
29
+ agent = DataVizAgent()
30
+ # Initialize MCP Client
31
+ mcp_client = DataVizClient()
32
+
33
+ def b64_to_pil(b64_str):
34
+ return Image.open(io.BytesIO(base64.b64decode(b64_str)))
35
+
36
+ def analyze_dataset(file_path):
37
+ """
38
+ Analyzes the dataset and returns a summary and the dataframe.
39
+ """
40
+ if file_path is None:
41
+ return None, "No file uploaded."
42
+
43
+ try:
44
+ if file_path.endswith('.csv'):
45
+ df = pd.read_csv(file_path)
46
+ elif file_path.endswith('.xlsx'):
47
+ df = pd.read_excel(file_path)
48
+ else:
49
+ return None, "Unsupported file format. Please upload CSV or Excel."
50
+
51
+ # Validate dataset
52
+ if df.empty:
53
+ return None, "Error: The uploaded file is empty."
54
+
55
+ if len(df.columns) == 0:
56
+ return None, "Error: No columns found in the dataset."
57
+
58
+ if len(df) > 1000000:
59
+ return None, "Error: Dataset is too large (>1M rows). Please use a smaller file."
60
+
61
+ except Exception as e:
62
+ return None, f"Error loading file: {str(e)}"
63
+
64
+ summary = {
65
+ "columns": [],
66
+ "row_count": len(df)
67
+ }
68
+
69
+ for col in df.columns:
70
+ col_info = {
71
+ "name": col,
72
+ "type": str(df[col].dtype),
73
+ "unique_values": df[col].nunique(),
74
+ "missing_values": df[col].isnull().sum()
75
+ }
76
+ if pd.api.types.is_numeric_dtype(df[col]):
77
+ try:
78
+ min_val = df[col].min()
79
+ max_val = df[col].max()
80
+ col_info["min"] = float(min_val) if pd.notna(min_val) else None
81
+ col_info["max"] = float(max_val) if pd.notna(max_val) else None
82
+ except (ValueError, TypeError):
83
+ col_info["min"] = None
84
+ col_info["max"] = None
85
+ col_info["is_numeric"] = True
86
+ else:
87
+ col_info["is_numeric"] = False
88
+
89
+ summary["columns"].append(col_info)
90
+
91
+ return df, summary
92
+
93
+ def process_upload(file):
94
+ logger.info(f"Processing file upload: {file.name}")
95
+ df, summary = analyze_dataset(file.name)
96
+ if df is None:
97
+ logger.error(f"Failed to load file: {file.name}")
98
+ return None, {}, "Error loading file.", None
99
+
100
+ # Save dataframe to a temporary parquet file for the MCP tool
101
+ fd, path = tempfile.mkstemp(suffix='.parquet')
102
+ os.close(fd)
103
+ df.to_parquet(path)
104
+ logger.info(f"Dataset saved to temp file: {path}")
105
+
106
+ # Create a readable summary string
107
+ summary_str = f"Dataset Loaded: {len(df)} rows, {len(df.columns)} columns.\n\nColumns:\n"
108
+ for col in summary["columns"]:
109
+ summary_str += f"- {col['name']} ({col['type']}): {col['unique_values']} unique"
110
+ if col['is_numeric'] and col.get('min') is not None and col.get('max') is not None:
111
+ summary_str += f", range: [{col['min']:.2f}, {col['max']:.2f}]"
112
+ summary_str += "\n"
113
+
114
+ return df, summary, summary_str, path
115
+
116
+ async def respond(message, chat_history, state):
117
+ logger.info(f"User message: {message}")
118
+ if state["dataframe"] is None:
119
+ logger.warning("User attempted to chat without uploading dataset")
120
+ chat_history.append({"role": "user", "content": message})
121
+ chat_history.append({"role": "assistant", "content": "Please upload a dataset first."})
122
+ return "", chat_history, gr.update(), state, gr.update(choices=[])
123
+
124
+ # Check for chart modification request
125
+ chart_id_match = re.search(r'#(\d+)', message)
126
+ existing_code = None
127
+ target_chart_id = None
128
+
129
+ if chart_id_match:
130
+ chart_id = int(chart_id_match.group(1))
131
+ if chart_id in state["charts"]:
132
+ existing_code = state["charts"][chart_id]["code"]
133
+ target_chart_id = chart_id
134
+ logger.info(f"Modifying chart #{chart_id}")
135
+ else:
136
+ chat_history.append({"role": "user", "content": message})
137
+ chat_history.append({"role": "assistant", "content": f"Chart #{chart_id} not found."})
138
+ return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
139
+
140
+ # Generate response using Agent (with chat history)
141
+ response = agent.generate_plot_code(
142
+ message,
143
+ state["columns_summary"],
144
+ history=chat_history,
145
+ existing_code=existing_code
146
+ )
147
+
148
+ chat_history.append({"role": "user", "content": message})
149
+
150
+ # Check response type
151
+ if response["type"] == "error":
152
+ logger.error(f"Agent error: {response['content']}")
153
+ chat_history.append({"role": "assistant", "content": f"Error: {response['content']}"})
154
+ return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
155
+
156
+ elif response["type"] == "message":
157
+ # Conversational response - no code to execute
158
+ logger.info("Agent provided conversational response")
159
+ chat_history.append({"role": "assistant", "content": response["content"]})
160
+ return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
161
+
162
+ elif response["type"] == "code":
163
+ # Code generation - execute it
164
+ code = response["content"]
165
+ logger.info("Executing generated code")
166
+
167
+ # Execute code using MCP Tool
168
+ result = await mcp_client.generate_plot(code, state["data_path"])
169
+
170
+ gallery_update = _get_gallery_items(state)
171
+
172
+ if result["success"]:
173
+ # Determine Chart ID
174
+ if target_chart_id:
175
+ cid = target_chart_id
176
+ action = "Updated"
177
+ else:
178
+ cid = state["next_chart_id"]
179
+ state["next_chart_id"] += 1
180
+ action = "Created"
181
+
182
+ # Generate description
183
+ description = agent.describe_chart(message, code)
184
+
185
+ # Update State
186
+ state["charts"][cid] = {
187
+ "code": code,
188
+ "image": result["image"],
189
+ "description": description
190
+ }
191
+
192
+ response_text = f"{action} chart #{cid}: {description}"
193
+ chat_history.append({"role": "assistant", "content": response_text})
194
+ logger.info(f"{action} chart #{cid}")
195
+
196
+ gallery_update = _get_gallery_items(state, selected_cid=cid)
197
+ else:
198
+ error_details = result.get('stderr', result.get('error', 'Unknown error occurred'))
199
+ error_msg = f"Failed to generate chart.\nError: {error_details}\n\nCode:\n```python\n{code}\n```"
200
+ chat_history.append({"role": "assistant", "content": error_msg})
201
+ logger.error(f"Chart generation failed: {error_details}")
202
+
203
+ return "", chat_history, gallery_update, state, _get_chart_choices(state)
204
+
205
+ # Fallback
206
+ return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
207
+
208
+ def _get_gallery_items(state, selected_cid=None):
209
+ items = []
210
+ selected_index = None
211
+ current_idx = 0
212
+ # Sort by ID
213
+ for cid in sorted(state["charts"].keys()):
214
+ chart = state["charts"][cid]
215
+ if chart["image"]:
216
+ img = b64_to_pil(chart["image"])
217
+ items.append((img, f"#{cid} {chart['description']}"))
218
+
219
+ if selected_cid is not None and cid == selected_cid:
220
+ selected_index = current_idx
221
+
222
+ current_idx += 1
223
+
224
+ if selected_cid is not None:
225
+ return gr.update(value=items, selected_index=selected_index)
226
+
227
+ return items
228
+
229
+ def _get_chart_choices(state):
230
+ return gr.update(choices=[f"#{cid}" for cid in sorted(state["charts"].keys())])
231
+
232
+ def delete_chart(chart_str, chat_history, state):
233
+ if not chart_str:
234
+ return chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
235
+
236
+ try:
237
+ cid = int(chart_str.replace("#", ""))
238
+ if cid in state["charts"]:
239
+ del state["charts"][cid]
240
+ chat_history.append({"role": "assistant", "content": f"🗑️ Chart #{cid} has been deleted."})
241
+ except:
242
+ pass
243
+
244
+ return chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
245
+
246
+ def download_zip(state):
247
+ if not state["charts"]:
248
+ return None
249
+
250
+ zip_filename = tempfile.mktemp(suffix=".zip")
251
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
252
+ for cid, chart in state["charts"].items():
253
+ if chart["image"]:
254
+ img_data = base64.b64decode(chart["image"])
255
+ zipf.writestr(f"chart_{cid}.png", img_data)
256
+
257
+ return zip_filename
258
+
259
+ def download_report(state):
260
+ if not state["charts"]:
261
+ return None
262
+
263
+ doc = Document()
264
+ doc.add_heading('DataViz Agent Report', 0)
265
+
266
+ for cid in sorted(state["charts"].keys()):
267
+ chart = state["charts"][cid]
268
+ if chart["image"]:
269
+ doc.add_heading(f"Chart #{cid}: {chart['description']}", level=1)
270
+
271
+ # Save temp image for docx
272
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
273
+ tmp_img.write(base64.b64decode(chart["image"]))
274
+ tmp_img_path = tmp_img.name
275
+
276
+ try:
277
+ doc.add_picture(tmp_img_path, width=Inches(6))
278
+ finally:
279
+ os.remove(tmp_img_path)
280
+
281
+ doc.add_paragraph(f"Code:\n{chart['code']}")
282
+ doc.add_page_break()
283
+
284
+ doc_filename = tempfile.mktemp(suffix=".docx")
285
+ doc.save(doc_filename)
286
+ return doc_filename
287
+
288
+ def global_clear():
289
+ logger.info("Global clear initiated")
290
+ new_state = {
291
+ "dataframe": None,
292
+ "columns_summary": {},
293
+ "charts": {},
294
+ "next_chart_id": 1,
295
+ "data_path": None
296
+ }
297
+ return (
298
+ None, # File
299
+ "Upload a dataset to get started.", # Info
300
+ [], # Chat
301
+ [], # Gallery
302
+ new_state, # State
303
+ gr.update(choices=[]), # Dropdown
304
+ None # Download File
305
+ )
306
+
307
+ with gr.Blocks(title="DataViz Agent", theme=gr.themes.Soft(), fill_height=True) as demo:
308
+ state = gr.State({
309
+ "dataframe": None,
310
+ "columns_summary": {},
311
+ "charts": {},
312
+ "next_chart_id": 1,
313
+ "data_path": None
314
+ })
315
+
316
+ with gr.Row():
317
+ gr.Markdown("## 🤖 DataViz Agent Chat")
318
+ gr.Markdown("## 📊 Charts Gallery")
319
+
320
+ with gr.Row():
321
+ with gr.Column(scale=3):
322
+ with gr.Row():
323
+ with gr.Group():
324
+ file_upload = gr.File(label="Upload Dataset (CSV/XLSX)", file_types=[".csv", ".xlsx"])
325
+ with gr.Accordion("Dataset Info", open=False):
326
+ dataset_info = gr.Markdown("Upload a dataset to get started.")
327
+
328
+ with gr.Row(scale=1, height=700):
329
+ chatbot = gr.Chatbot(type="messages", height=700)
330
+
331
+ with gr.Row(height=50, equal_height=True):
332
+ msg = gr.Textbox(
333
+ placeholder="Ask to visualize data (e.g., 'Show distribution of age')",
334
+ show_label=False,
335
+ elem_id="chat-input",
336
+ lines=1,
337
+ max_lines=1,
338
+ scale=1
339
+ )
340
+ send_btn = gr.Button("Send", variant="primary", scale=0)
341
+
342
+ with gr.Column(scale=2):
343
+ with gr.Row(height=626):
344
+ gallery = gr.Gallery(label="Generated Charts", columns=1, object_fit="contain", height=626)
345
+
346
+ with gr.Row():
347
+ with gr.Group():
348
+ gr.Markdown("### Manage Charts")
349
+ with gr.Row():
350
+ chart_selector = gr.Dropdown(label="Select Chart to Delete", choices=[])
351
+ delete_btn = gr.Button("🗑️ Delete Chart", variant="stop")
352
+
353
+ with gr.Row():
354
+ dl_zip_btn = gr.Button("💾 Download All (ZIP)")
355
+ dl_report_btn = gr.Button("📄 Download Report (Word)")
356
+
357
+ with gr.Row(height=80):
358
+ dl_file = gr.File(label="Download", visible=True)
359
+
360
+ # Global Clear (Bottom)
361
+ with gr.Row():
362
+ global_clear_btn = gr.Button("Global Clear (Reset All)", variant="stop")
363
+
364
+ # Event Handlers
365
+ def on_file_upload(file, current_state):
366
+ if file is None:
367
+ return current_state, "Upload a dataset to get started."
368
+
369
+ df, summary, summary_str, path = process_upload(file)
370
+ if df is not None:
371
+ current_state["dataframe"] = df
372
+ current_state["columns_summary"] = summary
373
+ current_state["data_path"] = path
374
+ return current_state, summary_str
375
+ return current_state, summary_str
376
+
377
+ def on_file_upload_wrapper(file, current_state):
378
+ # Clean up old temporary file if exists
379
+ if current_state.get("data_path") and os.path.exists(current_state["data_path"]):
380
+ try:
381
+ os.remove(current_state["data_path"])
382
+ logger.info(f"Cleaned up old temp file: {current_state['data_path']}")
383
+ except Exception as e:
384
+ logger.warning(f"Failed to remove temp file: {e}")
385
+ return on_file_upload(file, current_state)
386
+
387
+ file_upload.change(
388
+ on_file_upload_wrapper,
389
+ inputs=[file_upload, state],
390
+ outputs=[state, dataset_info]
391
+ )
392
+
393
+ # Chat interactions
394
+ msg.submit(
395
+ respond,
396
+ inputs=[msg, chatbot, state],
397
+ outputs=[msg, chatbot, gallery, state, chart_selector]
398
+ ).then(
399
+ None, None, None,
400
+ js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }"
401
+ )
402
+
403
+ send_btn.click(
404
+ respond,
405
+ inputs=[msg, chatbot, state],
406
+ outputs=[msg, chatbot, gallery, state, chart_selector]
407
+ ).then(
408
+ None, None, None,
409
+ js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }"
410
+ )
411
+
412
+ # Chart Management
413
+ delete_btn.click(
414
+ delete_chart,
415
+ inputs=[chart_selector, chatbot, state],
416
+ outputs=[chatbot, gallery, state, chart_selector]
417
+ )
418
+
419
+ dl_zip_btn.click(
420
+ download_zip,
421
+ inputs=[state],
422
+ outputs=[dl_file]
423
+ )
424
+
425
+ dl_report_btn.click(
426
+ download_report,
427
+ inputs=[state],
428
+ outputs=[dl_file]
429
+ )
430
+
431
+ global_clear_btn.click(
432
+ global_clear,
433
+ inputs=[],
434
+ outputs=[file_upload, dataset_info, chatbot, gallery, state, chart_selector, dl_file]
435
+ )
436
+
437
+ if __name__ == "__main__":
438
+ demo.launch()