Nyha15 commited on
Commit
a738995
·
1 Parent(s): 4298f06

Added files

Browse files
Files changed (2) hide show
  1. app.py +1344 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Analyst Duo MCP Implementation
3
+
4
+ This script implements a multi-agent system using the Model Context Protocol (MCP).
5
+ It features two agents that collaborate on data analysis tasks:
6
+ - ComputeAgent: Responsible for data loading, cleaning, and computation
7
+ - InterpretAgent: Responsible for interpreting results and visualizing data
8
+
9
+ The application includes a Gradio interface for interaction.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import time
16
+ import datetime
17
+ import gradio as gr
18
+ import pandas as pd
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ from typing import Dict, List, Any, Optional, Union, Tuple
23
+ import requests
24
+ from io import StringIO
25
+ import logging
26
+ import uuid
27
+ import anthropic
28
+ import openai
29
+ from dotenv import load_dotenv
30
+
31
+ # Load environment variables
32
+ load_dotenv()
33
+
34
+ # Configure logging
35
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # ============== MCP Protocol Implementation ==============
39
+
40
+ class MCPMessage:
41
+ """Base class for MCP messages that agents exchange"""
42
+
43
+ def __init__(self, sender: str, message_type: str, content: Any):
44
+ self.id = str(uuid.uuid4())
45
+ self.sender = sender
46
+ self.message_type = message_type
47
+ self.content = content
48
+ self.timestamp = datetime.datetime.now().isoformat()
49
+
50
+ def to_dict(self) -> Dict:
51
+ return {
52
+ "id": self.id,
53
+ "sender": self.sender,
54
+ "message_type": self.message_type,
55
+ "content": self.content,
56
+ "timestamp": self.timestamp
57
+ }
58
+
59
+ @staticmethod
60
+ def from_dict(data: Dict) -> 'MCPMessage':
61
+ msg = MCPMessage(
62
+ sender=data["sender"],
63
+ message_type=data["message_type"],
64
+ content=data["content"]
65
+ )
66
+ # Restore ID and timestamp if present
67
+ if "id" in data:
68
+ msg.id = data["id"]
69
+ if "timestamp" in data:
70
+ msg.timestamp = data["timestamp"]
71
+ return msg
72
+
73
+
74
+ class MCPTool:
75
+ """Defines a tool that can be used by agents through the MCP protocol"""
76
+
77
+ def __init__(self, name: str, description: str, function):
78
+ self.name = name
79
+ self.description = description
80
+ self.function = function
81
+
82
+ def to_dict(self) -> Dict:
83
+ return {
84
+ "name": self.name,
85
+ "description": self.description
86
+ }
87
+
88
+ def execute(self, params: Dict) -> Any:
89
+ return self.function(params)
90
+
91
+
92
+ class MCPAgent:
93
+ """Base agent class implementing MCP protocol"""
94
+
95
+ def __init__(self, name: str, description: str, llm_model: Optional[str] = None, api_key: Optional[str] = None):
96
+ self.name = name
97
+ self.description = description
98
+ self.tools: Dict[str, MCPTool] = {}
99
+ self.message_queue: List[MCPMessage] = []
100
+ self.peers: Dict[str, 'MCPAgent'] = {}
101
+ self.message_history: List[Dict] = []
102
+ self.llm_model = llm_model
103
+ self.api_key = api_key
104
+ self.llm_logs = []
105
+
106
+ def register_tool(self, tool: MCPTool):
107
+ """Register a tool that this agent can use"""
108
+ self.tools[tool.name] = tool
109
+
110
+ def list_tools(self) -> List[Dict]:
111
+ """List all tools available to this agent"""
112
+ return [tool.to_dict() for tool in self.tools.values()]
113
+
114
+ def call_tool(self, tool_name: str, params: Dict) -> Any:
115
+ """Call a tool by name with parameters"""
116
+ if tool_name not in self.tools:
117
+ raise ValueError(f"Tool {tool_name} not found")
118
+ return self.tools[tool_name].execute(params)
119
+
120
+ def connect(self, peer: 'MCPAgent'):
121
+ """Connect to another agent as a peer"""
122
+ self.peers[peer.name] = peer
123
+
124
+ def send_message(self, receiver: str, message_type: str, content: Any) -> Dict:
125
+ """Send a message to a peer agent"""
126
+ if receiver not in self.peers:
127
+ raise ValueError(f"Peer {receiver} not found")
128
+
129
+ message = MCPMessage(self.name, message_type, content)
130
+ message_dict = message.to_dict()
131
+
132
+ # Save to message history
133
+ self.message_history.append({
134
+ "type": "sent",
135
+ "message": message_dict
136
+ })
137
+
138
+ # Send to receiver
139
+ self.peers[receiver].receive_message(message)
140
+ logger.info(f"Agent {self.name} sent {message_type} to {receiver}")
141
+ return message_dict
142
+
143
+ def receive_message(self, message: MCPMessage):
144
+ """Receive a message from a peer agent"""
145
+ self.message_queue.append(message)
146
+
147
+ # Save to message history
148
+ self.message_history.append({
149
+ "type": "received",
150
+ "message": message.to_dict()
151
+ })
152
+
153
+ logger.info(f"Agent {self.name} received {message.message_type} from {message.sender}")
154
+
155
+ def process_messages(self) -> List[Dict]:
156
+ """Process all messages in the queue"""
157
+ processed = []
158
+ while self.message_queue:
159
+ message = self.message_queue.pop(0)
160
+ response = self.handle_message(message)
161
+ processed.append(response)
162
+ return processed
163
+
164
+ def handle_message(self, message: MCPMessage) -> Dict:
165
+ """Handle a message - to be implemented by subclasses"""
166
+ raise NotImplementedError("Subclasses must implement handle_message")
167
+
168
+ def log_llm_interaction(self, prompt: str, response: str):
169
+ """Log LLM interactions for transparency"""
170
+ log_entry = {
171
+ "timestamp": datetime.datetime.now().isoformat(),
172
+ "prompt": prompt,
173
+ "response": response
174
+ }
175
+ self.llm_logs.append(log_entry)
176
+ return log_entry
177
+
178
+ def get_message_history(self) -> List[Dict]:
179
+ """Get the agent's message history"""
180
+ return self.message_history
181
+
182
+ def get_llm_logs(self) -> List[Dict]:
183
+ """Get the agent's LLM interaction logs"""
184
+ return self.llm_logs
185
+
186
+
187
+ # ============== Compute Agent Implementation ==============
188
+
189
+ class ComputeAgent(MCPAgent):
190
+ """Agent responsible for data loading, cleaning, and computation"""
191
+
192
+ def __init__(self, name: str = "ComputeAgent", llm_model: Optional[str] = None, api_key: Optional[str] = None):
193
+ super().__init__(name, "Agent responsible for data loading, cleaning and computation", llm_model, api_key)
194
+ self.dataframe = None
195
+ self.current_task = None
196
+
197
+ # Register tools
198
+ self.register_tool(MCPTool(
199
+ "load_dataset",
200
+ "Load a dataset from Kaggle or URL",
201
+ self._load_dataset
202
+ ))
203
+
204
+ self.register_tool(MCPTool(
205
+ "clean_data",
206
+ "Clean the loaded dataset by handling missing values, duplicates, etc.",
207
+ self._clean_data
208
+ ))
209
+
210
+ self.register_tool(MCPTool(
211
+ "compute_statistics",
212
+ "Compute basic statistics on the dataset",
213
+ self._compute_statistics
214
+ ))
215
+
216
+ self.register_tool(MCPTool(
217
+ "compute_correlation",
218
+ "Compute correlation between columns",
219
+ self._compute_correlation
220
+ ))
221
+
222
+ self.register_tool(MCPTool(
223
+ "filter_data",
224
+ "Filter data based on conditions",
225
+ self._filter_data
226
+ ))
227
+
228
+ self.register_tool(MCPTool(
229
+ "compute_aggregation",
230
+ "Compute aggregation (sum, mean, etc.) grouped by a column",
231
+ self._compute_aggregation
232
+ ))
233
+
234
+ def _load_dataset(self, params: Dict) -> Dict:
235
+ """Load a dataset from Kaggle or URL"""
236
+ dataset_url = params.get("url")
237
+
238
+ try:
239
+ # Check if it's the default cereals dataset
240
+ if dataset_url == "default" or dataset_url.lower() == "cereals":
241
+ dataset_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/cereal.csv"
242
+
243
+ # Check if it's a Kaggle URL and extract the dataset path
244
+ elif "kaggle.com/datasets" in dataset_url:
245
+ # For simplicity, we use direct download links
246
+ # In real implementation, you would use the Kaggle API
247
+ prompt = f"""
248
+ I have a Kaggle dataset URL: {dataset_url}
249
+ Find the direct download link or alternative source for this dataset if possible.
250
+ If not, suggest a suitable replacement dataset that's freely available.
251
+ """
252
+
253
+ if self.llm_model and self.llm_model.startswith("claude"):
254
+ client = anthropic.Anthropic(api_key=self.api_key)
255
+ response = client.messages.create(
256
+ model="claude-3-sonnet-20240229",
257
+ max_tokens=1000,
258
+ messages=[{"role": "user", "content": prompt}]
259
+ )
260
+ result = response.content[0].text
261
+ elif self.llm_model and self.llm_model.startswith("gpt"):
262
+ client = openai.OpenAI(api_key=self.api_key)
263
+ response = client.chat.completions.create(
264
+ model="gpt-4o",
265
+ messages=[{"role": "user", "content": prompt}]
266
+ )
267
+ result = response.choices[0].message.content
268
+ else:
269
+ result = "For non-default datasets, please provide a direct download link."
270
+
271
+ self.log_llm_interaction(prompt, result)
272
+
273
+ # Extract URL from the response
274
+ lines = result.split('\n')
275
+ for line in lines:
276
+ if line.startswith("http") and (".csv" in line or ".xlsx" in line):
277
+ dataset_url = line.strip()
278
+ break
279
+ else:
280
+ # If no URL found, use default cereals dataset
281
+ dataset_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/cereal.csv"
282
+
283
+ # Load the dataset
284
+ response = requests.get(dataset_url)
285
+ content = response.content.decode('utf-8')
286
+ self.dataframe = pd.read_csv(StringIO(content))
287
+
288
+ # Basic info about the dataset
289
+ info = {
290
+ "status": "success",
291
+ "rows": len(self.dataframe),
292
+ "columns": list(self.dataframe.columns),
293
+ "preview": self.dataframe.head(5).to_dict(orient="records"),
294
+ "dtypes": {col: str(dtype) for col, dtype in self.dataframe.dtypes.items()}
295
+ }
296
+
297
+ return info
298
+
299
+ except Exception as e:
300
+ return {"status": "error", "message": str(e)}
301
+
302
+ def _clean_data(self, params: Dict) -> Dict:
303
+ """Clean the loaded dataset"""
304
+ if self.dataframe is None:
305
+ return {"status": "error", "message": "No dataset loaded"}
306
+
307
+ try:
308
+ original_shape = self.dataframe.shape
309
+
310
+ # Handle missing values based on strategy
311
+ missing_strategy = params.get("missing_strategy", "drop")
312
+ if missing_strategy == "drop":
313
+ self.dataframe = self.dataframe.dropna()
314
+ elif missing_strategy == "mean":
315
+ self.dataframe = self.dataframe.fillna(self.dataframe.mean(numeric_only=True))
316
+ elif missing_strategy == "median":
317
+ self.dataframe = self.dataframe.fillna(self.dataframe.median(numeric_only=True))
318
+ elif missing_strategy == "mode":
319
+ # Fill categorical with mode, numerics separately
320
+ for column in self.dataframe.columns:
321
+ if pd.api.types.is_numeric_dtype(self.dataframe[column]):
322
+ self.dataframe[column] = self.dataframe[column].fillna(self.dataframe[column].mean())
323
+ else:
324
+ self.dataframe[column] = self.dataframe[column].fillna(self.dataframe[column].mode()[0])
325
+
326
+ # Remove duplicates if specified
327
+ if params.get("remove_duplicates", True):
328
+ self.dataframe = self.dataframe.drop_duplicates()
329
+
330
+ # Convert datatypes if specified
331
+ if "convert_dtypes" in params:
332
+ for col, dtype in params["convert_dtypes"].items():
333
+ self.dataframe[col] = self.dataframe[col].astype(dtype)
334
+
335
+ new_shape = self.dataframe.shape
336
+
337
+ return {
338
+ "status": "success",
339
+ "original_shape": original_shape,
340
+ "new_shape": new_shape,
341
+ "missing_values_remaining": self.dataframe.isna().sum().to_dict(),
342
+ "duplicate_rows_removed": original_shape[0] - new_shape[0]
343
+ }
344
+
345
+ except Exception as e:
346
+ return {"status": "error", "message": str(e)}
347
+
348
+ def _compute_statistics(self, params: Dict) -> Dict:
349
+ """Compute basic statistics on the dataset"""
350
+ if self.dataframe is None:
351
+ return {"status": "error", "message": "No dataset loaded"}
352
+
353
+ try:
354
+ # Get columns to compute stats for
355
+ columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
356
+
357
+ # Compute different statistics based on parameters
358
+ stats = {}
359
+
360
+ # Basic descriptive statistics
361
+ if params.get("descriptive", True):
362
+ stats["descriptive"] = self.dataframe[columns].describe().to_dict()
363
+
364
+ # Central tendency
365
+ if params.get("central_tendency", False):
366
+ stats["mean"] = self.dataframe[columns].mean().to_dict()
367
+ stats["median"] = self.dataframe[columns].median().to_dict()
368
+ # Mode is more complex as it can return multiple values
369
+ mode_results = {}
370
+ for col in columns:
371
+ if pd.api.types.is_numeric_dtype(self.dataframe[col]):
372
+ mode_vals = self.dataframe[col].mode().tolist()
373
+ mode_results[col] = mode_vals
374
+ stats["mode"] = mode_results
375
+
376
+ # Dispersion
377
+ if params.get("dispersion", False):
378
+ stats["variance"] = self.dataframe[columns].var().to_dict()
379
+ stats["std_dev"] = self.dataframe[columns].std().to_dict()
380
+ stats["range"] = {col: self.dataframe[col].max() - self.dataframe[col].min() for col in columns}
381
+ stats["iqr"] = {col: self.dataframe[col].quantile(0.75) - self.dataframe[col].quantile(0.25) for col in columns}
382
+
383
+ # Shape
384
+ if params.get("shape", False):
385
+ stats["skewness"] = self.dataframe[columns].skew().to_dict()
386
+ stats["kurtosis"] = self.dataframe[columns].kurtosis().to_dict()
387
+
388
+ return {
389
+ "status": "success",
390
+ "statistics": stats
391
+ }
392
+
393
+ except Exception as e:
394
+ return {"status": "error", "message": str(e)}
395
+
396
+ def _compute_correlation(self, params: Dict) -> Dict:
397
+ """Compute correlation between columns"""
398
+ if self.dataframe is None:
399
+ return {"status": "error", "message": "No dataset loaded"}
400
+
401
+ try:
402
+ # Get columns to compute correlation for
403
+ columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
404
+ method = params.get("method", "pearson") # pearson, kendall, spearman
405
+
406
+ corr_matrix = self.dataframe[columns].corr(method=method).to_dict()
407
+
408
+ # Find highest correlated pairs
409
+ corr_df = self.dataframe[columns].corr(method=method).unstack()
410
+ corr_df = corr_df[corr_df < 1.0] # Remove self-correlation
411
+ highest_corr = corr_df.sort_values(ascending=False)[:10].to_dict()
412
+
413
+ return {
414
+ "status": "success",
415
+ "correlation_matrix": corr_matrix,
416
+ "highest_correlations": highest_corr
417
+ }
418
+
419
+ except Exception as e:
420
+ return {"status": "error", "message": str(e)}
421
+
422
+ def _filter_data(self, params: Dict) -> Dict:
423
+ """Filter data based on conditions"""
424
+ if self.dataframe is None:
425
+ return {"status": "error", "message": "No dataset loaded"}
426
+
427
+ try:
428
+ # Apply filters
429
+ filtered_df = self.dataframe.copy()
430
+ filters = params.get("filters", [])
431
+
432
+ for filter_item in filters:
433
+ column = filter_item["column"]
434
+ operator = filter_item["operator"]
435
+ value = filter_item["value"]
436
+
437
+ if operator == "==":
438
+ filtered_df = filtered_df[filtered_df[column] == value]
439
+ elif operator == "!=":
440
+ filtered_df = filtered_df[filtered_df[column] != value]
441
+ elif operator == ">":
442
+ filtered_df = filtered_df[filtered_df[column] > value]
443
+ elif operator == "<":
444
+ filtered_df = filtered_df[filtered_df[column] < value]
445
+ elif operator == ">=":
446
+ filtered_df = filtered_df[filtered_df[column] >= value]
447
+ elif operator == "<=":
448
+ filtered_df = filtered_df[filtered_df[column] <= value]
449
+ elif operator == "in":
450
+ filtered_df = filtered_df[filtered_df[column].isin(value)]
451
+ elif operator == "not in":
452
+ filtered_df = filtered_df[~filtered_df[column].isin(value)]
453
+
454
+ # Store the filtered dataframe temporarily for use in subsequent operations
455
+ self.filtered_df = filtered_df
456
+
457
+ return {
458
+ "status": "success",
459
+ "original_rows": len(self.dataframe),
460
+ "filtered_rows": len(filtered_df),
461
+ "preview": filtered_df.head(5).to_dict(orient="records")
462
+ }
463
+
464
+ except Exception as e:
465
+ return {"status": "error", "message": str(e)}
466
+
467
+ def _compute_aggregation(self, params: Dict) -> Dict:
468
+ """Compute aggregation grouped by a column"""
469
+ if self.dataframe is None:
470
+ return {"status": "error", "message": "No dataset loaded"}
471
+
472
+ try:
473
+ # Get params
474
+ groupby_cols = params.get("groupby", [])
475
+ agg_cols = params.get("columns", [])
476
+ agg_funcs = params.get("functions", ["mean"])
477
+
478
+ # Use filtered dataframe if available, otherwise use original
479
+ df_to_use = getattr(self, "filtered_df", self.dataframe)
480
+
481
+ # Prepare aggregation dict
482
+ agg_dict = {col: agg_funcs for col in agg_cols}
483
+
484
+ # Compute aggregation
485
+ result = df_to_use.groupby(groupby_cols).agg(agg_dict).reset_index()
486
+
487
+ return {
488
+ "status": "success",
489
+ "result": result.to_dict(orient="records")
490
+ }
491
+
492
+ except Exception as e:
493
+ return {"status": "error", "message": str(e)}
494
+
495
+ def handle_message(self, message: MCPMessage) -> Dict:
496
+ """Handle incoming messages from other agents"""
497
+ if message.message_type == "request_data_load":
498
+ result = self._load_dataset(message.content)
499
+ return self.send_message(message.sender, "data_load_result", result)
500
+
501
+ elif message.message_type == "request_data_cleaning":
502
+ result = self._clean_data(message.content)
503
+ return self.send_message(message.sender, "data_cleaning_result", result)
504
+
505
+ elif message.message_type == "request_statistics":
506
+ result = self._compute_statistics(message.content)
507
+ return self.send_message(message.sender, "statistics_result", result)
508
+
509
+ elif message.message_type == "request_correlation":
510
+ result = self._compute_correlation(message.content)
511
+ return self.send_message(message.sender, "correlation_result", result)
512
+
513
+ elif message.message_type == "request_filter":
514
+ result = self._filter_data(message.content)
515
+ return self.send_message(message.sender, "filter_result", result)
516
+
517
+ elif message.message_type == "request_aggregation":
518
+ result = self._compute_aggregation(message.content)
519
+ return self.send_message(message.sender, "aggregation_result", result)
520
+
521
+ else:
522
+ return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
523
+
524
+
525
+ # ============== Interpret Agent Implementation ==============
526
+
527
+ class InterpretAgent(MCPAgent):
528
+ """Agent responsible for interpreting results and visualizing data"""
529
+
530
+ def __init__(self, name: str = "InterpretAgent", llm_model: Optional[str] = None, api_key: Optional[str] = None):
531
+ super().__init__(name, "Agent responsible for interpreting results and visualizing data", llm_model, api_key)
532
+ self.dataset_info = None
533
+ self.statistics = None
534
+ self.correlation_data = None
535
+ self.filter_results = None
536
+ self.aggregation_results = None
537
+ self.visualization_results = {}
538
+
539
+ # Register tools
540
+ self.register_tool(MCPTool(
541
+ "interpret_statistics",
542
+ "Interpret statistical results and provide insights",
543
+ self._interpret_statistics
544
+ ))
545
+
546
+ self.register_tool(MCPTool(
547
+ "interpret_correlation",
548
+ "Interpret correlation results and provide insights",
549
+ self._interpret_correlation
550
+ ))
551
+
552
+ self.register_tool(MCPTool(
553
+ "create_visualization",
554
+ "Create a visualization based on data",
555
+ self._create_visualization
556
+ ))
557
+
558
+ self.register_tool(MCPTool(
559
+ "generate_report",
560
+ "Generate a report with key findings",
561
+ self._generate_report
562
+ ))
563
+
564
+ def _interpret_statistics(self, params: Dict) -> Dict:
565
+ """Interpret statistical results and provide insights"""
566
+ if not self.statistics:
567
+ return {"status": "error", "message": "No statistics data available"}
568
+
569
+ try:
570
+ # If we have LLM access, use it for more advanced interpretation
571
+ if self.llm_model:
572
+ prompt = f"""
573
+ As a data analyst, interpret these statistics and provide insights:
574
+ {json.dumps(self.statistics, indent=2)}
575
+
576
+ Provide:
577
+ 1. 5 key insights about the data
578
+ 2. Any potential anomalies or interesting observations
579
+ 3. Any patterns or trends visible in the descriptive statistics
580
+ """
581
+
582
+ if self.llm_model.startswith("claude"):
583
+ client = anthropic.Anthropic(api_key=self.api_key)
584
+ response = client.messages.create(
585
+ model="claude-3-sonnet-20240229",
586
+ max_tokens=1000,
587
+ messages=[{"role": "user", "content": prompt}]
588
+ )
589
+ result = response.content[0].text
590
+ elif self.llm_model.startswith("gpt"):
591
+ client = openai.OpenAI(api_key=self.api_key)
592
+ response = client.chat.completions.create(
593
+ model="gpt-4o",
594
+ messages=[{"role": "user", "content": prompt}]
595
+ )
596
+ result = response.choices[0].message.content
597
+
598
+ self.log_llm_interaction(prompt, result)
599
+
600
+ return {
601
+ "status": "success",
602
+ "insights": result.split('\n'),
603
+ "summary": "Statistical analysis complete with LLM-generated insights."
604
+ }
605
+
606
+ # Fallback to rule-based insights if no LLM available
607
+ insights = []
608
+ stats = self.statistics.get("statistics", {})
609
+
610
+ # Analyze descriptive statistics
611
+ if "descriptive" in stats:
612
+ desc_stats = stats["descriptive"]
613
+
614
+ # Look at each numerical column
615
+ for col in desc_stats:
616
+ col_stats = desc_stats[col]
617
+
618
+ # Check for outliers using IQR method
619
+ q1 = col_stats.get("25%", 0)
620
+ q3 = col_stats.get("75%", 0)
621
+ iqr = q3 - q1
622
+ lower_bound = q1 - 1.5 * iqr
623
+ upper_bound = q3 + 1.5 * iqr
624
+
625
+ if col_stats.get("min", 0) < lower_bound or col_stats.get("max", 0) > upper_bound:
626
+ insights.append(f"Column '{col}' may contain outliers.")
627
+
628
+ # Check for skewness
629
+ mean = col_stats.get("mean", 0)
630
+ median = col_stats.get("50%", 0)
631
+ if abs(mean - median) > 0.1 * mean:
632
+ skew_direction = "right" if mean > median else "left"
633
+ insights.append(f"Column '{col}' appears to be skewed to the {skew_direction}.")
634
+
635
+ # Check for variability
636
+ std = col_stats.get("std", 0)
637
+ mean = col_stats.get("mean", 0)
638
+ cv = std / mean if mean != 0 else 0
639
+ if cv > 1:
640
+ insights.append(f"Column '{col}' shows high variability (CV > 1).")
641
+
642
+ return {
643
+ "status": "success",
644
+ "insights": insights,
645
+ "summary": "Statistical analysis reveals potential patterns and anomalies in the data."
646
+ }
647
+
648
+ except Exception as e:
649
+ return {"status": "error", "message": str(e)}
650
+
651
+ def _interpret_correlation(self, params: Dict) -> Dict:
652
+ """Interpret correlation results and provide insights"""
653
+ if not self.correlation_data:
654
+ return {"status": "error", "message": "No correlation data available"}
655
+
656
+ try:
657
+ # If we have LLM access, use it for more advanced interpretation
658
+ if self.llm_model:
659
+ prompt = f"""
660
+ As a data analyst, interpret this correlation data and provide insights:
661
+ {json.dumps(self.correlation_data, indent=2)}
662
+
663
+ Provide:
664
+ 1. The 5 most significant correlations found and what they might indicate
665
+ 2. Any interesting patterns of correlation in the dataset
666
+ 3. Suggestions for variables that might have causal relationships
667
+ """
668
+
669
+ if self.llm_model.startswith("claude"):
670
+ client = anthropic.Anthropic(api_key=self.api_key)
671
+ response = client.messages.create(
672
+ model="claude-3-sonnet-20240229",
673
+ max_tokens=1000,
674
+ messages=[{"role": "user", "content": prompt}]
675
+ )
676
+ result = response.content[0].text
677
+ elif self.llm_model.startswith("gpt"):
678
+ client = openai.OpenAI(api_key=self.api_key)
679
+ response = client.chat.completions.create(
680
+ model="gpt-4o",
681
+ messages=[{"role": "user", "content": prompt}]
682
+ )
683
+ result = response.choices[0].message.content
684
+
685
+ self.log_llm_interaction(prompt, result)
686
+
687
+ return {
688
+ "status": "success",
689
+ "insights": result.split('\n'),
690
+ "summary": "Correlation analysis complete with LLM-generated insights."
691
+ }
692
+
693
+ # Fallback to rule-based insights if no LLM available
694
+ insights = []
695
+ corr_matrix = self.correlation_data.get("correlation_matrix", {})
696
+ highest_corr = self.correlation_data.get("highest_correlations", {})
697
+
698
+ # Find strong positive correlations
699
+ strong_pos_corr = [(k, v) for k, v in highest_corr.items() if v > 0.7]
700
+ if strong_pos_corr:
701
+ for (col1, col2), value in strong_pos_corr[:3]:
702
+ insights.append(f"Strong positive correlation ({value:.2f}) between '{col1}' and '{col2}'.")
703
+
704
+ # Find strong negative correlations
705
+ strong_neg_corr = [(k, v) for k, v in highest_corr.items() if v < -0.7]
706
+ if strong_neg_corr:
707
+ for (col1, col2), value in strong_neg_corr[:3]:
708
+ insights.append(f"Strong negative correlation ({value:.2f}) between '{col1}' and '{col2}'.")
709
+
710
+ # Identify potential multicollinearity
711
+ multi_corr = [(k, v) for k, v in highest_corr.items() if abs(v) > 0.9]
712
+ if multi_corr:
713
+ insights.append("Potential multicollinearity detected between some features.")
714
+
715
+ return {
716
+ "status": "success",
717
+ "insights": insights,
718
+ "summary": "Correlation analysis reveals interesting relationships between variables."
719
+ }
720
+
721
+ except Exception as e:
722
+ return {"status": "error", "message": str(e)}
723
+
724
+ def _create_visualization(self, params: Dict) -> Dict:
725
+ """Create a visualization based on data"""
726
+ try:
727
+ viz_type = params.get("type", "histogram")
728
+ title = params.get("title", "Data Visualization")
729
+ x_column = params.get("x", None)
730
+ y_column = params.get("y", None)
731
+
732
+ # Generate a unique ID for this visualization
733
+ viz_id = str(uuid.uuid4())
734
+
735
+ # Create the visualization and save it to a file
736
+ plt.figure(figsize=(10, 6))
737
+
738
+ if not hasattr(self, "compute_agent") or not hasattr(self.compute_agent, "dataframe"):
739
+ return {"status": "error", "message": "No data available for visualization"}
740
+
741
+ df = self.compute_agent.dataframe
742
+
743
+ if viz_type == "histogram":
744
+ if x_column:
745
+ sns.histplot(df[x_column], kde=True)
746
+ plt.xlabel(x_column)
747
+ plt.ylabel("Frequency")
748
+ else:
749
+ return {"status": "error", "message": "Column name required for histogram"}
750
+
751
+ elif viz_type == "scatter":
752
+ if x_column and y_column:
753
+ sns.scatterplot(x=df[x_column], y=df[y_column])
754
+ plt.xlabel(x_column)
755
+ plt.ylabel(y_column)
756
+ else:
757
+ return {"status": "error", "message": "X and Y column names required for scatter plot"}
758
+
759
+ elif viz_type == "bar":
760
+ if x_column and y_column:
761
+ sns.barplot(x=df[x_column], y=df[y_column])
762
+ plt.xlabel(x_column)
763
+ plt.ylabel(y_column)
764
+ else:
765
+ return {"status": "error", "message": "X and Y column names required for bar chart"}
766
+
767
+ elif viz_type == "boxplot":
768
+ if x_column:
769
+ sns.boxplot(y=df[x_column])
770
+ plt.ylabel(x_column)
771
+ elif x_column and y_column:
772
+ sns.boxplot(x=df[x_column], y=df[y_column])
773
+ plt.xlabel(x_column)
774
+ plt.ylabel(y_column)
775
+ else:
776
+ return {"status": "error", "message": "At least one column name required for boxplot"}
777
+
778
+ elif viz_type == "heatmap":
779
+ if params.get("columns"):
780
+ corr = df[params["columns"]].corr()
781
+ sns.heatmap(corr, annot=True, cmap="coolwarm")
782
+ else:
783
+ corr = df.select_dtypes(include=[np.number]).corr()
784
+ sns.heatmap(corr, annot=True, cmap="coolwarm")
785
+
786
+ plt.title(title)
787
+ plt.tight_layout()
788
+
789
+ # Save the visualization
790
+ viz_filename = f"viz_{viz_id}.png"
791
+ plt.savefig(viz_filename)
792
+ plt.close()
793
+
794
+ # Store visualization details
795
+ viz_details = {
796
+ "id": viz_id,
797
+ "type": viz_type,
798
+ "title": title,
799
+ "filename": viz_filename,
800
+ "x_column": x_column,
801
+ "y_column": y_column
802
+ }
803
+
804
+ self.visualization_results[viz_id] = viz_details
805
+
806
+ return {
807
+ "status": "success",
808
+ "visualization": viz_details
809
+ }
810
+
811
+ except Exception as e:
812
+ return {"status": "error", "message": str(e)}
813
+
814
+ def _generate_report(self, params: Dict) -> Dict:
815
+ """Generate a report with key findings"""
816
+ try:
817
+ # If LLM available, use it for advanced report generation
818
+ if self.llm_model:
819
+ # Gather all the data we have
820
+ report_data = {
821
+ "dataset_info": self.dataset_info,
822
+ "statistics": self.statistics,
823
+ "correlation_data": self.correlation_data,
824
+ "filter_results": self.filter_results,
825
+ "aggregation_results": self.aggregation_results
826
+ }
827
+
828
+ prompt = f"""
829
+ As a data analyst, generate a comprehensive report based on the following analysis data:
830
+ {json.dumps(report_data, indent=2)}
831
+
832
+ The report should include:
833
+ 1. Dataset Overview
834
+ 2. Key Findings from Statistical Analysis
835
+ 3. Correlation Analysis Highlights
836
+ 4. Filtered Data Analysis (if applicable)
837
+ 5. Aggregation Insights (if applicable)
838
+ 6. Conclusions and Recommendations
839
+
840
+ Format the report in a professional style with clear sections.
841
+ """
842
+
843
+ if self.llm_model.startswith("claude"):
844
+ client = anthropic.Anthropic(api_key=self.api_key)
845
+ response = client.messages.create(
846
+ model="claude-3-sonnet-20240229",
847
+ max_tokens=2000,
848
+ messages=[{"role": "user", "content": prompt}]
849
+ )
850
+ result = response.content[0].text
851
+ elif self.llm_model.startswith("gpt"):
852
+ client = openai.OpenAI(api_key=self.api_key)
853
+ response = client.chat.completions.create(
854
+ model="gpt-4o",
855
+ messages=[{"role": "user", "content": prompt}]
856
+ )
857
+ result = response.choices[0].message.content
858
+
859
+ self.log_llm_interaction(prompt, result)
860
+
861
+ return {
862
+ "status": "success",
863
+ "report": {
864
+ "title": params.get("report_title", "Data Analysis Report"),
865
+ "content": result
866
+ }
867
+ }
868
+
869
+ # Fallback to template-based report if no LLM available
870
+ # Gather all the insights and results
871
+ report_sections = []
872
+
873
+ # Dataset overview
874
+ if self.dataset_info:
875
+ report_sections.append({
876
+ "title": "Dataset Overview",
877
+ "content": f"The dataset contains {self.dataset_info.get('rows', 0)} rows and {len(self.dataset_info.get('columns', []))} columns."
878
+ })
879
+
880
+ # Statistical insights
881
+ if self.statistics:
882
+ # Interpret statistics if not already done
883
+ if not hasattr(self, 'stat_insights'):
884
+ self.stat_insights = self._interpret_statistics({}).get('insights', [])
885
+
886
+ report_sections.append({
887
+ "title": "Statistical Analysis",
888
+ "content": "Key findings from statistical analysis:",
889
+ "insights": self.stat_insights
890
+ })
891
+
892
+ # Correlation insights
893
+ if self.correlation_data:
894
+ # Interpret correlations if not already done
895
+ if not hasattr(self, 'corr_insights'):
896
+ self.corr_insights = self._interpret_correlation({}).get('insights', [])
897
+
898
+ report_sections.append({
899
+ "title": "Correlation Analysis",
900
+ "content": "Key findings from correlation analysis:",
901
+ "insights": self.corr_insights
902
+ })
903
+
904
+ # Filter results
905
+ if self.filter_results:
906
+ report_sections.append({
907
+ "title": "Filtered Data Analysis",
908
+ "content": f"The filtered dataset contains {self.filter_results.get('filtered_rows', 0)} rows, down from {self.filter_results.get('original_rows', 0)} rows."
909
+ })
910
+
911
+ # Aggregation results
912
+ if self.aggregation_results:
913
+ report_sections.append({
914
+ "title": "Aggregation Analysis",
915
+ "content": "Key insights from aggregated data:",
916
+ "data": self.aggregation_results.get('result', [])
917
+ })
918
+
919
+ # Conclusions
920
+ report_sections.append({
921
+ "title": "Conclusions",
922
+ "content": "Based on the analysis, several patterns and relationships have been identified in the data."
923
+ })
924
+
925
+ return {
926
+ "status": "success",
927
+ "report": {
928
+ "title": params.get("report_title", "Data Analysis Report"),
929
+ "sections": report_sections
930
+ }
931
+ }
932
+
933
+ except Exception as e:
934
+ return {"status": "error", "message": str(e)}
935
+
936
+ def handle_message(self, message: MCPMessage) -> Dict:
937
+ """Handle incoming messages from other agents"""
938
+ if message.message_type == "data_load_result":
939
+ self.dataset_info = message.content
940
+ return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Dataset info received"})
941
+
942
+ elif message.message_type == "data_cleaning_result":
943
+ return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Data cleaning result received"})
944
+
945
+ elif message.message_type == "statistics_result":
946
+ self.statistics = message.content
947
+ insights = self._interpret_statistics({})
948
+ return self.send_message(message.sender, "statistics_interpretation", insights)
949
+
950
+ elif message.message_type == "correlation_result":
951
+ self.correlation_data = message.content
952
+ insights = self._interpret_correlation({})
953
+ return self.send_message(message.sender, "correlation_interpretation", insights)
954
+
955
+ elif message.message_type == "filter_result":
956
+ self.filter_results = message.content
957
+ return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Filter result received"})
958
+
959
+ elif message.message_type == "aggregation_result":
960
+ self.aggregation_results = message.content
961
+ return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Aggregation result received"})
962
+
963
+ elif message.message_type == "request_report":
964
+ report = self._generate_report(message.content)
965
+ return self.send_message(message.sender, "report_result", report)
966
+
967
+ elif message.message_type == "request_visualization":
968
+ visualization = self._create_visualization(message.content)
969
+ return self.send_message(message.sender, "visualization_result", visualization)
970
+
971
+ else:
972
+ return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
973
+
974
+ def set_compute_agent(self, compute_agent):
975
+ """Set reference to compute agent for access to dataframe"""
976
+ self.compute_agent = compute_agent
977
+
978
+
979
+ # ============== Main Analysis Workflow ==============
980
+
981
+ class DataAnalystDuo:
982
+ """Main class for the Data Analyst Duo MCP implementation"""
983
+
984
+ def __init__(self, llm_model: Optional[str] = None, api_key: Optional[str] = None):
985
+ self.compute_agent = ComputeAgent(llm_model=llm_model, api_key=api_key)
986
+ self.interpret_agent = InterpretAgent(llm_model=llm_model, api_key=api_key)
987
+
988
+ # Connect the agents as peers
989
+ self.compute_agent.connect(self.interpret_agent)
990
+ self.interpret_agent.connect(self.compute_agent)
991
+
992
+ # Set reference to compute agent inside interpret agent
993
+ self.interpret_agent.set_compute_agent(self.compute_agent)
994
+
995
+ # Logs to store message flow and intermediate results
996
+ self.logs = []
997
+
998
+ def log_step(self, step_name: str, details: Dict):
999
+ """Log a step in the analysis workflow"""
1000
+ log_entry = {
1001
+ "timestamp": datetime.datetime.now().isoformat(),
1002
+ "step": step_name,
1003
+ "details": details
1004
+ }
1005
+ self.logs.append(log_entry)
1006
+ return log_entry
1007
+
1008
+ def run_analysis(self, dataset_url: str, analysis_params: Dict = None) -> Dict:
1009
+ """Run the complete analysis workflow"""
1010
+ if analysis_params is None:
1011
+ analysis_params = {}
1012
+
1013
+ results = {}
1014
+
1015
+ # 1. Load dataset
1016
+ self.log_step("Initiating dataset loading", {"url": dataset_url})
1017
+ self.interpret_agent.send_message("ComputeAgent", "request_data_load", {"url": dataset_url})
1018
+ self.compute_agent.process_messages()
1019
+ self.interpret_agent.process_messages()
1020
+
1021
+ # 2. Clean data
1022
+ clean_params = analysis_params.get("clean_params", {"missing_strategy": "mean", "remove_duplicates": True})
1023
+ self.log_step("Initiating data cleaning", clean_params)
1024
+ self.interpret_agent.send_message("ComputeAgent", "request_data_cleaning", clean_params)
1025
+ self.compute_agent.process_messages()
1026
+ self.interpret_agent.process_messages()
1027
+
1028
+ # 3. Compute statistics
1029
+ stats_params = analysis_params.get("stats_params", {"descriptive": True, "central_tendency": True, "dispersion": True})
1030
+ self.log_step("Initiating statistical analysis", stats_params)
1031
+ self.interpret_agent.send_message("ComputeAgent", "request_statistics", stats_params)
1032
+ self.compute_agent.process_messages()
1033
+ self.interpret_agent.process_messages()
1034
+
1035
+ # 4. Compute correlation
1036
+ corr_params = analysis_params.get("corr_params", {"method": "pearson"})
1037
+ self.log_step("Initiating correlation analysis", corr_params)
1038
+ self.interpret_agent.send_message("ComputeAgent", "request_correlation", corr_params)
1039
+ self.compute_agent.process_messages()
1040
+ self.interpret_agent.process_messages()
1041
+
1042
+ # 5. Filter data if requested
1043
+ if "filter_params" in analysis_params:
1044
+ self.log_step("Initiating data filtering", analysis_params["filter_params"])
1045
+ self.interpret_agent.send_message("ComputeAgent", "request_filter", analysis_params["filter_params"])
1046
+ self.compute_agent.process_messages()
1047
+ self.interpret_agent.process_messages()
1048
+
1049
+ # 6. Compute aggregation if requested
1050
+ if "agg_params" in analysis_params:
1051
+ self.log_step("Initiating data aggregation", analysis_params["agg_params"])
1052
+ self.interpret_agent.send_message("ComputeAgent", "request_aggregation", analysis_params["agg_params"])
1053
+ self.compute_agent.process_messages()
1054
+ self.interpret_agent.process_messages()
1055
+
1056
+ # 7. Create visualizations if requested
1057
+ if "viz_params" in analysis_params:
1058
+ for viz_param in analysis_params["viz_params"]:
1059
+ self.log_step("Initiating visualization creation", viz_param)
1060
+ self.compute_agent.send_message("InterpretAgent", "request_visualization", viz_param)
1061
+ self.interpret_agent.process_messages()
1062
+ self.compute_agent.process_messages()
1063
+
1064
+ # 8. Generate final report
1065
+ report_params = analysis_params.get("report_params", {"report_title": "Data Analysis Report"})
1066
+ self.log_step("Generating final report", report_params)
1067
+ self.compute_agent.send_message("InterpretAgent", "request_report", report_params)
1068
+ self.interpret_agent.process_messages()
1069
+ self.compute_agent.process_messages()
1070
+
1071
+ # Collect results
1072
+ results["dataset_info"] = self.interpret_agent.dataset_info
1073
+ results["statistics"] = self.interpret_agent.statistics
1074
+ results["correlation_data"] = self.interpret_agent.correlation_data
1075
+ results["filter_results"] = self.interpret_agent.filter_results
1076
+ results["aggregation_results"] = self.interpret_agent.aggregation_results
1077
+ results["visualizations"] = self.interpret_agent.visualization_results
1078
+ results["compute_agent_messages"] = self.compute_agent.get_message_history()
1079
+ results["interpret_agent_messages"] = self.interpret_agent.get_message_history()
1080
+ results["compute_agent_llm_logs"] = self.compute_agent.get_llm_logs()
1081
+ results["interpret_agent_llm_logs"] = self.interpret_agent.get_llm_logs()
1082
+ results["workflow_logs"] = self.logs
1083
+
1084
+ return results
1085
+
1086
+
1087
+ # ============== Gradio Interface ==============
1088
+
1089
+ def format_json(json_data):
1090
+ """Format JSON data for display"""
1091
+ if isinstance(json_data, dict) or isinstance(json_data, list):
1092
+ return json.dumps(json_data, indent=2)
1093
+ return str(json_data)
1094
+
1095
+ def run_data_analysis(dataset_url, llm_provider, api_key, missing_strategy, create_visualizations, high_fiber_filter):
1096
+ """Run the data analysis workflow and return results"""
1097
+ try:
1098
+ # Validate inputs
1099
+ if not dataset_url:
1100
+ dataset_url = "default" # Use default cereals dataset
1101
+
1102
+ if llm_provider != "none" and not api_key:
1103
+ return {
1104
+ 'mcp_messages': "Error: API key is required for LLM integration",
1105
+ 'llm_logs': "",
1106
+ 'visualizations': "",
1107
+ 'final_report': ""
1108
+ }
1109
+
1110
+ # Initialize the analyst duo
1111
+ llm_model = None
1112
+ if llm_provider == "claude":
1113
+ llm_model = "claude"
1114
+ elif llm_provider == "gpt":
1115
+ llm_model = "gpt"
1116
+ if not api_key:
1117
+ api_key = os.environ.get("OPENAI_API_KEY", "")
1118
+
1119
+ # Create the data analyst duo
1120
+ duo = DataAnalystDuo(llm_model=llm_model, api_key=api_key)
1121
+
1122
+ # Prepare analysis parameters
1123
+ analysis_params = {
1124
+ "clean_params": {
1125
+ "missing_strategy": missing_strategy,
1126
+ "remove_duplicates": True
1127
+ },
1128
+ "stats_params": {
1129
+ "descriptive": True,
1130
+ "central_tendency": True,
1131
+ "dispersion": True
1132
+ },
1133
+ "corr_params": {
1134
+ "method": "pearson"
1135
+ }
1136
+ }
1137
+
1138
+ # Add filter for high fiber if requested
1139
+ if high_fiber_filter:
1140
+ analysis_params["filter_params"] = {
1141
+ "filters": [
1142
+ {"column": "fiber", "operator": ">", "value": 5}
1143
+ ]
1144
+ }
1145
+
1146
+ # Add aggregation by manufacturer
1147
+ analysis_params["agg_params"] = {
1148
+ "groupby": ["mfr"],
1149
+ "columns": ["calories", "protein", "fat", "fiber", "sugars"],
1150
+ "functions": ["mean", "min", "max"]
1151
+ }
1152
+
1153
+ # Add visualizations if requested
1154
+ if create_visualizations:
1155
+ analysis_params["viz_params"] = [
1156
+ {
1157
+ "type": "scatter",
1158
+ "title": "Calories vs Sugar Content",
1159
+ "x": "calories",
1160
+ "y": "sugars"
1161
+ },
1162
+ {
1163
+ "type": "histogram",
1164
+ "title": "Distribution of Fiber Content",
1165
+ "x": "fiber"
1166
+ },
1167
+ {
1168
+ "type": "heatmap",
1169
+ "title": "Correlation Matrix",
1170
+ "columns": ["calories", "protein", "fat", "fiber", "sugars", "rating"]
1171
+ }
1172
+ ]
1173
+
1174
+ # Run the analysis
1175
+ results = duo.run_analysis(dataset_url, analysis_params)
1176
+
1177
+ # Extract MCP messages for display
1178
+ compute_messages = results["compute_agent_messages"]
1179
+ interpret_messages = results["interpret_agent_messages"]
1180
+
1181
+ # Extract LLM logs
1182
+ compute_llm_logs = results["compute_agent_llm_logs"]
1183
+ interpret_llm_logs = results["interpret_agent_llm_logs"]
1184
+
1185
+ # Format messages for display
1186
+ formatted_messages = []
1187
+
1188
+ # Combine and sort messages by timestamp
1189
+ all_messages = []
1190
+ for msg in compute_messages:
1191
+ msg_copy = msg.copy()
1192
+ msg_copy["agent"] = "ComputeAgent"
1193
+ all_messages.append(msg_copy)
1194
+
1195
+ for msg in interpret_messages:
1196
+ msg_copy = msg.copy()
1197
+ msg_copy["agent"] = "InterpretAgent"
1198
+ all_messages.append(msg_copy)
1199
+
1200
+ # Sort by timestamp
1201
+ all_messages.sort(key=lambda x: x["message"]["timestamp"])
1202
+
1203
+ # Format for display
1204
+ for msg in all_messages:
1205
+ agent = msg["agent"]
1206
+ direction = msg["type"]
1207
+ message = msg["message"]
1208
+
1209
+ formatted_msg = f"[{message['timestamp']}] {agent} {direction.upper()} - Type: {message['message_type']}\n"
1210
+ formatted_msg += format_json(message['content'])
1211
+ formatted_msg += "\n\n" + "-"*80 + "\n\n"
1212
+ formatted_messages.append(formatted_msg)
1213
+
1214
+ # Format LLM logs
1215
+ formatted_llm_logs = []
1216
+
1217
+ for log in compute_llm_logs + interpret_llm_logs:
1218
+ formatted_log = f"[{log['timestamp']}]\n"
1219
+ formatted_log += "PROMPT:\n" + log['prompt'] + "\n\n"
1220
+ formatted_log += "RESPONSE:\n" + log['response'] + "\n\n"
1221
+ formatted_log += "-"*80 + "\n\n"
1222
+ formatted_llm_logs.append(formatted_log)
1223
+
1224
+ # Prepare visualization display
1225
+ viz_html = ""
1226
+ if create_visualizations and "visualizations" in results and results["visualizations"]:
1227
+ viz_html = "<div style='display: flex; flex-wrap: wrap;'>"
1228
+ for viz_id, viz_data in results["visualizations"].items():
1229
+ viz_html += f"<div style='margin: 10px;'>"
1230
+ viz_html += f"<h3>{viz_data['title']}</h3>"
1231
+ viz_html += f"<img src='file={viz_data['filename']}' width='400' />"
1232
+ viz_html += "</div>"
1233
+ viz_html += "</div>"
1234
+
1235
+ # Get the final report
1236
+ report_html = "<h2>No report generated</h2>"
1237
+ if "report_result" in [msg["message"]["message_type"] for msg in compute_messages if msg["type"] == "received"]:
1238
+ # Find the report message
1239
+ for msg in compute_messages:
1240
+ if msg["type"] == "received" and msg["message"]["message_type"] == "report_result":
1241
+ report_content = msg["message"]["content"]["report"]
1242
+ if "content" in report_content:
1243
+ # LLM-generated report
1244
+ report_html = f"<h2>{report_content['title']}</h2>"
1245
+ report_html += f"<div>{report_content['content'].replace('\n', '<br/>')}</div>"
1246
+ elif "sections" in report_content:
1247
+ # Template-based report
1248
+ report_html = f"<h2>{report_content['title']}</h2>"
1249
+ for section in report_content["sections"]:
1250
+ report_html += f"<h3>{section['title']}</h3>"
1251
+ report_html += f"<p>{section['content']}</p>"
1252
+ if "insights" in section:
1253
+ report_html += "<ul>"
1254
+ for insight in section["insights"]:
1255
+ report_html += f"<li>{insight}</li>"
1256
+ report_html += "</ul>"
1257
+ if "data" in section:
1258
+ report_html += "<pre>" + format_json(section["data"]) + "</pre>"
1259
+
1260
+ # Return all results
1261
+ return {
1262
+ 'mcp_messages': "\n".join(formatted_messages),
1263
+ 'llm_logs': "\n".join(formatted_llm_logs),
1264
+ 'visualizations': viz_html,
1265
+ 'final_report': report_html
1266
+ }
1267
+
1268
+ except Exception as e:
1269
+ import traceback
1270
+ return {
1271
+ 'mcp_messages': f"Error: {str(e)}\n\n{traceback.format_exc()}",
1272
+ 'llm_logs': "",
1273
+ 'visualizations': "",
1274
+ 'final_report': ""
1275
+ }
1276
+
1277
+ # Define the Gradio interface
1278
+ def create_interface():
1279
+ with gr.Blocks(title="Data Analyst Duo - MCP Implementation") as app:
1280
+ gr.Markdown("""
1281
+ # Data Analyst Duo - Model Context Protocol (MCP) Implementation
1282
+
1283
+ This application demonstrates a multi-agent system using the Model Context Protocol (MCP).
1284
+ It consists of two agents:
1285
+
1286
+ 1. **ComputeAgent**: Responsible for data loading, cleaning, and computation
1287
+ 2. **InterpretAgent**: Responsible for interpreting results and visualizing data
1288
+
1289
+ The agents communicate directly using standardized MCP messages, showcasing agent-to-agent communication.
1290
+ """)
1291
+
1292
+ with gr.Row():
1293
+ with gr.Column():
1294
+ dataset_url = gr.Textbox(label="Dataset URL (leave empty for default cereals dataset)", placeholder="Enter dataset URL or leave empty for default")
1295
+
1296
+ with gr.Row():
1297
+ llm_provider = gr.Radio(["none", "claude", "gpt"], label="LLM Provider (Optional)", value="none")
1298
+ api_key = gr.Textbox(label="API Key (if using LLM)", placeholder="Enter API key if using Claude or GPT")
1299
+
1300
+ with gr.Row():
1301
+ missing_strategy = gr.Dropdown(["drop", "mean", "median", "mode"], label="Missing Values Strategy", value="mean")
1302
+ create_visualizations = gr.Checkbox(label="Create Visualizations", value=True)
1303
+ high_fiber_filter = gr.Checkbox(label="Filter for High Fiber & Aggregate by Manufacturer", value=True)
1304
+
1305
+ run_button = gr.Button("Run Data Analysis")
1306
+
1307
+ with gr.Row():
1308
+ with gr.Tab("MCP Messages"):
1309
+ mcp_messages = gr.Textbox(label="MCP Message Flow", lines=20)
1310
+ with gr.Tab("LLM Logs"):
1311
+ llm_logs = gr.Textbox(label="LLM Interaction Logs", lines=20)
1312
+
1313
+ with gr.Row():
1314
+ with gr.Tab("Visualizations"):
1315
+ visualizations = gr.HTML(label="Data Visualizations")
1316
+ with gr.Tab("Final Report"):
1317
+ final_report = gr.HTML(label="Analysis Report")
1318
+
1319
+ # Connect the button to the analysis function
1320
+ run_button.click(
1321
+ fn=run_data_analysis,
1322
+ inputs=[dataset_url, llm_provider, api_key, missing_strategy, create_visualizations, high_fiber_filter],
1323
+ outputs=[mcp_messages, llm_logs, visualizations, final_report]
1324
+ )
1325
+
1326
+ gr.Markdown("""
1327
+ ## How This Demonstrates MCP
1328
+
1329
+ This application shows the Model Context Protocol in action:
1330
+
1331
+ 1. **Standardized Message Structure**: All communication between agents follows a consistent format
1332
+ 2. **Tool Registration**: Agents register their capabilities as tools with descriptions
1333
+ 3. **Direct Peer Communication**: Agents communicate directly with structured messages
1334
+ 4. **Asynchronous Processing**: Each agent processes messages independently
1335
+
1336
+ The message flow display shows the exact JSON messages exchanged between agents, demonstrating the protocol in action.
1337
+ """)
1338
+
1339
+ return app
1340
+
1341
+ # Create and launch the interface
1342
+ if __name__ == "__main__":
1343
+ app = create_interface()
1344
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.13.0
2
+ pandas==2.1.1
3
+ numpy==1.26.0
4
+ matplotlib==3.8.0
5
+ seaborn==0.13.0
6
+ anthropic==0.8.1
7
+ openai==1.1.1
8
+ python-dotenv==1.0.0
9
+ requests==2.31.0