Janarthanan-Gnanamurthy commited on
Commit
a295a27
·
verified ·
1 Parent(s): 1d22584

Upload agents.py

Browse files
Files changed (1) hide show
  1. agents.py +807 -0
agents.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import logging
5
+ import httpx
6
+ import pandas as pd
7
+ import numpy as np
8
+ from typing import Dict, List, TypedDict, Annotated, Literal
9
+ from fastapi import HTTPException
10
+ from langgraph.graph import StateGraph, START, END
11
+ from langgraph.prebuilt import ToolNode
12
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
13
+ from langchain_core.tools import tool
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from scipy import stats
17
+ import warnings
18
+ import io
19
+ import base64
20
+ import tempfile
21
+ from dotenv import load_dotenv
22
+
23
+ # Load environment variables from .env file
24
+ load_dotenv()
25
+
26
+ # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Gemini API configuration
34
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
35
+ GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-preview-05-20")
36
+ GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
37
+
38
+ if not GEMINI_API_KEY:
39
+ raise ValueError("GEMINI_API_KEY environment variable is required")
40
+
41
+ # Define the agent state
42
+ class AgentState(TypedDict):
43
+ messages: Annotated[List[BaseMessage], "The conversation messages"]
44
+ prompt: str
45
+ dataframe: pd.DataFrame
46
+ columns: List[str]
47
+ intent: Dict
48
+ chart_config: Dict
49
+ code: str
50
+ result: Dict
51
+ error: str
52
+ next_action: str
53
+ plot_path: str
54
+
55
+ async def generate_with_gemini(prompt, temperature=0.2):
56
+ """Generate response using Gemini API."""
57
+ url = f"{GEMINI_BASE_URL}/{GEMINI_MODEL}:generateContent"
58
+
59
+ headers = {
60
+ "Content-Type": "application/json",
61
+ }
62
+
63
+ payload = {
64
+ "contents": [
65
+ {
66
+ "parts": [
67
+ {
68
+ "text": prompt
69
+ }
70
+ ]
71
+ }
72
+ ],
73
+ "generationConfig": {
74
+ "temperature": temperature,
75
+ "topP": 0.95,
76
+ "topK": 40,
77
+ "maxOutputTokens": 8192,
78
+ },
79
+ "safetySettings": [
80
+ {
81
+ "category": "HARM_CATEGORY_HARASSMENT",
82
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
83
+ },
84
+ {
85
+ "category": "HARM_CATEGORY_HATE_SPEECH",
86
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
87
+ },
88
+ {
89
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
90
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
91
+ },
92
+ {
93
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
94
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
95
+ }
96
+ ]
97
+ }
98
+
99
+ try:
100
+ async with httpx.AsyncClient(timeout=120.0) as client:
101
+ response = await client.post(
102
+ url,
103
+ json=payload,
104
+ headers=headers,
105
+ params={"key": GEMINI_API_KEY}
106
+ )
107
+ response.raise_for_status()
108
+ result = response.json()
109
+
110
+ # Extract text from Gemini response
111
+ if "candidates" in result and len(result["candidates"]) > 0:
112
+ candidate = result["candidates"][0]
113
+ if "content" in candidate and "parts" in candidate["content"]:
114
+ return candidate["content"]["parts"][0].get("text", "")
115
+
116
+ return ""
117
+
118
+ except httpx.HTTPStatusError as e:
119
+ logger.error(f"HTTP error from Gemini API: {e.response.status_code} - {e.response.text}")
120
+ raise HTTPException(status_code=e.response.status_code, detail=f"Gemini API error: {e.response.text}")
121
+ except Exception as e:
122
+ logger.error(f"Error generating response with Gemini: {str(e)}")
123
+ raise HTTPException(status_code=500, detail=f"Error generating response with Gemini: {str(e)}")
124
+
125
+ def create_chart(df: pd.DataFrame, chart_config: Dict) -> str:
126
+ """Create a matplotlib chart and return the base64 encoded image."""
127
+ try:
128
+ plt.style.use('seaborn-v0_8')
129
+ fig, ax = plt.subplots(figsize=(12, 8))
130
+
131
+ chart_type = chart_config.get("chart_type", "bar")
132
+ x_axis = chart_config.get("x_axis")
133
+ y_axis = chart_config.get("y_axis")
134
+ title = chart_config.get("title", "Chart")
135
+ aggregation = chart_config.get("aggregation", "none")
136
+
137
+ # Handle data aggregation if needed
138
+ plot_df = df.copy()
139
+ if aggregation != "none" and x_axis and y_axis:
140
+ if aggregation == "sum":
141
+ plot_df = df.groupby(x_axis)[y_axis].sum().reset_index()
142
+ elif aggregation == "mean":
143
+ plot_df = df.groupby(x_axis)[y_axis].mean().reset_index()
144
+ elif aggregation == "count":
145
+ plot_df = df.groupby(x_axis)[y_axis].count().reset_index()
146
+
147
+ # Create the chart based on type
148
+ if chart_type == "bar":
149
+ if aggregation != "none":
150
+ ax.bar(plot_df[x_axis], plot_df[y_axis])
151
+ else:
152
+ sns.barplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
153
+
154
+ elif chart_type == "line":
155
+ if aggregation != "none":
156
+ ax.plot(plot_df[x_axis], plot_df[y_axis], marker='o')
157
+ else:
158
+ sns.lineplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
159
+
160
+ elif chart_type == "scatter":
161
+ sns.scatterplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
162
+
163
+ elif chart_type == "histogram":
164
+ if x_axis in df.columns:
165
+ ax.hist(df[x_axis].dropna(), bins=30, alpha=0.7)
166
+
167
+ elif chart_type == "boxplot":
168
+ if y_axis and x_axis:
169
+ sns.boxplot(data=plot_df, x=x_axis, y=y_axis, ax=ax)
170
+ else:
171
+ ax.boxplot(df.select_dtypes(include=[np.number]).dropna())
172
+
173
+ elif chart_type == "pie":
174
+ if x_axis:
175
+ value_counts = df[x_axis].value_counts()
176
+ ax.pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%')
177
+
178
+ elif chart_type == "area":
179
+ if x_axis and y_axis:
180
+ ax.fill_between(plot_df[x_axis], plot_df[y_axis], alpha=0.7)
181
+
182
+ # Customize the chart
183
+ ax.set_title(title, fontsize=16, fontweight='bold')
184
+ if x_axis and chart_type != "pie":
185
+ ax.set_xlabel(x_axis.replace('_', '').title(), fontsize=12)
186
+ if y_axis and chart_type not in ["pie", "histogram"]:
187
+ ax.set_ylabel(y_axis.replace('_', ' ').title(), fontsize=12)
188
+
189
+ # Rotate x-axis labels if they're long
190
+ if chart_type not in ["pie", "histogram"]:
191
+ plt.xticks(rotation=45, ha='right')
192
+
193
+ plt.tight_layout()
194
+
195
+ # Save to base64
196
+ buffer = io.BytesIO()
197
+ plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight')
198
+ buffer.seek(0)
199
+ image_base64 = base64.b64encode(buffer.read()).decode()
200
+ plt.close(fig)
201
+
202
+ return image_base64
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error creating chart: {str(e)}")
206
+ plt.close('all') # Clean up any open figures
207
+ return None
208
+
209
+ # Agent nodes
210
+ async def analyze_intent_node(state: AgentState) -> AgentState:
211
+ """Analyze the user's prompt to determine intent."""
212
+ prompt = state["prompt"]
213
+ columns = state["columns"]
214
+
215
+ response_format = {
216
+ "intent": "statistical",
217
+ "reason": "Prompt requests statistical analysis",
218
+ "visualization_type": None,
219
+ "transformation_type": None,
220
+ "statistical_type": "correlation"
221
+ }
222
+
223
+ input_text = f"""Analyze the following prompt and determine if it's requesting data transformation, visualization, or statistical analysis:
224
+
225
+ Prompt: {prompt}
226
+ Available columns: {', '.join(columns)}
227
+
228
+ Provide a JSON response with:
229
+ 1. intent: Either 'visualization', 'transformation', or 'statistical'
230
+ 2. reason: Brief explanation of why this classification was chosen
231
+ 3. visualization_type: If intent is 'visualization', specify the chart type ('bar', 'line', 'pie', 'scatter', 'area', 'histogram', 'boxplot')
232
+ 4. transformation_type: If intent is 'transformation', specify the operation type ('aggregate', 'filter', 'join', 'compute', 'sort', 'group')
233
+ 5. statistical_type: If intent is 'statistical', specify the test type ('correlation', 'ttest', 'regression', 'descriptive'),
234
+
235
+ Example response format:
236
+ {json.dumps(response_format)}"""
237
+
238
+ try:
239
+ json_text = await generate_with_gemini(input_text, temperature=0.4)
240
+
241
+ # Try to extract JSON from markdown code blocks if present
242
+ json_match = re.search(r"```(?:json)?\n(.*?)\n```", json_text, re.DOTALL)
243
+ if json_match:
244
+ json_text = json_match.group(1)
245
+
246
+ json_text = json_text.strip()
247
+
248
+ try:
249
+ intent = json.loads(json_text)
250
+ except json.JSONDecodeError:
251
+ # If direct parsing fails, try to extract just the JSON object
252
+ json_obj_match = re.search(r"(\{.*\})", json_text, re.DOTALL)
253
+ if json_obj_match:
254
+ intent = json.loads(json_obj_match.group(1))
255
+ else:
256
+ # Fallback classification based on keywords
257
+ prompt_lower = prompt.lower()
258
+ if any(word in prompt_lower for word in ['chart', 'plot', 'graph', 'visualiz', 'show']):
259
+ intent = {"intent": "visualization", "reason": "Keywords suggest visualization"}
260
+ elif any(word in prompt_lower for word in ['filter', 'transform', 'add', 'modify', 'create column']):
261
+ intent = {"intent": "transformation", "reason": "Keywords suggest transformation"}
262
+ else:
263
+ intent = {"intent": "statistical", "reason": "Default to statistical analysis"}
264
+
265
+ state["intent"] = intent
266
+ state["next_action"] = intent["intent"]
267
+ logger.info(f"Intent analysis result: {intent}")
268
+
269
+ except Exception as e:
270
+ state["error"] = f"Error analyzing prompt intent: {str(e)}"
271
+ state["next_action"] = "error"
272
+ logger.error(f"Error in analyze_intent_node: {str(e)}")
273
+
274
+ return state
275
+
276
+ async def generate_visualization_node(state: AgentState) -> AgentState:
277
+ """Generate visualization configuration and create the chart."""
278
+ prompt = state["prompt"]
279
+ columns = state["columns"]
280
+ df = state["dataframe"]
281
+
282
+ response_format = {
283
+ "chart_type": "bar",
284
+ "x_axis": "date",
285
+ "y_axis": "sales",
286
+ "aggregation": "sum",
287
+ "title": "Total Sales by Date"
288
+ }
289
+
290
+ input_text = f"""Based on the following prompt, determine the appropriate chart configuration:
291
+
292
+ Prompt: {prompt}
293
+ Available columns: {', '.join(columns)}
294
+
295
+ Generate a JSON configuration with:
296
+ 1. chart_type: 'bar', 'line', 'pie', 'scatter', 'area', 'histogram', 'boxplot'
297
+ 2. x_axis: column name for x-axis (choose from available columns)
298
+ 3. y_axis: column name for y-axis (can be None for histograms, choose from available columns)
299
+ 4. aggregation: 'sum', 'mean', 'count', 'none'
300
+ 5. title: descriptive chart title
301
+
302
+ Example response format:
303
+ {json.dumps(response_format)}
304
+
305
+ Provide only the JSON configuration, no explanations."""
306
+
307
+ try:
308
+ json_text = await generate_with_gemini(input_text, temperature=0.5)
309
+
310
+ json_match = re.search(r"```(?:json)?\n(.*?)\n```", json_text, re.DOTALL)
311
+ if json_match:
312
+ json_text = json_match.group(1)
313
+
314
+ json_text = json_text.strip()
315
+
316
+ try:
317
+ chart_config = json.loads(json_text)
318
+ except json.JSONDecodeError:
319
+ json_obj_match = re.search(r"(\{.*\})", json_text, re.DOTALL)
320
+ if json_obj_match:
321
+ chart_config = json.loads(json_obj_match.group(1))
322
+ else:
323
+ # Fallback configuration
324
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
325
+ categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
326
+
327
+ chart_config = {
328
+ "chart_type": "bar",
329
+ "x_axis": categorical_cols[0] if categorical_cols else columns[0],
330
+ "y_axis": numeric_cols[0] if numeric_cols else columns[1] if len(columns) > 1 else None,
331
+ "aggregation": "mean" if numeric_cols else "count",
332
+ "title": "Data Visualization"
333
+ }
334
+
335
+ # Validate column names exist
336
+ if chart_config.get("x_axis") not in columns:
337
+ chart_config["x_axis"] = columns[0]
338
+ if chart_config.get("y_axis") and chart_config["y_axis"] not in columns:
339
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
340
+ chart_config["y_axis"] = numeric_cols[0] if numeric_cols else None
341
+
342
+ state["chart_config"] = chart_config
343
+
344
+ # Create the chart immediately
345
+ image_base64 = create_chart(df, chart_config)
346
+ if image_base64:
347
+ state["result"] = {
348
+ "type": "visualization",
349
+ "chart_type": chart_config["chart_type"],
350
+ "config": chart_config,
351
+ "image": image_base64,
352
+ "message": "Visualization created successfully"
353
+ }
354
+ state["next_action"] = "complete"
355
+ else:
356
+ state["error"] = "Failed to create visualization"
357
+ state["next_action"] = "error"
358
+
359
+ logger.info(f"Generated chart config: {chart_config}")
360
+
361
+ except Exception as e:
362
+ state["error"] = f"Error generating chart configuration: {str(e)}"
363
+ state["next_action"] = "error"
364
+ logger.error(f"Error in generate_visualization_node: {str(e)}")
365
+
366
+ return state
367
+
368
+ async def generate_transformation_node(state: AgentState) -> AgentState:
369
+ """Generate pandas transformation code."""
370
+ prompt = state["prompt"]
371
+ columns = state["columns"]
372
+
373
+ input_text = f"""Write Python code to perform the following pandas DataFrame transformation:
374
+
375
+ {prompt}
376
+
377
+ Available columns: {', '.join(columns)}
378
+
379
+ Pandas Knowledge Base:
380
+ 1. DataFrame Operations:
381
+ - select columns: df[['col1', 'col2']]
382
+ - filter rows: df[df['column'] > value]
383
+ - group data: df.groupby('column')
384
+ - sort data: df.sort_values('column')
385
+ - add/modify columns: df['new_col'] = df['col1'] * 2
386
+ - drop columns: df.drop(['col1'], axis=1)
387
+ - remove duplicates: df.drop_duplicates()
388
+ - merge dataframes: pd.merge(df1, df2)
389
+
390
+ 2. Common Functions:
391
+ - df.apply(): Apply function to columns/rows
392
+ - df.fillna(): Fill missing values
393
+ - df.dropna(): Drop missing values
394
+ - df.replace(): Replace values
395
+ - pd.to_datetime(): Convert to datetime
396
+ - df.astype(): Convert data types
397
+ - df.round(): Round numbers
398
+ - df.sum(), df.mean(), df.count(): Aggregations
399
+
400
+ 3. String Operations:
401
+ - df['col'].str.contains(): String contains
402
+ - df['col'].str.split(): Split strings
403
+ - df['col'].str.replace(): Replace in strings
404
+ - df['col'].str.upper(): Convert to uppercase
405
+
406
+ 4. Window Operations:
407
+ - df.rolling(): Rolling window operations
408
+ - df.shift(): Shift values
409
+ - df.expanding(): Expanding window
410
+
411
+ Requirements:
412
+ 1. Use pandas DataFrame operations
413
+ 2. Handle missing values appropriately
414
+ 3. Store result in 'transformed_df'
415
+ 4. DO NOT define functions
416
+ 5. Return a pandas DataFrame
417
+ 6. Use proper type conversions if needed
418
+
419
+ Available variables:
420
+ - df: pandas DataFrame
421
+ - pd: pandas module
422
+ - np: numpy module
423
+
424
+ Example format:
425
+ ```python
426
+ transformed_df = df.copy()
427
+ transformed_df['new_column'] = df['column1'] * df['column2']
428
+ transformed_df = transformed_df.fillna(0) # Handle nulls
429
+ ```
430
+
431
+ Provide only the code, no explanations. DO NOT DEFINE functions, directly perform the operations on the df."""
432
+
433
+ try:
434
+ code = await generate_with_gemini(input_text, temperature=0.4)
435
+
436
+ code_match = re.search(r"```python\n(.*?)\n```", code, re.DOTALL)
437
+ code = code_match.group(1) if code_match else code
438
+
439
+ state["code"] = code
440
+ state["next_action"] = "execute"
441
+ logger.info(f"Generated transformation code: {code}")
442
+
443
+ except Exception as e:
444
+ state["error"] = f"Error generating transformation code: {str(e)}"
445
+ state["next_action"] = "error"
446
+ logger.error(f"Error in generate_transformation_node: {str(e)}")
447
+
448
+ return state
449
+
450
+ async def generate_statistical_node(state: AgentState) -> AgentState:
451
+ """Generate pandas/numpy code for statistical analysis."""
452
+ prompt = state["prompt"]
453
+ columns = state["columns"]
454
+
455
+ input_text = f"""Write pandas/numpy code to perform the following statistical analysis:
456
+
457
+ {prompt}
458
+
459
+ Available columns: {', '.join(columns)}
460
+
461
+ Statistical Analysis Knowledge Base:
462
+ 1. Descriptive Statistics:
463
+ - df.describe(): Summary statistics
464
+ - df.mean(), df.std(): Mean and standard deviation
465
+ - df.var(): Variance
466
+ - df.min(), df.max(): Min/max values
467
+ - df.quantile([0.25, 0.5, 0.75]): Quartiles
468
+ - df.corr(): Correlation matrix
469
+
470
+ 2. Hypothesis Testing (scipy.stats):
471
+ - stats.ttest_ind(): Independent t-test
472
+ - stats.ttest_rel(): Paired t-test
473
+ - stats.chi2_contingency(): Chi-square test
474
+ - stats.pearsonr(): Pearson correlation
475
+ - stats.spearmanr(): Spearman correlation
476
+
477
+ 3. Regression Analysis:
478
+ - np.polyfit(): Polynomial fitting
479
+ - stats.linregress(): Linear regression
480
+ - df.rolling().corr(): Rolling correlation
481
+
482
+ 4. Data Quality:
483
+ - df.isnull().sum(): Count missing values
484
+ - df.duplicated().sum(): Count duplicates
485
+ - df.value_counts(): Value counts
486
+
487
+ Requirements:
488
+ 1. Use pandas and numpy functions
489
+ 2. Include proper statistical computations
490
+ 3. Store main result in 'stat_result'
491
+ 4. DO NOT define functions
492
+ 5. Handle null values appropriately
493
+ 6. Include interpretation comments
494
+
495
+ Available variables:
496
+ - df: pandas DataFrame
497
+ - pd: pandas module
498
+ - np: numpy module
499
+ - stats: scipy.stats module
500
+
501
+ Example formats:
502
+
503
+ For correlation analysis:
504
+ ```python
505
+ # Calculate correlation matrix
506
+ correlation_matrix = df.select_dtypes(include=[np.number]).corr()
507
+ stat_result = correlation_matrix
508
+ ```
509
+
510
+ For descriptive statistics:
511
+ ```python
512
+ # Generate summary statistics
513
+ stat_result = df.describe()
514
+ # Add correlation for numeric columns
515
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
516
+ if len(numeric_cols) > 1:
517
+ stat_result = {
518
+ 'descriptive': df.describe(),
519
+ 'correlation': df[numeric_cols].corr()
520
+ }
521
+ ```
522
+
523
+ For hypothesis testing:
524
+ ```python
525
+ # Perform t-test between two groups
526
+ group1 = df[df['category'] == 'A']['value']
527
+ group2 = df[df['category'] == 'B']['value']
528
+ t_stat, p_value = stats.ttest_ind(group1, group2)
529
+ stat_result = {'t_statistic': t_stat, 'p_value': p_value}
530
+ ```
531
+
532
+ Provide only the code, no explanations. DO NOT DEFINE functions."""
533
+
534
+ try:
535
+ code = await generate_with_gemini(input_text, temperature=0.3)
536
+
537
+ code_match = re.search(r"```python\n(.*?)\n```", code, re.DOTALL)
538
+ code = code_match.group(1) if code_match else code
539
+
540
+ state["code"] = code
541
+ state["next_action"] = "execute"
542
+ logger.info(f"Generated statistical code: {code}")
543
+
544
+ except Exception as e:
545
+ state["error"] = f"Error generating statistical code: {str(e)}"
546
+ state["next_action"] = "error"
547
+ logger.error(f"Error in generate_statistical_node: {str(e)}")
548
+
549
+ return state
550
+
551
+ async def execute_code_node(state: AgentState) -> AgentState:
552
+ """Execute the generated code safely."""
553
+ code = state["code"]
554
+ df = state["dataframe"]
555
+
556
+ if not code:
557
+ state["error"] = "No code to execute"
558
+ state["next_action"] = "error"
559
+ return state
560
+
561
+ try:
562
+ # Create safe execution environment
563
+ safe_globals = {
564
+ 'df': df,
565
+ 'pd': pd,
566
+ 'np': np,
567
+ 'stats': stats,
568
+ 'plt': plt,
569
+ 'sns': sns
570
+ }
571
+
572
+ # Execute the code
573
+ exec(code, safe_globals)
574
+
575
+ # Extract results based on intent
576
+ intent = state["intent"]["intent"]
577
+
578
+ if intent == "transformation":
579
+ if 'transformed_df' in safe_globals:
580
+ result_df = safe_globals['transformed_df']
581
+ state["result"] = {
582
+ "type": "transformation",
583
+ "shape": result_df.shape,
584
+ "columns": result_df.columns.tolist(),
585
+ "preview": result_df.head(10).to_html(classes='table table-striped'),
586
+ "dataframe": result_df,
587
+ "message": f"Data transformed successfully. New shape: {result_df.shape}"
588
+ }
589
+ else:
590
+ state["error"] = "No 'transformed_df' found in execution result"
591
+
592
+ elif intent == "statistical":
593
+ if 'stat_result' in safe_globals:
594
+ stat_result = safe_globals['stat_result']
595
+ formatted_result = format_statistical_result(stat_result)
596
+ state["result"] = {
597
+ "type": "statistical",
598
+ "data": formatted_result,
599
+ "message": "Statistical analysis completed successfully"
600
+ }
601
+ else:
602
+ state["error"] = "No 'stat_result' found in execution result"
603
+
604
+ state["next_action"] = "complete"
605
+ logger.info("Code executed successfully")
606
+
607
+ except Exception as e:
608
+ state["error"] = f"Error executing code: {str(e)}"
609
+ state["next_action"] = "error"
610
+ logger.error(f"Error in execute_code_node: {str(e)}")
611
+
612
+ return state
613
+
614
+ def format_statistical_result(stat_result) -> str:
615
+ """Format statistical results for display in Gradio."""
616
+ try:
617
+ if isinstance(stat_result, pd.DataFrame):
618
+ return stat_result.to_html(classes='table table-striped')
619
+ elif isinstance(stat_result, dict):
620
+ html_parts = []
621
+ for key, value in stat_result.items():
622
+ html_parts.append(f"<h4>{key.replace('_', ' ').title()}</h4>")
623
+ if isinstance(value, pd.DataFrame):
624
+ html_parts.append(value.to_html(classes='table table-striped'))
625
+ elif isinstance(value, (int, float)):
626
+ html_parts.append(f"<p><strong>{value:.6f}</strong></p>")
627
+ else:
628
+ html_parts.append(f"<p>{str(value)}</p>")
629
+ return ''.join(html_parts)
630
+ else:
631
+ return f"<p><strong>Result:</strong> {str(stat_result)}</p>"
632
+ except Exception as e:
633
+ return f"<p><strong>Error formatting result:</strong> {str(e)}</p>"
634
+
635
+ async def error_handler_node(state: AgentState) -> AgentState:
636
+ """Handle errors and provide feedback."""
637
+ error = state.get("error", "Unknown error occurred")
638
+ logger.error(f"Error in agent workflow: {error}")
639
+
640
+ state["result"] = {
641
+ "type": "error",
642
+ "message": error,
643
+ "suggestions": [
644
+ "Check if the column names are correct",
645
+ "Verify that the data types are appropriate",
646
+ "Ensure the prompt is clear and specific"
647
+ ]
648
+ }
649
+ state["next_action"] = "complete"
650
+ return state
651
+
652
+ def route_based_on_intent(state: AgentState) -> Literal["visualization", "transformation", "statistical", "error"]:
653
+ """Route to appropriate node based on intent analysis."""
654
+ if state.get("error"):
655
+ return "error"
656
+
657
+ intent = state.get("intent", {}).get("intent", "error")
658
+ return intent
659
+
660
+ def route_to_execution(state: AgentState) -> Literal["execute", "error", "complete"]:
661
+ """Route to execution or error handling."""
662
+ if state.get("error"):
663
+ return "error"
664
+
665
+ next_action = state.get("next_action", "error")
666
+ if next_action == "execute":
667
+ return "execute"
668
+ elif next_action == "complete":
669
+ return "complete"
670
+ else:
671
+ return "error"
672
+
673
+ # Build the LangGraph workflow
674
+ def create_data_analysis_agent():
675
+ """Create the data analysis agent using LangGraph."""
676
+
677
+ # Create the state graph
678
+ workflow = StateGraph(AgentState)
679
+
680
+ # Add nodes
681
+ workflow.add_node("analyze_intent", analyze_intent_node)
682
+ workflow.add_node("visualization", generate_visualization_node)
683
+ workflow.add_node("transformation", generate_transformation_node)
684
+ workflow.add_node("statistical", generate_statistical_node)
685
+ workflow.add_node("execute", execute_code_node)
686
+ workflow.add_node("error_handler", error_handler_node)
687
+
688
+ # Add edges
689
+ workflow.add_edge(START, "analyze_intent")
690
+
691
+ # Conditional edges based on intent
692
+ workflow.add_conditional_edges(
693
+ "analyze_intent",
694
+ route_based_on_intent,
695
+ {
696
+ "visualization": "visualization",
697
+ "transformation": "transformation",
698
+ "statistical": "statistical",
699
+ "error": "error_handler"
700
+ }
701
+ )
702
+
703
+ # Route from generation nodes to execution
704
+ workflow.add_conditional_edges(
705
+ "visualization",
706
+ route_to_execution,
707
+ {
708
+ "execute": "execute",
709
+ "complete": END,
710
+ "error": "error_handler"
711
+ }
712
+ )
713
+ workflow.add_conditional_edges(
714
+ "transformation",
715
+ route_to_execution,
716
+ {
717
+ "execute": "execute",
718
+ "complete": END,
719
+ "error": "error_handler"
720
+ }
721
+ )
722
+ workflow.add_conditional_edges(
723
+ "statistical",
724
+ route_to_execution,
725
+ {
726
+ "execute": "execute",
727
+ "complete": END,
728
+ "error": "error_handler"
729
+ }
730
+ )
731
+
732
+ # Final edges
733
+ workflow.add_edge("execute", END)
734
+ workflow.add_edge("error_handler", END)
735
+
736
+ # Compile the graph
737
+ app = workflow.compile()
738
+ return app
739
+
740
+ # Main execution function
741
+ async def analyze_data_with_agent(prompt: str, dataframe: pd.DataFrame) -> Dict:
742
+ """
743
+ Analyze data using the LangGraph agent.
744
+
745
+ Args:
746
+ prompt: Natural language prompt describing the analysis
747
+ dataframe: Pandas DataFrame to analyze
748
+
749
+ Returns:
750
+ Dictionary containing the analysis results
751
+ """
752
+ # Create the agent
753
+ agent = create_data_analysis_agent()
754
+
755
+ # Initialize state
756
+ initial_state = {
757
+ "messages": [HumanMessage(content=prompt)],
758
+ "prompt": prompt,
759
+ "dataframe": dataframe,
760
+ "columns": dataframe.columns.tolist(),
761
+ "intent": {},
762
+ "chart_config": {},
763
+ "code": "",
764
+ "result": {},
765
+ "error": "",
766
+ "next_action": "",
767
+ "plot_path": ""
768
+ }
769
+
770
+ # Run the agent
771
+ try:
772
+ final_state = await agent.ainvoke(initial_state)
773
+ return final_state["result"]
774
+ except Exception as e:
775
+ logger.error(f"Error running agent: {str(e)}")
776
+ return {
777
+ "type": "error",
778
+ "message": f"Agent execution failed: {str(e)}"
779
+ }
780
+
781
+ # Test function
782
+ async def test_agent():
783
+ """Test the data analysis agent."""
784
+ # Create sample data
785
+ data = {
786
+ 'date': pd.date_range('2024-01-01', periods=100),
787
+ 'sales': np.random.normal(1000, 200, 100),
788
+ 'category': np.random.choice(['A', 'B', 'C'], 100),
789
+ 'region': np.random.choice(['North', 'South', 'East', 'West'], 100)
790
+ }
791
+ df = pd.DataFrame(data)
792
+
793
+ # Test different types of prompts
794
+ test_prompts = [
795
+ "Create a bar chart showing average sales by category",
796
+ "Calculate correlation between date and sales",
797
+ "Filter the data to show only category A and add a profit column that is 20% of sales"
798
+ ]
799
+
800
+ for prompt in test_prompts:
801
+ print(f"\n--- Testing: {prompt} ---")
802
+ result = await analyze_data_with_agent(prompt, df)
803
+ print(f"Result: {result}")
804
+
805
+ if __name__ == "__main__":
806
+ import asyncio
807
+ asyncio.run(test_agent())