Janarthanan-Gnanamurthy commited on
Commit
285025c
·
verified ·
1 Parent(s): 05d34db

initial commit

Browse files
Files changed (5) hide show
  1. agents.py +744 -0
  2. gradio_app.py +639 -0
  3. mermaid_graph.png +0 -0
  4. requirements.txt +10 -0
  5. visualize_agent.py +156 -0
agents.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import logging
5
+ import httpx
6
+ import pandas as pd
7
+ import numpy as np
8
+ from typing import Dict, List, TypedDict, Annotated, Literal
9
+ from fastapi import HTTPException
10
+ from langgraph.graph import StateGraph, START, END
11
+ from langgraph.prebuilt import ToolNode
12
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
13
+ from langchain_core.tools import tool
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from scipy import stats
17
+ import warnings
18
+ import io
19
+ import base64
20
+ import tempfile
21
+ from dotenv import load_dotenv
22
+
23
+ # Load environment variables from .env file
24
+ load_dotenv()
25
+
26
+ # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Gemini API configuration
34
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
35
+ GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-preview-05-20")
36
+ GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
37
+
38
+ if not GEMINI_API_KEY:
39
+ raise ValueError("GEMINI_API_KEY environment variable is required")
40
+
41
+ # Define the agent state
42
+ class AgentState(TypedDict):
43
+ messages: Annotated[List[BaseMessage], "The conversation messages"]
44
+ prompt: str
45
+ dataframe: pd.DataFrame
46
+ columns: List[str]
47
+ intent: Dict
48
+ chart_config: Dict
49
+ code: str
50
+ result: Dict
51
+ error: str
52
+ next_action: str
53
+ plot_path: str
54
+
55
+ async def generate(prompt, temperature=0.2, model="gemma3:12b-it-qat"):
56
+ """Generate response using your deployed Ollama API."""
57
+ url = "https://sumansuriya7010--ollama-server3-ollamaserver-serve.modal.run/v1/chat/completions"
58
+
59
+ headers = {
60
+ "Content-Type": "application/json",
61
+ }
62
+
63
+ payload = {
64
+ "model": model,
65
+ "messages": [
66
+ {
67
+ "role": "user",
68
+ "content": prompt
69
+ }
70
+ ],
71
+ "temperature": temperature,
72
+ "max_tokens": 8192,
73
+ "stream": False
74
+ }
75
+
76
+ try:
77
+ async with httpx.AsyncClient(timeout=120.0) as client:
78
+ response = await client.post(
79
+ url,
80
+ json=payload,
81
+ headers=headers
82
+ )
83
+ response.raise_for_status()
84
+ result = response.json()
85
+
86
+ # Extract text from Ollama/OpenAI compatible response
87
+ if "choices" in result and len(result["choices"]) > 0:
88
+ choice = result["choices"][0]
89
+ if "message" in choice and "content" in choice["message"]:
90
+ return choice["message"]["content"]
91
+
92
+ return ""
93
+
94
+ except httpx.HTTPStatusError as e:
95
+ logger.error(f"HTTP error from Ollama API: {e.response.status_code} - {e.response.text}")
96
+ raise HTTPException(status_code=e.response.status_code, detail=f"Ollama API error: {e.response.text}")
97
+ except Exception as e:
98
+ logger.error(f"Error generating response with Ollama: {str(e)}")
99
+ raise HTTPException(status_code=500, detail=f"Error generating response with Ollama: {str(e)}")
100
+
101
+ def create_chart(df: pd.DataFrame, chart_config: Dict) -> str:
102
+ """Create a matplotlib chart and return the base64 encoded image."""
103
+ try:
104
+ plt.style.use('seaborn-v0_8')
105
+ fig, ax = plt.subplots(figsize=(12, 8))
106
+
107
+ chart_type = chart_config.get("chart_type", "bar")
108
+ x_axis = chart_config.get("x_axis")
109
+ y_axis = chart_config.get("y_axis")
110
+ title = chart_config.get("title", "Chart")
111
+ aggregation = chart_config.get("aggregation", "none")
112
+
113
+ # Handle data aggregation if needed
114
+ plot_df = df.copy()
115
+ if aggregation != "none" and x_axis and y_axis:
116
+ if aggregation == "sum":
117
+ plot_df = df.groupby(x_axis)[y_axis].sum().reset_index()
118
+ elif aggregation == "mean":
119
+ plot_df = df.groupby(x_axis)[y_axis].mean().reset_index()
120
+ elif aggregation == "count":
121
+ plot_df = df.groupby(x_axis)[y_axis].count().reset_index()
122
+
123
+ # Create the chart based on type
124
+ if chart_type == "bar":
125
+ if aggregation != "none":
126
+ ax.bar(plot_df[x_axis], plot_df[y_axis])
127
+ else:
128
+ sns.barplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
129
+
130
+ elif chart_type == "line":
131
+ if aggregation != "none":
132
+ ax.plot(plot_df[x_axis], plot_df[y_axis], marker='o')
133
+ else:
134
+ sns.lineplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
135
+
136
+ elif chart_type == "scatter":
137
+ sns.scatterplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
138
+
139
+ elif chart_type == "histogram":
140
+ if x_axis in df.columns:
141
+ ax.hist(df[x_axis].dropna(), bins=30, alpha=0.7)
142
+
143
+ elif chart_type == "boxplot":
144
+ if y_axis and x_axis:
145
+ sns.boxplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
146
+ else:
147
+ ax.boxplot(df.select_dtypes(include=[np.number]).dropna())
148
+
149
+ elif chart_type == "pie":
150
+ if x_axis:
151
+ value_counts = df[x_axis].value_counts()
152
+ ax.pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%')
153
+
154
+ elif chart_type == "area":
155
+ if x_axis and y_axis:
156
+ ax.fill_between(plot_df[x_axis], plot_df[y_axis], alpha=0.7)
157
+
158
+ # Customize the chart
159
+ ax.set_title(title, fontsize=16, fontweight='bold')
160
+ if x_axis and chart_type != "pie":
161
+ ax.set_xlabel(x_axis.replace('_', '').title(), fontsize=12)
162
+ if y_axis and chart_type not in ["pie", "histogram"]:
163
+ ax.set_ylabel(y_axis.replace('_', ' ').title(), fontsize=12)
164
+
165
+ # Rotate x-axis labels if they're long
166
+ if chart_type not in ["pie", "histogram"]:
167
+ plt.xticks(rotation=45, ha='right')
168
+
169
+ plt.tight_layout()
170
+
171
+ # Save to base64
172
+ buffer = io.BytesIO()
173
+ plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight')
174
+ buffer.seek(0)
175
+ image_base64 = base64.b64encode(buffer.read()).decode()
176
+ plt.close(fig)
177
+
178
+ return image_base64
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error creating chart: {str(e)}")
182
+ plt.close('all') # Clean up any open figures
183
+ return None
184
+
185
+ # Agent nodes
186
+ async def analyze_intent_node(state: AgentState) -> AgentState:
187
+ """Analyze the user's prompt to determine intent."""
188
+ prompt = state["prompt"]
189
+ columns = state["columns"]
190
+
191
+ response_format = {
192
+ "intent": "statistical",
193
+ "reason": "Prompt requests statistical analysis",
194
+ "visualization_type": None,
195
+ "transformation_type": None,
196
+ "statistical_type": "correlation"
197
+ }
198
+
199
+ input_text = f"""Analyze the following prompt and determine if it's requesting data transformation, visualization, or statistical analysis:
200
+
201
+ Prompt: {prompt}
202
+ Available columns: {', '.join(columns)}
203
+
204
+ Provide a JSON response with:
205
+ 1. intent: Either 'visualization', 'transformation', or 'statistical'
206
+ 2. reason: Brief explanation of why this classification was chosen
207
+ 3. visualization_type: If intent is 'visualization', specify the chart type ('bar', 'line', 'pie', 'scatter', 'area', 'histogram', 'boxplot')
208
+ 4. transformation_type: If intent is 'transformation', specify the operation type ('aggregate', 'filter', 'join', 'compute', 'sort', 'group')
209
+ 5. statistical_type: If intent is 'statistical', specify the test type ('correlation', 'ttest', 'regression', 'descriptive'),
210
+
211
+ Example response format:
212
+ {json.dumps(response_format)}"""
213
+
214
+ try:
215
+ json_text = await generate(input_text, temperature=0.4)
216
+
217
+ # Try to extract JSON from markdown code blocks if present
218
+ json_match = re.search(r"```(?:json)?\n(.*?)\n```", json_text, re.DOTALL)
219
+ if json_match:
220
+ json_text = json_match.group(1)
221
+
222
+ json_text = json_text.strip()
223
+
224
+ try:
225
+ intent = json.loads(json_text)
226
+ except json.JSONDecodeError:
227
+ # If direct parsing fails, try to extract just the JSON object
228
+ json_obj_match = re.search(r"(\{.*\})", json_text, re.DOTALL)
229
+ if json_obj_match:
230
+ intent = json.loads(json_obj_match.group(1))
231
+ else:
232
+ # Fallback classification based on keywords
233
+ prompt_lower = prompt.lower()
234
+ if any(word in prompt_lower for word in ['chart', 'plot', 'graph', 'visualiz', 'show']):
235
+ intent = {"intent": "visualization", "reason": "Keywords suggest visualization"}
236
+ elif any(word in prompt_lower for word in ['filter', 'transform', 'add', 'modify', 'create column']):
237
+ intent = {"intent": "transformation", "reason": "Keywords suggest transformation"}
238
+ else:
239
+ intent = {"intent": "statistical", "reason": "Default to statistical analysis"}
240
+
241
+ state["intent"] = intent
242
+ state["next_action"] = intent["intent"]
243
+ logger.info(f"Intent analysis result: {intent}")
244
+
245
+ except Exception as e:
246
+ state["error"] = f"Error analyzing prompt intent: {str(e)}"
247
+ state["next_action"] = "error"
248
+ logger.error(f"Error in analyze_intent_node: {str(e)}")
249
+
250
+ return state
251
+
252
+ async def generate_visualization_node(state: AgentState) -> AgentState:
253
+ """Generate visualization configuration and create the chart."""
254
+ prompt = state["prompt"]
255
+ columns = state["columns"]
256
+ df = state["dataframe"]
257
+
258
+ response_format = {
259
+ "chart_type": "bar",
260
+ "x_axis": "date",
261
+ "y_axis": "sales",
262
+ "aggregation": "sum",
263
+ "title": "Total Sales by Date"
264
+ }
265
+
266
+ input_text = f"""Based on the following prompt, determine the appropriate chart configuration:
267
+
268
+ Prompt: {prompt}
269
+ Available columns: {', '.join(columns)}
270
+
271
+ Generate a JSON configuration with:
272
+ 1. chart_type: 'bar', 'line', 'pie', 'scatter', 'area', 'histogram', 'boxplot'
273
+ 2. x_axis: column name for x-axis (choose from available columns)
274
+ 3. y_axis: column name for y-axis (can be None for histograms, choose from available columns)
275
+ 4. aggregation: 'sum', 'mean', 'count', 'none'
276
+ 5. title: descriptive chart title
277
+
278
+ Example response format:
279
+ {json.dumps(response_format)}
280
+
281
+ Provide only the JSON configuration, no explanations."""
282
+
283
+ try:
284
+ json_text = await generate(input_text, temperature=0.5)
285
+
286
+ json_match = re.search(r"```(?:json)?\n(.*?)\n```", json_text, re.DOTALL)
287
+ if json_match:
288
+ json_text = json_match.group(1)
289
+
290
+ json_text = json_text.strip()
291
+
292
+ try:
293
+ chart_config = json.loads(json_text)
294
+ except json.JSONDecodeError:
295
+ json_obj_match = re.search(r"(\{.*\})", json_text, re.DOTALL)
296
+ if json_obj_match:
297
+ chart_config = json.loads(json_obj_match.group(1))
298
+ else:
299
+ # Fallback configuration
300
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
301
+ categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
302
+
303
+ chart_config = {
304
+ "chart_type": "bar",
305
+ "x_axis": categorical_cols[0] if categorical_cols else columns[0],
306
+ "y_axis": numeric_cols[0] if numeric_cols else columns[1] if len(columns) > 1 else None,
307
+ "aggregation": "mean" if numeric_cols else "count",
308
+ "title": "Data Visualization"
309
+ }
310
+
311
+ # Validate column names exist
312
+ if chart_config.get("x_axis") not in columns:
313
+ chart_config["x_axis"] = columns[0]
314
+ if chart_config.get("y_axis") and chart_config["y_axis"] not in columns:
315
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
316
+ chart_config["y_axis"] = numeric_cols[0] if numeric_cols else None
317
+
318
+ state["chart_config"] = chart_config
319
+
320
+ # Create the chart immediately
321
+ image_base64 = create_chart(df, chart_config)
322
+ if image_base64:
323
+ state["result"] = {
324
+ "type": "visualization",
325
+ "chart_type": chart_config["chart_type"],
326
+ "config": chart_config,
327
+ "image": image_base64,
328
+ "message": "Visualization created successfully"
329
+ }
330
+ state["next_action"] = "complete"
331
+ else:
332
+ state["error"] = "Failed to create visualization"
333
+ state["next_action"] = "error"
334
+
335
+ logger.info(f"Generated chart config: {chart_config}")
336
+
337
+ except Exception as e:
338
+ state["error"] = f"Error generating chart configuration: {str(e)}"
339
+ state["next_action"] = "error"
340
+ logger.error(f"Error in generate_visualization_node: {str(e)}")
341
+
342
+ return state
343
+
344
+ async def generate_transformation_node(state: AgentState) -> AgentState:
345
+ """Generate pandas transformation code."""
346
+ prompt = state["prompt"]
347
+ columns = state["columns"]
348
+
349
+ input_text = f"""Write Python code to perform the following pandas DataFrame transformation:
350
+
351
+ {prompt}
352
+
353
+ Available columns: {', '.join(columns)}
354
+
355
+ Pandas Knowledge Base:
356
+ 1. DataFrame Operations:
357
+ - select columns: df[['col1', 'col2']]
358
+ - filter rows: df[df['column'] > value]
359
+ - group data: df.groupby('column')
360
+ - sort data: df.sort_values('column')
361
+ - add/modify columns: df['new_col'] = df['col1'] * 2
362
+ - drop columns: df.drop(['col1'], axis=1)
363
+ - remove duplicates: df.drop_duplicates()
364
+ - merge dataframes: pd.merge(df1, df2)
365
+
366
+ 2. Common Functions:
367
+ - df.apply(): Apply function to columns/rows
368
+ - df.fillna(): Fill missing values
369
+ - df.dropna(): Drop missing values
370
+ - df.replace(): Replace values
371
+ - pd.to_datetime(): Convert to datetime
372
+ - df.astype(): Convert data types
373
+ - df.round(): Round numbers
374
+ - df.sum(), df.mean(), df.count(): Aggregations
375
+
376
+ 3. String Operations:
377
+ - df['col'].str.contains(): String contains
378
+ - df['col'].str.split(): Split strings
379
+ - df['col'].str.replace(): Replace in strings
380
+ - df['col'].str.upper(): Convert to uppercase
381
+
382
+ 4. Window Operations:
383
+ - df.rolling(): Rolling window operations
384
+ - df.shift(): Shift values
385
+ - df.expanding(): Expanding window
386
+
387
+ Requirements:
388
+ 1. Use pandas DataFrame operations
389
+ 2. Handle missing values appropriately
390
+ 3. Store result in 'transformed_df'
391
+ 4. DO NOT define functions
392
+ 5. Return a pandas DataFrame
393
+ 6. Use proper type conversions if needed
394
+
395
+ Available variables:
396
+ - df: pandas DataFrame
397
+ - pd: pandas module
398
+ - np: numpy module
399
+
400
+ Example format:
401
+ ```python
402
+ transformed_df = df.copy()
403
+ transformed_df['new_column'] = df['column1'] * df['column2']
404
+ transformed_df = transformed_df.fillna(0) # Handle nulls
405
+ ```
406
+
407
+ Provide only the code, no explanations. DO NOT DEFINE functions, directly perform the operations on the df."""
408
+
409
+ try:
410
+ code = await generate(input_text, temperature=0.4)
411
+
412
+ code_match = re.search(r"```python\n(.*?)\n```", code, re.DOTALL)
413
+ code = code_match.group(1) if code_match else code
414
+
415
+ state["code"] = code
416
+ state["next_action"] = "execute"
417
+ logger.info(f"Generated transformation code: {code}")
418
+
419
+ except Exception as e:
420
+ state["error"] = f"Error generating transformation code: {str(e)}"
421
+ state["next_action"] = "error"
422
+ logger.error(f"Error in generate_transformation_node: {str(e)}")
423
+
424
+ return state
425
+
426
+ async def generate_statistical_node(state: AgentState) -> AgentState:
427
+ """Generate robust pandas/numpy code for statistical analysis with fallbacks."""
428
+ prompt = state.get("prompt", "")
429
+ print(prompt+" - Prompt received in generate_statistical_node")
430
+
431
+ columns = state.get("columns", [])
432
+ # Use predefined templates based on prompt keywords
433
+ operations = []
434
+ if any(x in prompt.lower() for x in ["describe", "summary"]):
435
+ operations.append("describe")
436
+ if any(x in prompt.lower() for x in ["correlation", "corr"]):
437
+ operations.append("correlation")
438
+ if any(x in prompt.lower() for x in ["ttest", "hypothesis"]):
439
+ operations.append("ttest")
440
+ if not operations:
441
+ operations = ["describe"] # default
442
+
443
+ code_blocks = []
444
+ # Build code blocks robustly
445
+ if "describe" in operations:
446
+ code_blocks.append(
447
+ "# Descriptive statistics\n"
448
+ "desc = df.describe(include='all')\n"
449
+ )
450
+ if "correlation" in operations:
451
+ code_blocks.append(
452
+ "# Correlation for numeric columns\n"
453
+ "num_cols = df.select_dtypes(include=[np.number]).columns.tolist()\n"
454
+ "corr = df[num_cols].corr() if len(num_cols) > 1 else pd.DataFrame()\n"
455
+ )
456
+ if "ttest" in operations and 'category' in columns:
457
+ # safe t-test only if category and value exist
458
+ code_blocks.append(
459
+ "# Independent T-test between two groups in 'category' on 'value' column\n"
460
+ "groups = df['category'].dropna().unique().tolist()[:2]\n"
461
+ "if len(groups) == 2:\n"
462
+ " g1 = df[df['category'] == groups[0]]['value'].dropna()\n"
463
+ " g2 = df[df['category'] == groups[1]]['value'].dropna()\n"
464
+ " t_stat, p_val = stats.ttest_ind(g1, g2, nan_policy='omit')\n"
465
+ "else:\n"
466
+ " t_stat, p_val = None, None\n"
467
+ )
468
+ # Assemble result dict
469
+ code_blocks.append(
470
+ "# Assemble results\n"
471
+ "results = {}\n"
472
+ "if 'desc' in locals(): results['descriptive'] = desc\n"
473
+ "if 'corr' in locals(): results['correlation'] = corr\n"
474
+ "if 't_stat' in locals(): results['ttest'] = {'t_statistic': t_stat, 'p_value': p_val}\n"
475
+ "# Final assignment\n"
476
+ "stat_result = results\n"
477
+ )
478
+
479
+ state['code'] = '\n'.join(code_blocks)
480
+ state['next_action'] = 'execute'
481
+ logger.info(f"Generated statistical code with operations {operations}")
482
+ return state
483
+
484
+
485
+
486
+ async def execute_code_node(state: AgentState) -> AgentState:
487
+ """Execute the generated code safely."""
488
+ code = state["code"]
489
+ df = state["dataframe"]
490
+
491
+ if not code:
492
+ state["error"] = "No code to execute"
493
+ state["next_action"] = "error"
494
+ return state
495
+
496
+ try:
497
+ # Create safe execution environment
498
+ safe_globals = {
499
+ 'df': df,
500
+ 'pd': pd,
501
+ 'np': np,
502
+ 'stats': stats,
503
+ 'plt': plt,
504
+ 'sns': sns
505
+ }
506
+
507
+ # Execute the code
508
+ exec(code, safe_globals)
509
+
510
+ # Extract results based on intent
511
+ intent = state["intent"]["intent"]
512
+
513
+ if intent == "transformation":
514
+ if 'transformed_df' in safe_globals:
515
+ result_df = safe_globals['transformed_df']
516
+ state["result"] = {
517
+ "type": "transformation",
518
+ "shape": result_df.shape,
519
+ "columns": result_df.columns.tolist(),
520
+ "preview": result_df.head(10).to_html(classes='table table-striped'),
521
+ "dataframe": result_df,
522
+ "message": f"Data transformed successfully. New shape: {result_df.shape}"
523
+ }
524
+ else:
525
+ state["error"] = "No 'transformed_df' found in execution result"
526
+
527
+ elif intent == "statistical":
528
+ exec(code, safe_globals)
529
+ stat_result = safe_globals.get('stat_result')
530
+ if stat_result is None:
531
+ raise ValueError("'stat_result' not found after execution")
532
+ if not isinstance(stat_result, dict):
533
+ stat_result = {'result': stat_result}
534
+ formatted = format_statistical_result(stat_result)
535
+ state['result'] = {
536
+ 'type': 'statistical',
537
+ 'data': formatted,
538
+ 'message': 'Statistical analysis completed successfully'
539
+ }
540
+
541
+ state["next_action"] = "complete"
542
+ logger.info("Code executed successfully")
543
+
544
+ except Exception as e:
545
+ state["error"] = f"Error executing code: {str(e)}"
546
+ state["next_action"] = "error"
547
+ logger.error(f"Error in execute_code_node: {str(e)}")
548
+
549
+ return state
550
+
551
+ def format_statistical_result(stat_result) -> str:
552
+ """Format statistical results for display in Gradio."""
553
+ try:
554
+ if isinstance(stat_result, pd.DataFrame):
555
+ return stat_result.to_html(classes='table table-striped')
556
+ elif isinstance(stat_result, dict):
557
+ html_parts = []
558
+ for key, value in stat_result.items():
559
+ html_parts.append(f"<h4>{key.replace('_', ' ').title()}</h4>")
560
+ if isinstance(value, pd.DataFrame):
561
+ html_parts.append(value.to_html(classes='table table-striped'))
562
+ elif isinstance(value, (int, float)):
563
+ html_parts.append(f"<p><strong>{value:.6f}</strong></p>")
564
+ else:
565
+ html_parts.append(f"<p>{str(value)}</p>")
566
+ return ''.join(html_parts)
567
+ else:
568
+ return f"<p><strong>Result:</strong> {str(stat_result)}</p>"
569
+ except Exception as e:
570
+ return f"<p><strong>Error formatting result:</strong> {str(e)}</p>"
571
+
572
+ async def error_handler_node(state: AgentState) -> AgentState:
573
+ """Handle errors and provide feedback."""
574
+ error = state.get("error", "Unknown error occurred")
575
+ logger.error(f"Error in agent workflow: {error}")
576
+
577
+ state["result"] = {
578
+ "type": "error",
579
+ "message": error,
580
+ "suggestions": [
581
+ "Check if the column names are correct",
582
+ "Verify that the data types are appropriate",
583
+ "Ensure the prompt is clear and specific"
584
+ ]
585
+ }
586
+ state["next_action"] = "complete"
587
+ return state
588
+
589
+ def route_based_on_intent(state: AgentState) -> Literal["visualization", "transformation", "statistical", "error"]:
590
+ """Route to appropriate node based on intent analysis."""
591
+ if state.get("error"):
592
+ return "error"
593
+
594
+ intent = state.get("intent", {}).get("intent", "error")
595
+ return intent
596
+
597
+ def route_to_execution(state: AgentState) -> Literal["execute", "error", "complete"]:
598
+ """Route to execution or error handling."""
599
+ if state.get("error"):
600
+ return "error"
601
+
602
+ next_action = state.get("next_action", "error")
603
+ if next_action == "execute":
604
+ return "execute"
605
+ elif next_action == "complete":
606
+ return "complete"
607
+ else:
608
+ return "error"
609
+
610
+ # Build the LangGraph workflow
611
+ def create_data_analysis_agent():
612
+ """Create the data analysis agent using LangGraph."""
613
+
614
+ # Create the state graph
615
+ workflow = StateGraph(AgentState)
616
+
617
+ # Add nodes
618
+ workflow.add_node("analyze_intent", analyze_intent_node)
619
+ workflow.add_node("visualization", generate_visualization_node)
620
+ workflow.add_node("transformation", generate_transformation_node)
621
+ workflow.add_node("statistical", generate_statistical_node)
622
+ workflow.add_node("execute", execute_code_node)
623
+ workflow.add_node("error_handler", error_handler_node)
624
+
625
+ # Add edges
626
+ workflow.add_edge(START, "analyze_intent")
627
+
628
+ # Conditional edges based on intent
629
+ workflow.add_conditional_edges(
630
+ "analyze_intent",
631
+ route_based_on_intent,
632
+ {
633
+ "visualization": "visualization",
634
+ "transformation": "transformation",
635
+ "statistical": "statistical",
636
+ "error": "error_handler"
637
+ }
638
+ )
639
+
640
+ # Route from generation nodes to execution
641
+ workflow.add_conditional_edges(
642
+ "visualization",
643
+ route_to_execution,
644
+ {
645
+ "execute": "execute",
646
+ "complete": END,
647
+ "error": "error_handler"
648
+ }
649
+ )
650
+ workflow.add_conditional_edges(
651
+ "transformation",
652
+ route_to_execution,
653
+ {
654
+ "execute": "execute",
655
+ "complete": END,
656
+ "error": "error_handler"
657
+ }
658
+ )
659
+ workflow.add_conditional_edges(
660
+ "statistical",
661
+ route_to_execution,
662
+ {
663
+ "execute": "execute",
664
+ "complete": END,
665
+ "error": "error_handler"
666
+ }
667
+ )
668
+
669
+ # Final edges
670
+ workflow.add_edge("execute", END)
671
+ workflow.add_edge("error_handler", END)
672
+
673
+ # Compile the graph
674
+ app = workflow.compile()
675
+ return app
676
+
677
+ # Main execution function
678
+ async def analyze_data_with_agent(prompt: str, dataframe: pd.DataFrame) -> Dict:
679
+ """
680
+ Analyze data using the LangGraph agent.
681
+
682
+ Args:
683
+ prompt: Natural language prompt describing the analysis
684
+ dataframe: Pandas DataFrame to analyze
685
+
686
+ Returns:
687
+ Dictionary containing the analysis results
688
+ """
689
+ # Create the agent
690
+ agent = create_data_analysis_agent()
691
+
692
+ # Initialize state
693
+ initial_state = {
694
+ "messages": [HumanMessage(content=prompt)],
695
+ "prompt": prompt,
696
+ "dataframe": dataframe,
697
+ "columns": dataframe.columns.tolist(),
698
+ "intent": {},
699
+ "chart_config": {},
700
+ "code": "",
701
+ "result": {},
702
+ "error": "",
703
+ "next_action": "",
704
+ "plot_path": ""
705
+ }
706
+
707
+ # Run the agent
708
+ try:
709
+ final_state = await agent.ainvoke(initial_state)
710
+ return final_state["result"]
711
+ except Exception as e:
712
+ logger.error(f"Error running agent: {str(e)}")
713
+ return {
714
+ "type": "error",
715
+ "message": f"Agent execution failed: {str(e)}"
716
+ }
717
+
718
+ # Test function
719
+ async def test_agent():
720
+ """Test the data analysis agent."""
721
+ # Create sample data
722
+ data = {
723
+ 'date': pd.date_range('2024-01-01', periods=100),
724
+ 'sales': np.random.normal(1000, 200, 100),
725
+ 'category': np.random.choice(['A', 'B', 'C'], 100),
726
+ 'region': np.random.choice(['North', 'South', 'East', 'West'], 100)
727
+ }
728
+ df = pd.DataFrame(data)
729
+
730
+ # Test different types of prompts
731
+ test_prompts = [
732
+ "Create a bar chart showing average sales by category",
733
+ "Calculate correlation between date and sales",
734
+ "Filter the data to show only category A and add a profit column that is 20% of sales"
735
+ ]
736
+
737
+ for prompt in test_prompts:
738
+ print(f"\n--- Testing: {prompt} ---")
739
+ result = await analyze_data_with_agent(prompt, df)
740
+ print(f"Result: {result}")
741
+
742
+ if __name__ == "__main__":
743
+ import asyncio
744
+ asyncio.run(test_agent())
gradio_app.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import json
4
+ from agents import analyze_data_with_agent
5
+ import io
6
+ import asyncio
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ async def process_data_and_prompt(file, prompt):
14
+ """Process uploaded file and prompt using the data analysis agent."""
15
+ try:
16
+ if not file:
17
+ return "Please upload a data file.", None, None
18
+
19
+ if not prompt or prompt.strip() == "":
20
+ return "Please enter an analysis prompt.", None, None
21
+
22
+ # Read the uploaded file
23
+ if file.name.endswith('.csv'):
24
+ df = pd.read_csv(file.name)
25
+ elif file.name.endswith(('.xlsx', '.xls')):
26
+ df = pd.read_excel(file.name)
27
+ elif file.name.endswith('.json'):
28
+ df = pd.read_json(file.name)
29
+ else:
30
+ return "Error: Unsupported file format. Please upload CSV, Excel, or JSON files.", None, None
31
+
32
+ # Clean column names
33
+ df.columns = [str(col).strip().lower().replace(' ', '_').replace('-', '_') for col in df.columns]
34
+
35
+ # Show data preview
36
+ # data_preview = f"""
37
+ # <div class="data-section">
38
+ # <h3>Data Preview</h3>
39
+ # <p><strong>Shape:</strong> {df.shape[0]} rows × {df.shape[1]} columns</p>
40
+ # <p><strong>Columns:</strong> {', '.join(df.columns.tolist())}</p>
41
+ # {df.head().to_html(classes='table data-table', table_id='data-preview')}
42
+ # </div>
43
+ # """
44
+ data_preview = f"""
45
+ <div></div>"""
46
+
47
+ # Process with agent
48
+ logger.info(f"Processing prompt: {prompt}")
49
+ result = await analyze_data_with_agent(prompt, df)
50
+ logger.info(f"Agent result type: {result.get('type')}")
51
+
52
+ # Handle different result types
53
+ if result["type"] == "error":
54
+ error_html = f"""
55
+ <div class="error-box">
56
+ <h3>Error</h3>
57
+ <p><strong>Message:</strong> {result['message']}</p>
58
+ {f"<p><strong>Suggestions:</strong></p><ul>{''.join([f'<li>{s}</li>' for s in result.get('suggestions', [])])}</ul>" if result.get('suggestions') else ""}
59
+ </div>
60
+ """
61
+ return data_preview + error_html, None, None
62
+
63
+ elif result["type"] == "visualization":
64
+ # Display the chart
65
+ image_base64 = result.get("image")
66
+ if image_base64:
67
+ chart_html = f"""
68
+ <div class="analysis-result">
69
+ <h3>Visualization Result</h3>
70
+ <p><strong>Chart Type:</strong> {result.get('chart_type', 'Unknown').title()}</p>
71
+ <div class="chart-container">
72
+ <img src="data:image/png;base64,{image_base64}" class="chart-image">
73
+ </div>
74
+ <p><em>{result.get('message', 'Visualization created successfully')}</em></p>
75
+ </div>
76
+ """
77
+ return data_preview + chart_html, None, None
78
+ else:
79
+ return data_preview + "<p>Error: Could not generate visualization</p>", None, None
80
+
81
+ elif result["type"] == "statistical":
82
+ # Format statistical results
83
+ stat_html = f"""
84
+ <div class="analysis-result">
85
+ <h3>Statistical Analysis Results</h3>
86
+ <div class="stat-output-box">
87
+ {result.get('data', 'No statistical results available')}
88
+ </div>
89
+ <p><em>{result.get('message', 'Statistical analysis completed')}</em></p>
90
+ </div>
91
+ """
92
+ return data_preview + stat_html, None, None
93
+
94
+ elif result["type"] == "transformation":
95
+ # Return transformed data
96
+ transformed_df = result.get("dataframe")
97
+ if transformed_df is not None:
98
+ # Create CSV for download
99
+ csv_buffer = io.StringIO()
100
+ transformed_df.to_csv(csv_buffer, index=False)
101
+ csv_data = csv_buffer.getvalue()
102
+
103
+ # Create temporary file for download (Gradio handles temporary files for downloads)
104
+ temp_file_name = "transformed_data.csv"
105
+ with open(temp_file_name, 'w', encoding='utf-8') as f:
106
+ f.write(csv_data)
107
+
108
+ transform_html = f"""
109
+ <div class="analysis-result">
110
+ <h3>Data Transformation Results</h3>
111
+ <p><strong>Original Shape:</strong> {df.shape[0]} rows × {df.shape[1]} columns</p>
112
+ <p><strong>New Shape:</strong> {result.get('shape', 'Unknown')}</p>
113
+ <p><strong>New Columns:</strong> {', '.join(result.get('columns', []))}</p>
114
+ <div class="transformed-data-preview">
115
+ <h4>Preview of Transformed Data:</h4>
116
+ {result.get('preview', 'No preview available')}
117
+ </div>
118
+ <p><em>{result.get('message', 'Data transformation completed')}</em></p>
119
+ <p><strong>Download the transformed data using the button below.</strong></p>
120
+ </div>
121
+ """
122
+ return data_preview + transform_html, temp_file_name, None
123
+ else:
124
+ return data_preview + "<p>Error: Could not retrieve transformed data</p>", None, None
125
+
126
+ else:
127
+ return data_preview + f"<p>Unknown result type: {result.get('type')}</p>", None, None
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error processing data: {str(e)}")
131
+ error_html = f"""
132
+ <div class="error-box">
133
+ <h3>Processing Error</h3>
134
+ <p><strong>Error:</strong> {str(e)}</p>
135
+ <p><strong>Please check:</strong></p>
136
+ <ul>
137
+ <li>File format is supported (CSV, Excel)</li>
138
+ <li>File is not corrupted</li>
139
+ <li>Prompt is clear and specific</li>
140
+ <li>Ollama server is running</li>
141
+ </ul>
142
+ </div>
143
+ """
144
+ return error_html, None, None
145
+
146
+ def process_sync(file, prompt):
147
+ """Synchronous wrapper for the async processing function."""
148
+ try:
149
+ # Check if an event loop is already running
150
+ try:
151
+ loop = asyncio.get_running_loop()
152
+ except RuntimeError:
153
+ loop = asyncio.new_event_loop()
154
+ asyncio.set_event_loop(loop)
155
+ return loop.run_until_complete(process_data_and_prompt(file, prompt))
156
+ except Exception as e:
157
+ logger.error(f"Error in sync wrapper: {str(e)}")
158
+ return f"Error: {str(e)}", None, None
159
+
160
+ def generate_preview(file):
161
+ """Generate a preview of the uploaded file."""
162
+ try:
163
+ if not file:
164
+ return "Please upload a data file to see preview."
165
+
166
+ # Read the uploaded file
167
+ if file.name.endswith('.csv'):
168
+ df = pd.read_csv(file.name)
169
+ elif file.name.endswith(('.xlsx', '.xls')):
170
+ df = pd.read_excel(file.name)
171
+ elif file.name.endswith('.json'):
172
+ df = pd.read_json(file.name)
173
+ else:
174
+ return "Error: Unsupported file format. Please upload CSV, Excel, or JSON files."
175
+
176
+ # Clean column names
177
+ df.columns = [str(col).strip().lower().replace(' ', '_').replace('-', '_') for col in df.columns]
178
+
179
+ # Show data preview
180
+ data_preview = f"""
181
+ <div class="data-section">
182
+ <h3>📊 Data Preview</h3>
183
+ <div class="data-stats">
184
+ <span class="stat-badge">📏 {df.shape[0]} rows</span>
185
+ <span class="stat-badge">📋 {df.shape[1]} columns</span>
186
+ </div>
187
+ <div class="columns-info">
188
+ <strong>Columns:</strong> {', '.join(df.columns.tolist())}
189
+ </div>
190
+ <div class="table-container">
191
+ {df.head(4).to_html(classes='table data-table', table_id='data-preview')}
192
+ </div>
193
+ </div>
194
+ """
195
+ return data_preview
196
+ except Exception as e:
197
+ logger.error(f"Error generating preview: {str(e)}")
198
+ return f"<div class='error-box'>Error generating preview: {str(e)}</div>"
199
+
200
+ # Sample prompts for different analysis types
201
+ sample_prompts = {
202
+ "Data Transformation": [
203
+ "Filter data where [column] > 1000 ",
204
+ "Group by [column] and calculate average [values]",
205
+ "Create new columns based on existing ones",
206
+ "Remove duplicates and sort by date",
207
+ ],
208
+ "Visualization": [
209
+ "Create a bar chart showing the distribution of [categories]",
210
+ "Generate a line plot of sales over time",
211
+ "Make a scatter plot of [column1] vs [column2]",
212
+ "Show a histogram of [column2]",
213
+ "Create a pie chart of market share by region"
214
+ ],
215
+ "Statistical Analysis": [
216
+ "Calculate correlation matrix for all numeric columns",
217
+ "Perform descriptive statistics analysis",
218
+ ]
219
+ }
220
+
221
+ # Create the Gradio interface
222
+ with gr.Blocks(
223
+ title="Data Analysis Agent",
224
+ theme=gr.themes.Soft(),
225
+ css="""
226
+ /* Main container */
227
+ .gradio-container {
228
+ max-width: 900px;
229
+ margin: auto;
230
+ padding: 20px;
231
+ }
232
+
233
+ /* Header styling */
234
+ .main-header {
235
+ text-align: center;
236
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
237
+ color: white;
238
+ padding: 30px;
239
+ border-radius: 15px;
240
+ margin-bottom: 30px;
241
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
242
+ }
243
+
244
+ .main-header h1 {
245
+ margin: 0;
246
+ font-size: 2.5em;
247
+ font-weight: 600;
248
+ }
249
+
250
+ .main-header p {
251
+ margin: 10px 0 0 0;
252
+ font-size: 1.1em;
253
+ opacity: 0.9;
254
+ }
255
+
256
+ /* Accordion styling */
257
+ .gr-accordion {
258
+ margin-bottom: 20px !important;
259
+ border-radius: 12px !important;
260
+ border: 1px solid var(--border-color-primary) !important;
261
+ box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important;
262
+ overflow: hidden !important;
263
+ }
264
+
265
+ .gr-accordion-header {
266
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
267
+ color: white !important;
268
+ padding: 15px 20px !important;
269
+ font-weight: 600 !important;
270
+ font-size: 1.1em !important;
271
+ border: none !important;
272
+ cursor: pointer !important;
273
+ transition: all 0.3s ease !important;
274
+ }
275
+
276
+ .gr-accordion-header:hover {
277
+ background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%) !important;
278
+ transform: translateY(-1px) !important;
279
+ }
280
+
281
+ .gr-accordion-content {
282
+ background: var(--background-fill-secondary) !important;
283
+ padding: 25px !important;
284
+ border-top: 1px solid var(--border-color-primary) !important;
285
+ }
286
+
287
+ /* Special styling for example prompt accordions */
288
+ .gr-accordion .gr-accordion {
289
+ margin-bottom: 15px !important;
290
+ border-radius: 8px !important;
291
+ box-shadow: 0 1px 4px rgba(0,0,0,0.1) !important;
292
+ }
293
+
294
+ .gr-accordion .gr-accordion .gr-accordion-header {
295
+ background: var(--color-accent-soft) !important;
296
+ color: var(--text-color-body) !important;
297
+ padding: 12px 16px !important;
298
+ font-size: 1em !important;
299
+ font-weight: 500 !important;
300
+ }
301
+
302
+ .gr-accordion .gr-accordion .gr-accordion-header:hover {
303
+ background: var(--color-accent) !important;
304
+ color: white !important;
305
+ transform: none !important;
306
+ }
307
+
308
+ .gr-accordion .gr-accordion .gr-accordion-content {
309
+ background: var(--background-fill-primary) !important;
310
+ padding: 15px !important;
311
+ }
312
+
313
+ /* Section styling (keeping for compatibility) */
314
+ .section {
315
+ background: var(--background-fill-secondary);
316
+ border-radius: 12px;
317
+ padding: 25px;
318
+ margin-bottom: 25px;
319
+ border: 1px solid var(--border-color-primary);
320
+ box-shadow: 0 2px 8px rgba(0,0,0,0.05);
321
+ }
322
+
323
+ .section h2 {
324
+ margin: 0 0 20px 0;
325
+ color: var(--text-color-body);
326
+ font-size: 1.4em;
327
+ font-weight: 600;
328
+ display: flex;
329
+ align-items: center;
330
+ gap: 10px;
331
+ }
332
+
333
+ /* File upload styling */
334
+ .upload-area {
335
+ border: 2px dashed var(--border-color-accent);
336
+ border-radius: 10px;
337
+ padding: 20px;
338
+ text-align: center;
339
+ background: var(--background-fill-primary);
340
+ transition: all 0.3s ease;
341
+ }
342
+
343
+ .upload-area:hover {
344
+ border-color: var(--color-accent);
345
+ background: var(--background-fill-hover);
346
+ }
347
+
348
+ /* Data preview styling */
349
+ .data-section {
350
+ background: var(--background-fill-primary);
351
+ border-radius: 10px;
352
+ padding: 20px;
353
+ border: 1px solid var(--border-color-primary);
354
+ margin: 15px 0;
355
+ }
356
+
357
+ .data-section h3 {
358
+ margin: 0 0 15px 0;
359
+ color: var(--text-color-body);
360
+ font-size: 1.2em;
361
+ }
362
+
363
+ .data-stats {
364
+ display: flex;
365
+ gap: 10px;
366
+ margin-bottom: 15px;
367
+ flex-wrap: wrap;
368
+ }
369
+
370
+ .stat-badge {
371
+ background: var(--color-accent-soft);
372
+ color: var(--text-color-body);
373
+ padding: 6px 12px;
374
+ border-radius: 20px;
375
+ font-size: 0.9em;
376
+ font-weight: 500;
377
+ }
378
+
379
+ .columns-info {
380
+ margin-bottom: 15px;
381
+ padding: 10px;
382
+ background: var(--background-fill-secondary);
383
+ border-radius: 8px;
384
+ font-size: 0.9em;
385
+ }
386
+
387
+ .table-container {
388
+ overflow-x: auto;
389
+ border-radius: 8px;
390
+ }
391
+
392
+ /* Table styling */
393
+ .table {
394
+ width: 100%;
395
+ border-collapse: collapse;
396
+ font-size: 0.85em;
397
+ background: var(--background-fill-primary);
398
+ }
399
+
400
+ .table th {
401
+ background: var(--background-fill-secondary);
402
+ color: var(--text-color-body);
403
+ font-weight: 600;
404
+ padding: 12px 8px;
405
+ border: 1px solid var(--border-color-primary);
406
+ text-align: left;
407
+ }
408
+
409
+ .table td {
410
+ padding: 10px 8px;
411
+ border: 1px solid var(--border-color-primary);
412
+ color: var(--text-color-body);
413
+ }
414
+
415
+ .table tr:nth-child(even) {
416
+ background: var(--background-fill-hover);
417
+ }
418
+
419
+ /* Prompt examples styling */
420
+ .prompt-examples {
421
+ display: grid;
422
+ gap: 15px;
423
+ margin-top: 15px;
424
+ }
425
+
426
+ .prompt-category {
427
+ background: var(--background-fill-primary);
428
+ border-radius: 8px;
429
+ padding: 15px;
430
+ border: 1px solid var(--border-color-primary);
431
+ }
432
+
433
+ .prompt-category h4 {
434
+ margin: 0 0 10px 0;
435
+ color: var(--text-color-body);
436
+ font-size: 1em;
437
+ }
438
+
439
+ .prompt-buttons {
440
+ display: flex;
441
+ flex-wrap: wrap;
442
+ gap: 8px;
443
+ }
444
+
445
+ .prompt-btn {
446
+ font-size: 0.8em !important;
447
+ padding: 6px 12px !important;
448
+ border-radius: 15px !important;
449
+ background: var(--color-accent-soft) !important;
450
+ color: var(--text-color-body) !important;
451
+ border: 1px solid var(--border-color-accent) !important;
452
+ cursor: pointer;
453
+ transition: all 0.2s ease;
454
+ }
455
+
456
+ .prompt-btn:hover {
457
+ background: var(--color-accent) !important;
458
+ color: white !important;
459
+ }
460
+
461
+ /* Analysis results styling */
462
+ .analysis-result {
463
+ background: var(--background-fill-primary);
464
+ border-radius: 10px;
465
+ padding: 20px;
466
+ margin: 15px 0;
467
+ border: 1px solid var(--border-color-primary);
468
+ }
469
+
470
+ .analysis-result h3 {
471
+ margin: 0 0 15px 0;
472
+ color: var(--text-color-body);
473
+ }
474
+
475
+ /* Chart styling */
476
+ .chart-container {
477
+ text-align: center;
478
+ margin: 20px 0;
479
+ background: var(--background-fill-primary);
480
+ padding: 15px;
481
+ border-radius: 8px;
482
+ border: 1px solid var(--border-color-primary);
483
+ }
484
+
485
+ .chart-image {
486
+ max-width: 100%;
487
+ height: auto;
488
+ border-radius: 8px;
489
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
490
+ }
491
+
492
+ /* Error styling */
493
+ .error-box {
494
+ background: #fee;
495
+ border: 1px solid #fcc;
496
+ color: #c33;
497
+ padding: 15px;
498
+ border-radius: 8px;
499
+ margin: 15px 0;
500
+ }
501
+
502
+ .error-box h3 {
503
+ margin: 0 0 10px 0;
504
+ color: #c33;
505
+ }
506
+
507
+ /* Button styling */
508
+ .analyze-btn {
509
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
510
+ color: white !important;
511
+ border: none !important;
512
+ border-radius: 25px !important;
513
+ padding: 15px 30px !important;
514
+ font-size: 1.1em !important;
515
+ font-weight: 600 !important;
516
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
517
+ transition: all 0.3s ease !important;
518
+ }
519
+
520
+ .analyze-btn:hover {
521
+ transform: translateY(-2px) !important;
522
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
523
+ }
524
+
525
+ /* Responsive design */
526
+ @media (max-width: 768px) {
527
+ .gradio-container {
528
+ padding: 10px;
529
+ }
530
+
531
+ .main-header h1 {
532
+ font-size: 2em;
533
+ }
534
+
535
+ .section {
536
+ padding: 15px;
537
+ }
538
+
539
+ .data-stats {
540
+ flex-direction: column;
541
+ }
542
+
543
+ .prompt-buttons {
544
+ flex-direction: column;
545
+ }
546
+ }
547
+ """
548
+ ) as demo:
549
+
550
+ # Header
551
+ gr.Markdown("""
552
+ # 🤖 Data Analysis Agent
553
+
554
+ Upload your data file and describe what analysis you want to perform. The AI agent will:
555
+ - 📊 Create visualizations (charts, plots, graphs)
556
+ - 🔢 Perform statistical analysis (correlations, tests, summaries)
557
+ - 🔧 Transform your data (filter, aggregate, compute new columns)
558
+
559
+ **Supported formats:** CSV, Excel (.xlsx, .xls)
560
+ """)
561
+
562
+ # Step 1: File Upload
563
+ with gr.Accordion("📁 Step 1: Upload Your Data", open=True):
564
+ file_input = gr.File(
565
+ label="Choose your data file (CSV, Excel)",
566
+ file_types=[".csv", ".xlsx", ".xls"],
567
+ type="filepath"
568
+ )
569
+
570
+ # Step 2: Data Preview
571
+ with gr.Accordion("👀 Step 2: Data Preview", open=True):
572
+ preview_output = gr.HTML(value="<p style='text-align: center; color: #888; padding: 40px;'>Upload a file to see data preview</p>")
573
+
574
+ # Step 3: Analysis Prompt
575
+ with gr.Accordion("💬 Step 3: Describe Your Analysis", open=True):
576
+ prompt_input = gr.Textbox(
577
+ label="What would you like to analyze?",
578
+ placeholder="e.g., 'Create a bar chart showing sales by category' or 'Calculate correlation between price and quantity'",
579
+ lines=3
580
+ )
581
+
582
+ # Example prompts in separate collapsible sections
583
+ gr.HTML('<h4 style="margin: 20px 0 10px 0;">💡 Need inspiration? Try these examples:</h4>')
584
+
585
+
586
+ with gr.Accordion("🔧 Data Transformation Examples", open=False):
587
+ for prompt in sample_prompts["Data Transformation"]:
588
+ gr.Button(prompt, size="sm", elem_classes=["prompt-btn"]).click(
589
+ lambda p=prompt: p, inputs=[], outputs=prompt_input, queue=False
590
+ )
591
+
592
+ with gr.Accordion("📊 Visualization Examples", open=False):
593
+ for prompt in sample_prompts["Visualization"]:
594
+ gr.Button(prompt, size="sm", elem_classes=["prompt-btn"]).click(
595
+ lambda p=prompt: p, inputs=[], outputs=prompt_input, queue=False
596
+ )
597
+
598
+ with gr.Accordion("📈 Statistical Analysis Examples", open=False):
599
+ for prompt in sample_prompts["Statistical Analysis"]:
600
+ gr.Button(prompt, size="sm", elem_classes=["prompt-btn"]).click(
601
+ lambda p=prompt: p, inputs=[], outputs=prompt_input, queue=False
602
+ )
603
+
604
+
605
+
606
+ # Step 4: Analysis Button
607
+ with gr.Accordion("🚀 Step 4: Run Analysis", open=True):
608
+ submit_btn = gr.Button("🚀 Analyze Data", variant="primary", size="lg", elem_classes=["analyze-btn"])
609
+
610
+ # Step 5: Results
611
+ with gr.Accordion("📊 Step 5: Analysis Results", open=True):
612
+ output = gr.HTML(value="<p style='text-align: center; color: #888; padding: 40px;'>Click 'Analyze Data' to see results here</p>")
613
+
614
+ # Step 6: Downloads
615
+ with gr.Accordion("📥 Step 6: Downloads", open=True):
616
+ download_output = gr.File(label="Transformed Data (if applicable)", visible=True)
617
+ gr.HTML("<p style='color: #666; font-size: 0.9em;'>Download will appear here for data transformation results</p>")
618
+
619
+ # Event handlers
620
+ file_input.change(
621
+ fn=generate_preview,
622
+ inputs=[file_input],
623
+ outputs=[preview_output]
624
+ )
625
+
626
+ submit_btn.click(
627
+ fn=process_sync,
628
+ inputs=[file_input, prompt_input],
629
+ outputs=[output, download_output],
630
+ show_progress=True
631
+ )
632
+
633
+ if __name__ == "__main__":
634
+ demo.launch(
635
+ server_name="0.0.0.0",
636
+ server_port=7860,
637
+ share=False,
638
+ debug=True
639
+ )
mermaid_graph.png ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.19.2
2
+ pandas>=2.2.0
3
+ numpy>=1.26.0
4
+ matplotlib>=3.8.0
5
+ seaborn>=0.13.0
6
+ scipy>=1.12.0
7
+ httpx>=0.26.0
8
+ langgraph>=0.0.20
9
+ langchain-core>=0.1.27
10
+ python-dotenv>=1.0.0
visualize_agent.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langgraph.graph import StateGraph, START, END
3
+ from IPython.display import Image, display
4
+
5
+ # Assuming 'agents' module and its contents (AgentState, nodes, routes) are available.
6
+ # For a runnable example, you'd need to define these or mock them.
7
+ # Example placeholders if 'agents.py' isn't provided:
8
+ class AgentState:
9
+ """A placeholder for AgentState."""
10
+ pass
11
+
12
+ def analyze_intent_node(state):
13
+ """Placeholder for analyze_intent_node."""
14
+ print("Analyzing intent...")
15
+ # In a real scenario, this would determine the next step
16
+ # For demonstration, let's simulate routing to visualization
17
+ return {"next_step": "visualization"}
18
+
19
+ def generate_visualization_node(state):
20
+ """Placeholder for generate_visualization_node."""
21
+ print("Generating visualization code...")
22
+ # Simulate success
23
+ return {"code_generated": True}
24
+
25
+ def generate_transformation_node(state):
26
+ """Placeholder for generate_transformation_node."""
27
+ print("Generating transformation code...")
28
+ return {"code_generated": True}
29
+
30
+ def generate_statistical_node(state):
31
+ """Placeholder for generate_statistical_node."""
32
+ print("Generating statistical code...")
33
+ return {"code_generated": True}
34
+
35
+ def execute_code_node(state):
36
+ """Placeholder for execute_code_node."""
37
+ print("Executing code...")
38
+ return {"execution_successful": True}
39
+
40
+ def error_handler_node(state):
41
+ """Placeholder for error_handler_node."""
42
+ print("Handling error...")
43
+ return {}
44
+
45
+ def route_based_on_intent(state):
46
+ """Placeholder for route_based_on_intent."""
47
+ # In a real app, this would use state to determine the route
48
+ if state.get("next_step") == "visualization":
49
+ return "visualization"
50
+ elif state.get("next_step") == "transformation":
51
+ return "transformation"
52
+ elif state.get("next_step") == "statistical":
53
+ return "statistical"
54
+ return "error"
55
+
56
+ def route_to_execution(state):
57
+ """Placeholder for route_to_execution."""
58
+ # In a real app, this would check if code generation was successful
59
+ if state.get("code_generated"):
60
+ return "execute"
61
+ return "error"
62
+
63
+
64
+ def create_visualization():
65
+ """Create and save a visualization of the agent workflow."""
66
+ # Create the state graph
67
+ workflow = StateGraph(AgentState)
68
+
69
+ # Add nodes
70
+ workflow.add_node("analyze_intent", analyze_intent_node)
71
+ workflow.add_node("visualization", generate_visualization_node)
72
+ workflow.add_node("transformation", generate_transformation_node)
73
+ workflow.add_node("statistical", generate_statistical_node)
74
+ workflow.add_node("execute", execute_code_node)
75
+ workflow.add_node("error_handler", error_handler_node)
76
+
77
+ # Add edges
78
+ workflow.add_edge(START, "analyze_intent")
79
+
80
+ # Conditional edges based on intent
81
+ workflow.add_conditional_edges(
82
+ "analyze_intent",
83
+ route_based_on_intent,
84
+ {
85
+ "visualization": "visualization",
86
+ "transformation": "transformation",
87
+ "statistical": "statistical",
88
+ "error": "error_handler"
89
+ }
90
+ )
91
+
92
+ # Route from generation nodes to execution
93
+ workflow.add_conditional_edges(
94
+ "visualization",
95
+ route_to_execution,
96
+ {
97
+ "execute": "execute",
98
+ "complete": END, # Added 'complete' to allow direct END from visualization if needed
99
+ "error": "error_handler"
100
+ }
101
+ )
102
+ workflow.add_conditional_edges(
103
+ "transformation",
104
+ route_to_execution,
105
+ {
106
+ "execute": "execute",
107
+ "complete": END,
108
+ "error": "error_handler"
109
+ }
110
+ )
111
+ workflow.add_conditional_edges(
112
+ "statistical",
113
+ route_to_execution,
114
+ {
115
+ "execute": "execute",
116
+ "complete": END,
117
+ "error": "error_handler"
118
+ }
119
+ )
120
+
121
+ # Final edges
122
+ workflow.add_edge("execute", END)
123
+ workflow.add_edge("error_handler", END)
124
+
125
+ # Create visualization directory if it doesn't exist
126
+ os.makedirs("visualizations", exist_ok=True)
127
+
128
+ # Generate and save the visualization
129
+ graph = workflow.compile()
130
+
131
+ try:
132
+ # Get the graph as a Mermaid diagram and draw it to PNG
133
+ # This requires 'mermaid-py' and potentially 'puppeteer' (for playwright backend)
134
+ png_data = graph.get_graph().draw_mermaid_png()
135
+
136
+ # Define the filename
137
+ filename = os.path.join("visualizations", "mermaid_graph.png")
138
+
139
+ # Save the PNG data to a file
140
+ with open(filename, "wb") as f:
141
+ f.write(png_data)
142
+
143
+ print(f"Image successfully saved as '{filename}'")
144
+
145
+ # Optionally, display the image after saving
146
+ display(Image(png_data))
147
+ except ImportError:
148
+ print("Please install 'mermaid-py' to generate PNG visualizations.")
149
+ print("You might also need to install a browser automation tool like 'playwright' for mermaid-py.")
150
+ except Exception as e:
151
+ print(f"An error occurred during visualization: {e}")
152
+ # This requires some extra dependencies and is optional
153
+ pass
154
+
155
+ if __name__ == "__main__":
156
+ create_visualization()