Nyha15 commited on
Commit
60fbe6c
·
1 Parent(s): 84900a9

Simplified Implementation to avoid html errors

Browse files
Files changed (2) hide show
  1. app.py +116 -982
  2. requirements.txt +0 -5
app.py CHANGED
@@ -1,35 +1,17 @@
1
  """
2
- Data Analyst Duo MCP Implementation
3
-
4
- This script implements a multi-agent system using the Model Context Protocol (MCP).
5
- It features two agents that collaborate on data analysis tasks:
6
- - ComputeAgent: Responsible for data loading, cleaning, and computation
7
- - InterpretAgent: Responsible for interpreting results and visualizing data
8
-
9
- The application includes a Gradio interface for interaction.
10
  """
11
 
12
  import os
13
- import sys
14
  import json
15
- import time
16
  import datetime
17
  import gradio as gr
18
  import pandas as pd
19
  import numpy as np
20
- import matplotlib.pyplot as plt
21
- import seaborn as sns
22
- from typing import Dict, List, Any, Optional, Union, Tuple
23
  import requests
24
  from io import StringIO
25
  import logging
26
  import uuid
27
- import anthropic
28
- import openai
29
- from dotenv import load_dotenv
30
-
31
- # Load environment variables
32
- load_dotenv()
33
 
34
  # Configure logging
35
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -40,14 +22,14 @@ logger = logging.getLogger(__name__)
40
  class MCPMessage:
41
  """Base class for MCP messages that agents exchange"""
42
 
43
- def __init__(self, sender: str, message_type: str, content: Any):
44
  self.id = str(uuid.uuid4())
45
  self.sender = sender
46
  self.message_type = message_type
47
  self.content = content
48
  self.timestamp = datetime.datetime.now().isoformat()
49
 
50
- def to_dict(self) -> Dict:
51
  return {
52
  "id": self.id,
53
  "sender": self.sender,
@@ -56,72 +38,39 @@ class MCPMessage:
56
  "timestamp": self.timestamp
57
  }
58
 
59
- @staticmethod
60
- def from_dict(data: Dict) -> 'MCPMessage':
61
- msg = MCPMessage(
62
- sender=data["sender"],
63
- message_type=data["message_type"],
64
- content=data["content"]
65
- )
66
- # Restore ID and timestamp if present
67
- if "id" in data:
68
- msg.id = data["id"]
69
- if "timestamp" in data:
70
- msg.timestamp = data["timestamp"]
71
- return msg
72
-
73
 
74
  class MCPTool:
75
  """Defines a tool that can be used by agents through the MCP protocol"""
76
 
77
- def __init__(self, name: str, description: str, function):
78
  self.name = name
79
  self.description = description
80
  self.function = function
81
 
82
- def to_dict(self) -> Dict:
83
- return {
84
- "name": self.name,
85
- "description": self.description
86
- }
87
-
88
- def execute(self, params: Dict) -> Any:
89
  return self.function(params)
90
 
91
 
92
  class MCPAgent:
93
  """Base agent class implementing MCP protocol"""
94
 
95
- def __init__(self, name: str, description: str, llm_model: Optional[str] = None, api_key: Optional[str] = None):
96
  self.name = name
97
  self.description = description
98
- self.tools: Dict[str, MCPTool] = {}
99
- self.message_queue: List[MCPMessage] = []
100
- self.peers: Dict[str, 'MCPAgent'] = {}
101
- self.message_history: List[Dict] = []
102
- self.llm_model = llm_model
103
- self.api_key = api_key
104
- self.llm_logs = []
105
-
106
- def register_tool(self, tool: MCPTool):
107
  """Register a tool that this agent can use"""
108
  self.tools[tool.name] = tool
109
 
110
- def list_tools(self) -> List[Dict]:
111
- """List all tools available to this agent"""
112
- return [tool.to_dict() for tool in self.tools.values()]
113
-
114
- def call_tool(self, tool_name: str, params: Dict) -> Any:
115
- """Call a tool by name with parameters"""
116
- if tool_name not in self.tools:
117
- raise ValueError(f"Tool {tool_name} not found")
118
- return self.tools[tool_name].execute(params)
119
-
120
- def connect(self, peer: 'MCPAgent'):
121
  """Connect to another agent as a peer"""
122
  self.peers[peer.name] = peer
123
 
124
- def send_message(self, receiver: str, message_type: str, content: Any) -> Dict:
125
  """Send a message to a peer agent"""
126
  if receiver not in self.peers:
127
  raise ValueError(f"Peer {receiver} not found")
@@ -140,7 +89,7 @@ class MCPAgent:
140
  logger.info(f"Agent {self.name} sent {message_type} to {receiver}")
141
  return message_dict
142
 
143
- def receive_message(self, message: MCPMessage):
144
  """Receive a message from a peer agent"""
145
  self.message_queue.append(message)
146
 
@@ -152,7 +101,7 @@ class MCPAgent:
152
 
153
  logger.info(f"Agent {self.name} received {message.message_type} from {message.sender}")
154
 
155
- def process_messages(self) -> List[Dict]:
156
  """Process all messages in the queue"""
157
  processed = []
158
  while self.message_queue:
@@ -161,52 +110,31 @@ class MCPAgent:
161
  processed.append(response)
162
  return processed
163
 
164
- def handle_message(self, message: MCPMessage) -> Dict:
165
  """Handle a message - to be implemented by subclasses"""
166
  raise NotImplementedError("Subclasses must implement handle_message")
167
 
168
- def log_llm_interaction(self, prompt: str, response: str):
169
- """Log LLM interactions for transparency"""
170
- log_entry = {
171
- "timestamp": datetime.datetime.now().isoformat(),
172
- "prompt": prompt,
173
- "response": response
174
- }
175
- self.llm_logs.append(log_entry)
176
- return log_entry
177
-
178
- def get_message_history(self) -> List[Dict]:
179
  """Get the agent's message history"""
180
  return self.message_history
181
 
182
- def get_llm_logs(self) -> List[Dict]:
183
- """Get the agent's LLM interaction logs"""
184
- return self.llm_logs
185
-
186
 
187
  # ============== Compute Agent Implementation ==============
188
 
189
  class ComputeAgent(MCPAgent):
190
  """Agent responsible for data loading, cleaning, and computation"""
191
 
192
- def __init__(self, name: str = "ComputeAgent", llm_model: Optional[str] = None, api_key: Optional[str] = None):
193
- super().__init__(name, "Agent responsible for data loading, cleaning and computation", llm_model, api_key)
194
  self.dataframe = None
195
- self.current_task = None
196
 
197
  # Register tools
198
  self.register_tool(MCPTool(
199
  "load_dataset",
200
- "Load a dataset from Kaggle or URL",
201
  self._load_dataset
202
  ))
203
 
204
- self.register_tool(MCPTool(
205
- "clean_data",
206
- "Clean the loaded dataset by handling missing values, duplicates, etc.",
207
- self._clean_data
208
- ))
209
-
210
  self.register_tool(MCPTool(
211
  "compute_statistics",
212
  "Compute basic statistics on the dataset",
@@ -219,67 +147,15 @@ class ComputeAgent(MCPAgent):
219
  self._compute_correlation
220
  ))
221
 
222
- self.register_tool(MCPTool(
223
- "filter_data",
224
- "Filter data based on conditions",
225
- self._filter_data
226
- ))
227
-
228
- self.register_tool(MCPTool(
229
- "compute_aggregation",
230
- "Compute aggregation (sum, mean, etc.) grouped by a column",
231
- self._compute_aggregation
232
- ))
233
-
234
- def _load_dataset(self, params: Dict) -> Dict:
235
- """Load a dataset from Kaggle or URL"""
236
  dataset_url = params.get("url")
237
 
238
  try:
239
- # Check if it's the default cereals dataset
240
- if dataset_url == "default" or dataset_url.lower() == "cereals":
241
  dataset_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/cereal.csv"
242
 
243
- # Check if it's a Kaggle URL and extract the dataset path
244
- elif "kaggle.com/datasets" in dataset_url:
245
- # For simplicity, we use direct download links
246
- # In real implementation, you would use the Kaggle API
247
- prompt = f"""
248
- I have a Kaggle dataset URL: {dataset_url}
249
- Find the direct download link or alternative source for this dataset if possible.
250
- If not, suggest a suitable replacement dataset that's freely available.
251
- """
252
-
253
- if self.llm_model and self.llm_model.startswith("claude"):
254
- client = anthropic.Anthropic(api_key=self.api_key)
255
- response = client.messages.create(
256
- model="claude-3-sonnet-20240229",
257
- max_tokens=1000,
258
- messages=[{"role": "user", "content": prompt}]
259
- )
260
- result = response.content[0].text
261
- elif self.llm_model and self.llm_model.startswith("gpt"):
262
- client = openai.OpenAI(api_key=self.api_key)
263
- response = client.chat.completions.create(
264
- model="gpt-4o",
265
- messages=[{"role": "user", "content": prompt}]
266
- )
267
- result = response.choices[0].message.content
268
- else:
269
- result = "For non-default datasets, please provide a direct download link."
270
-
271
- self.log_llm_interaction(prompt, result)
272
-
273
- # Extract URL from the response
274
- lines = result.split('\n')
275
- for line in lines:
276
- if line.startswith("http") and (".csv" in line or ".xlsx" in line):
277
- dataset_url = line.strip()
278
- break
279
- else:
280
- # If no URL found, use default cereals dataset
281
- dataset_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/cereal.csv"
282
-
283
  # Load the dataset
284
  response = requests.get(dataset_url)
285
  content = response.content.decode('utf-8')
@@ -290,8 +166,7 @@ class ComputeAgent(MCPAgent):
290
  "status": "success",
291
  "rows": len(self.dataframe),
292
  "columns": list(self.dataframe.columns),
293
- "preview": self.dataframe.head(5).to_dict(orient="records"),
294
- "dtypes": {col: str(dtype) for col, dtype in self.dataframe.dtypes.items()}
295
  }
296
 
297
  return info
@@ -299,53 +174,7 @@ class ComputeAgent(MCPAgent):
299
  except Exception as e:
300
  return {"status": "error", "message": str(e)}
301
 
302
- def _clean_data(self, params: Dict) -> Dict:
303
- """Clean the loaded dataset"""
304
- if self.dataframe is None:
305
- return {"status": "error", "message": "No dataset loaded"}
306
-
307
- try:
308
- original_shape = self.dataframe.shape
309
-
310
- # Handle missing values based on strategy
311
- missing_strategy = params.get("missing_strategy", "drop")
312
- if missing_strategy == "drop":
313
- self.dataframe = self.dataframe.dropna()
314
- elif missing_strategy == "mean":
315
- self.dataframe = self.dataframe.fillna(self.dataframe.mean(numeric_only=True))
316
- elif missing_strategy == "median":
317
- self.dataframe = self.dataframe.fillna(self.dataframe.median(numeric_only=True))
318
- elif missing_strategy == "mode":
319
- # Fill categorical with mode, numerics separately
320
- for column in self.dataframe.columns:
321
- if pd.api.types.is_numeric_dtype(self.dataframe[column]):
322
- self.dataframe[column] = self.dataframe[column].fillna(self.dataframe[column].mean())
323
- else:
324
- self.dataframe[column] = self.dataframe[column].fillna(self.dataframe[column].mode()[0])
325
-
326
- # Remove duplicates if specified
327
- if params.get("remove_duplicates", True):
328
- self.dataframe = self.dataframe.drop_duplicates()
329
-
330
- # Convert datatypes if specified
331
- if "convert_dtypes" in params:
332
- for col, dtype in params["convert_dtypes"].items():
333
- self.dataframe[col] = self.dataframe[col].astype(dtype)
334
-
335
- new_shape = self.dataframe.shape
336
-
337
- return {
338
- "status": "success",
339
- "original_shape": original_shape,
340
- "new_shape": new_shape,
341
- "missing_values_remaining": self.dataframe.isna().sum().to_dict(),
342
- "duplicate_rows_removed": original_shape[0] - new_shape[0]
343
- }
344
-
345
- except Exception as e:
346
- return {"status": "error", "message": str(e)}
347
-
348
- def _compute_statistics(self, params: Dict) -> Dict:
349
  """Compute basic statistics on the dataset"""
350
  if self.dataframe is None:
351
  return {"status": "error", "message": "No dataset loaded"}
@@ -354,36 +183,8 @@ class ComputeAgent(MCPAgent):
354
  # Get columns to compute stats for
355
  columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
356
 
357
- # Compute different statistics based on parameters
358
- stats = {}
359
-
360
  # Basic descriptive statistics
361
- if params.get("descriptive", True):
362
- stats["descriptive"] = self.dataframe[columns].describe().to_dict()
363
-
364
- # Central tendency
365
- if params.get("central_tendency", False):
366
- stats["mean"] = self.dataframe[columns].mean().to_dict()
367
- stats["median"] = self.dataframe[columns].median().to_dict()
368
- # Mode is more complex as it can return multiple values
369
- mode_results = {}
370
- for col in columns:
371
- if pd.api.types.is_numeric_dtype(self.dataframe[col]):
372
- mode_vals = self.dataframe[col].mode().tolist()
373
- mode_results[col] = mode_vals
374
- stats["mode"] = mode_results
375
-
376
- # Dispersion
377
- if params.get("dispersion", False):
378
- stats["variance"] = self.dataframe[columns].var().to_dict()
379
- stats["std_dev"] = self.dataframe[columns].std().to_dict()
380
- stats["range"] = {col: self.dataframe[col].max() - self.dataframe[col].min() for col in columns}
381
- stats["iqr"] = {col: self.dataframe[col].quantile(0.75) - self.dataframe[col].quantile(0.25) for col in columns}
382
-
383
- # Shape
384
- if params.get("shape", False):
385
- stats["skewness"] = self.dataframe[columns].skew().to_dict()
386
- stats["kurtosis"] = self.dataframe[columns].kurtosis().to_dict()
387
 
388
  return {
389
  "status": "success",
@@ -393,7 +194,7 @@ class ComputeAgent(MCPAgent):
393
  except Exception as e:
394
  return {"status": "error", "message": str(e)}
395
 
396
- def _compute_correlation(self, params: Dict) -> Dict:
397
  """Compute correlation between columns"""
398
  if self.dataframe is None:
399
  return {"status": "error", "message": "No dataset loaded"}
@@ -401,107 +202,23 @@ class ComputeAgent(MCPAgent):
401
  try:
402
  # Get columns to compute correlation for
403
  columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
404
- method = params.get("method", "pearson") # pearson, kendall, spearman
405
-
406
- corr_matrix = self.dataframe[columns].corr(method=method).to_dict()
407
 
408
- # Find highest correlated pairs
409
- corr_df = self.dataframe[columns].corr(method=method).unstack()
410
- corr_df = corr_df[corr_df < 1.0] # Remove self-correlation
411
- highest_corr = corr_df.sort_values(ascending=False)[:10].to_dict()
412
 
413
  return {
414
  "status": "success",
415
- "correlation_matrix": corr_matrix,
416
- "highest_correlations": highest_corr
417
  }
418
 
419
  except Exception as e:
420
  return {"status": "error", "message": str(e)}
421
 
422
- def _filter_data(self, params: Dict) -> Dict:
423
- """Filter data based on conditions"""
424
- if self.dataframe is None:
425
- return {"status": "error", "message": "No dataset loaded"}
426
-
427
- try:
428
- # Apply filters
429
- filtered_df = self.dataframe.copy()
430
- filters = params.get("filters", [])
431
-
432
- for filter_item in filters:
433
- column = filter_item["column"]
434
- operator = filter_item["operator"]
435
- value = filter_item["value"]
436
-
437
- if operator == "==":
438
- filtered_df = filtered_df[filtered_df[column] == value]
439
- elif operator == "!=":
440
- filtered_df = filtered_df[filtered_df[column] != value]
441
- elif operator == ">":
442
- filtered_df = filtered_df[filtered_df[column] > value]
443
- elif operator == "<":
444
- filtered_df = filtered_df[filtered_df[column] < value]
445
- elif operator == ">=":
446
- filtered_df = filtered_df[filtered_df[column] >= value]
447
- elif operator == "<=":
448
- filtered_df = filtered_df[filtered_df[column] <= value]
449
- elif operator == "in":
450
- filtered_df = filtered_df[filtered_df[column].isin(value)]
451
- elif operator == "not in":
452
- filtered_df = filtered_df[~filtered_df[column].isin(value)]
453
-
454
- # Store the filtered dataframe temporarily for use in subsequent operations
455
- self.filtered_df = filtered_df
456
-
457
- return {
458
- "status": "success",
459
- "original_rows": len(self.dataframe),
460
- "filtered_rows": len(filtered_df),
461
- "preview": filtered_df.head(5).to_dict(orient="records")
462
- }
463
-
464
- except Exception as e:
465
- return {"status": "error", "message": str(e)}
466
-
467
- def _compute_aggregation(self, params: Dict) -> Dict:
468
- """Compute aggregation grouped by a column"""
469
- if self.dataframe is None:
470
- return {"status": "error", "message": "No dataset loaded"}
471
-
472
- try:
473
- # Get params
474
- groupby_cols = params.get("groupby", [])
475
- agg_cols = params.get("columns", [])
476
- agg_funcs = params.get("functions", ["mean"])
477
-
478
- # Use filtered dataframe if available, otherwise use original
479
- df_to_use = getattr(self, "filtered_df", self.dataframe)
480
-
481
- # Prepare aggregation dict
482
- agg_dict = {col: agg_funcs for col in agg_cols}
483
-
484
- # Compute aggregation
485
- result = df_to_use.groupby(groupby_cols).agg(agg_dict).reset_index()
486
-
487
- return {
488
- "status": "success",
489
- "result": result.to_dict(orient="records")
490
- }
491
-
492
- except Exception as e:
493
- return {"status": "error", "message": str(e)}
494
-
495
- def handle_message(self, message: MCPMessage) -> Dict:
496
  """Handle incoming messages from other agents"""
497
  if message.message_type == "request_data_load":
498
  result = self._load_dataset(message.content)
499
  return self.send_message(message.sender, "data_load_result", result)
500
 
501
- elif message.message_type == "request_data_cleaning":
502
- result = self._clean_data(message.content)
503
- return self.send_message(message.sender, "data_cleaning_result", result)
504
-
505
  elif message.message_type == "request_statistics":
506
  result = self._compute_statistics(message.content)
507
  return self.send_message(message.sender, "statistics_result", result)
@@ -510,14 +227,6 @@ class ComputeAgent(MCPAgent):
510
  result = self._compute_correlation(message.content)
511
  return self.send_message(message.sender, "correlation_result", result)
512
 
513
- elif message.message_type == "request_filter":
514
- result = self._filter_data(message.content)
515
- return self.send_message(message.sender, "filter_result", result)
516
-
517
- elif message.message_type == "request_aggregation":
518
- result = self._compute_aggregation(message.content)
519
- return self.send_message(message.sender, "aggregation_result", result)
520
-
521
  else:
522
  return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
523
 
@@ -527,14 +236,11 @@ class ComputeAgent(MCPAgent):
527
  class InterpretAgent(MCPAgent):
528
  """Agent responsible for interpreting results and visualizing data"""
529
 
530
- def __init__(self, name: str = "InterpretAgent", llm_model: Optional[str] = None, api_key: Optional[str] = None):
531
- super().__init__(name, "Agent responsible for interpreting results and visualizing data", llm_model, api_key)
532
  self.dataset_info = None
533
  self.statistics = None
534
  self.correlation_data = None
535
- self.filter_results = None
536
- self.aggregation_results = None
537
- self.visualization_results = {}
538
 
539
  # Register tools
540
  self.register_tool(MCPTool(
@@ -549,325 +255,60 @@ class InterpretAgent(MCPAgent):
549
  self._interpret_correlation
550
  ))
551
 
552
- self.register_tool(MCPTool(
553
- "create_visualization",
554
- "Create a visualization based on data",
555
- self._create_visualization
556
- ))
557
-
558
  self.register_tool(MCPTool(
559
  "generate_report",
560
  "Generate a report with key findings",
561
  self._generate_report
562
  ))
563
 
564
- def _interpret_statistics(self, params: Dict) -> Dict:
565
  """Interpret statistical results and provide insights"""
566
  if not self.statistics:
567
  return {"status": "error", "message": "No statistics data available"}
568
 
569
  try:
570
- # If we have LLM access, use it for more advanced interpretation
571
- if self.llm_model:
572
- prompt = f"""
573
- As a data analyst, interpret these statistics and provide insights:
574
- {json.dumps(self.statistics, indent=2)}
575
-
576
- Provide:
577
- 1. 5 key insights about the data
578
- 2. Any potential anomalies or interesting observations
579
- 3. Any patterns or trends visible in the descriptive statistics
580
- """
581
-
582
- if self.llm_model.startswith("claude"):
583
- client = anthropic.Anthropic(api_key=self.api_key)
584
- response = client.messages.create(
585
- model="claude-3-sonnet-20240229",
586
- max_tokens=1000,
587
- messages=[{"role": "user", "content": prompt}]
588
- )
589
- result = response.content[0].text
590
- elif self.llm_model.startswith("gpt"):
591
- client = openai.OpenAI(api_key=self.api_key)
592
- response = client.chat.completions.create(
593
- model="gpt-4o",
594
- messages=[{"role": "user", "content": prompt}]
595
- )
596
- result = response.choices[0].message.content
597
-
598
- self.log_llm_interaction(prompt, result)
599
-
600
- return {
601
- "status": "success",
602
- "insights": result.split('\n'),
603
- "summary": "Statistical analysis complete with LLM-generated insights."
604
- }
605
-
606
- # Fallback to rule-based insights if no LLM available
607
  insights = []
608
  stats = self.statistics.get("statistics", {})
609
 
610
- # Analyze descriptive statistics
611
- if "descriptive" in stats:
612
- desc_stats = stats["descriptive"]
613
-
614
- # Look at each numerical column
615
- for col in desc_stats:
616
- col_stats = desc_stats[col]
617
-
618
- # Check for outliers using IQR method
619
- q1 = col_stats.get("25%", 0)
620
- q3 = col_stats.get("75%", 0)
621
- iqr = q3 - q1
622
- lower_bound = q1 - 1.5 * iqr
623
- upper_bound = q3 + 1.5 * iqr
624
-
625
- if col_stats.get("min", 0) < lower_bound or col_stats.get("max", 0) > upper_bound:
626
- insights.append(f"Column '{col}' may contain outliers.")
627
-
628
- # Check for skewness
629
- mean = col_stats.get("mean", 0)
630
- median = col_stats.get("50%", 0)
631
- if abs(mean - median) > 0.1 * mean:
632
- skew_direction = "right" if mean > median else "left"
633
- insights.append(f"Column '{col}' appears to be skewed to the {skew_direction}.")
634
-
635
- # Check for variability
636
- std = col_stats.get("std", 0)
637
- mean = col_stats.get("mean", 0)
638
- cv = std / mean if mean != 0 else 0
639
- if cv > 1:
640
- insights.append(f"Column '{col}' shows high variability (CV > 1).")
641
 
642
  return {
643
  "status": "success",
644
- "insights": insights,
645
- "summary": "Statistical analysis reveals potential patterns and anomalies in the data."
646
  }
647
 
648
  except Exception as e:
649
  return {"status": "error", "message": str(e)}
650
 
651
- def _interpret_correlation(self, params: Dict) -> Dict:
652
  """Interpret correlation results and provide insights"""
653
  if not self.correlation_data:
654
  return {"status": "error", "message": "No correlation data available"}
655
 
656
  try:
657
- # If we have LLM access, use it for more advanced interpretation
658
- if self.llm_model:
659
- prompt = f"""
660
- As a data analyst, interpret this correlation data and provide insights:
661
- {json.dumps(self.correlation_data, indent=2)}
662
-
663
- Provide:
664
- 1. The 5 most significant correlations found and what they might indicate
665
- 2. Any interesting patterns of correlation in the dataset
666
- 3. Suggestions for variables that might have causal relationships
667
- """
668
-
669
- if self.llm_model.startswith("claude"):
670
- client = anthropic.Anthropic(api_key=self.api_key)
671
- response = client.messages.create(
672
- model="claude-3-sonnet-20240229",
673
- max_tokens=1000,
674
- messages=[{"role": "user", "content": prompt}]
675
- )
676
- result = response.content[0].text
677
- elif self.llm_model.startswith("gpt"):
678
- client = openai.OpenAI(api_key=self.api_key)
679
- response = client.chat.completions.create(
680
- model="gpt-4o",
681
- messages=[{"role": "user", "content": prompt}]
682
- )
683
- result = response.choices[0].message.content
684
-
685
- self.log_llm_interaction(prompt, result)
686
-
687
- return {
688
- "status": "success",
689
- "insights": result.split('\n'),
690
- "summary": "Correlation analysis complete with LLM-generated insights."
691
- }
692
-
693
- # Fallback to rule-based insights if no LLM available
694
- insights = []
695
- corr_matrix = self.correlation_data.get("correlation_matrix", {})
696
- highest_corr = self.correlation_data.get("highest_correlations", {})
697
-
698
- # Find strong positive correlations
699
- strong_pos_corr = [(k, v) for k, v in highest_corr.items() if v > 0.7]
700
- if strong_pos_corr:
701
- for (col1, col2), value in strong_pos_corr[:3]:
702
- insights.append(f"Strong positive correlation ({value:.2f}) between '{col1}' and '{col2}'.")
703
-
704
- # Find strong negative correlations
705
- strong_neg_corr = [(k, v) for k, v in highest_corr.items() if v < -0.7]
706
- if strong_neg_corr:
707
- for (col1, col2), value in strong_neg_corr[:3]:
708
- insights.append(f"Strong negative correlation ({value:.2f}) between '{col1}' and '{col2}'.")
709
-
710
- # Identify potential multicollinearity
711
- multi_corr = [(k, v) for k, v in highest_corr.items() if abs(v) > 0.9]
712
- if multi_corr:
713
- insights.append("Potential multicollinearity detected between some features.")
714
 
715
  return {
716
  "status": "success",
717
  "insights": insights,
718
- "summary": "Correlation analysis reveals interesting relationships between variables."
719
  }
720
 
721
  except Exception as e:
722
  return {"status": "error", "message": str(e)}
723
 
724
- def _create_visualization(self, params: Dict) -> Dict:
725
- """Create a visualization based on data"""
726
- try:
727
- viz_type = params.get("type", "histogram")
728
- title = params.get("title", "Data Visualization")
729
- x_column = params.get("x", None)
730
- y_column = params.get("y", None)
731
-
732
- # Generate a unique ID for this visualization
733
- viz_id = str(uuid.uuid4())
734
-
735
- # Create the visualization and save it to a file
736
- plt.figure(figsize=(10, 6))
737
-
738
- if not hasattr(self, "compute_agent") or not hasattr(self.compute_agent, "dataframe"):
739
- return {"status": "error", "message": "No data available for visualization"}
740
-
741
- df = self.compute_agent.dataframe
742
-
743
- if viz_type == "histogram":
744
- if x_column:
745
- sns.histplot(df[x_column], kde=True)
746
- plt.xlabel(x_column)
747
- plt.ylabel("Frequency")
748
- else:
749
- return {"status": "error", "message": "Column name required for histogram"}
750
-
751
- elif viz_type == "scatter":
752
- if x_column and y_column:
753
- sns.scatterplot(x=df[x_column], y=df[y_column])
754
- plt.xlabel(x_column)
755
- plt.ylabel(y_column)
756
- else:
757
- return {"status": "error", "message": "X and Y column names required for scatter plot"}
758
-
759
- elif viz_type == "bar":
760
- if x_column and y_column:
761
- sns.barplot(x=df[x_column], y=df[y_column])
762
- plt.xlabel(x_column)
763
- plt.ylabel(y_column)
764
- else:
765
- return {"status": "error", "message": "X and Y column names required for bar chart"}
766
-
767
- elif viz_type == "boxplot":
768
- if x_column:
769
- sns.boxplot(y=df[x_column])
770
- plt.ylabel(x_column)
771
- elif x_column and y_column:
772
- sns.boxplot(x=df[x_column], y=df[y_column])
773
- plt.xlabel(x_column)
774
- plt.ylabel(y_column)
775
- else:
776
- return {"status": "error", "message": "At least one column name required for boxplot"}
777
-
778
- elif viz_type == "heatmap":
779
- if params.get("columns"):
780
- corr = df[params["columns"]].corr()
781
- sns.heatmap(corr, annot=True, cmap="coolwarm")
782
- else:
783
- corr = df.select_dtypes(include=[np.number]).corr()
784
- sns.heatmap(corr, annot=True, cmap="coolwarm")
785
-
786
- plt.title(title)
787
- plt.tight_layout()
788
-
789
- # Save the visualization
790
- viz_filename = f"viz_{viz_id}.png"
791
- plt.savefig(viz_filename)
792
- plt.close()
793
-
794
- # Store visualization details
795
- viz_details = {
796
- "id": viz_id,
797
- "type": viz_type,
798
- "title": title,
799
- "filename": viz_filename,
800
- "x_column": x_column,
801
- "y_column": y_column
802
- }
803
-
804
- self.visualization_results[viz_id] = viz_details
805
-
806
- return {
807
- "status": "success",
808
- "visualization": viz_details
809
- }
810
-
811
- except Exception as e:
812
- return {"status": "error", "message": str(e)}
813
-
814
- def _generate_report(self, params: Dict) -> Dict:
815
  """Generate a report with key findings"""
816
  try:
817
- # If LLM available, use it for advanced report generation
818
- if self.llm_model:
819
- # Gather all the data we have
820
- report_data = {
821
- "dataset_info": self.dataset_info,
822
- "statistics": self.statistics,
823
- "correlation_data": self.correlation_data,
824
- "filter_results": self.filter_results,
825
- "aggregation_results": self.aggregation_results
826
- }
827
-
828
- prompt = f"""
829
- As a data analyst, generate a comprehensive report based on the following analysis data:
830
- {json.dumps(report_data, indent=2)}
831
-
832
- The report should include:
833
- 1. Dataset Overview
834
- 2. Key Findings from Statistical Analysis
835
- 3. Correlation Analysis Highlights
836
- 4. Filtered Data Analysis (if applicable)
837
- 5. Aggregation Insights (if applicable)
838
- 6. Conclusions and Recommendations
839
-
840
- Format the report in a professional style with clear sections.
841
- """
842
-
843
- if self.llm_model.startswith("claude"):
844
- client = anthropic.Anthropic(api_key=self.api_key)
845
- response = client.messages.create(
846
- model="claude-3-sonnet-20240229",
847
- max_tokens=2000,
848
- messages=[{"role": "user", "content": prompt}]
849
- )
850
- result = response.content[0].text
851
- elif self.llm_model.startswith("gpt"):
852
- client = openai.OpenAI(api_key=self.api_key)
853
- response = client.chat.completions.create(
854
- model="gpt-4o",
855
- messages=[{"role": "user", "content": prompt}]
856
- )
857
- result = response.choices[0].message.content
858
-
859
- self.log_llm_interaction(prompt, result)
860
-
861
- return {
862
- "status": "success",
863
- "report": {
864
- "title": params.get("report_title", "Data Analysis Report"),
865
- "content": result
866
- }
867
- }
868
-
869
- # Fallback to template-based report if no LLM available
870
- # Gather all the insights and results
871
  report_sections = []
872
 
873
  # Dataset overview
@@ -877,49 +318,10 @@ class InterpretAgent(MCPAgent):
877
  "content": f"The dataset contains {self.dataset_info.get('rows', 0)} rows and {len(self.dataset_info.get('columns', []))} columns."
878
  })
879
 
880
- # Statistical insights
881
- if self.statistics:
882
- # Interpret statistics if not already done
883
- if not hasattr(self, 'stat_insights'):
884
- self.stat_insights = self._interpret_statistics({}).get('insights', [])
885
-
886
- report_sections.append({
887
- "title": "Statistical Analysis",
888
- "content": "Key findings from statistical analysis:",
889
- "insights": self.stat_insights
890
- })
891
-
892
- # Correlation insights
893
- if self.correlation_data:
894
- # Interpret correlations if not already done
895
- if not hasattr(self, 'corr_insights'):
896
- self.corr_insights = self._interpret_correlation({}).get('insights', [])
897
-
898
- report_sections.append({
899
- "title": "Correlation Analysis",
900
- "content": "Key findings from correlation analysis:",
901
- "insights": self.corr_insights
902
- })
903
-
904
- # Filter results
905
- if self.filter_results:
906
- report_sections.append({
907
- "title": "Filtered Data Analysis",
908
- "content": f"The filtered dataset contains {self.filter_results.get('filtered_rows', 0)} rows, down from {self.filter_results.get('original_rows', 0)} rows."
909
- })
910
-
911
- # Aggregation results
912
- if self.aggregation_results:
913
- report_sections.append({
914
- "title": "Aggregation Analysis",
915
- "content": "Key insights from aggregated data:",
916
- "data": self.aggregation_results.get('result', [])
917
- })
918
-
919
- # Conclusions
920
  report_sections.append({
921
  "title": "Conclusions",
922
- "content": "Based on the analysis, several patterns and relationships have been identified in the data."
923
  })
924
 
925
  return {
@@ -933,15 +335,12 @@ class InterpretAgent(MCPAgent):
933
  except Exception as e:
934
  return {"status": "error", "message": str(e)}
935
 
936
- def handle_message(self, message: MCPMessage) -> Dict:
937
  """Handle incoming messages from other agents"""
938
  if message.message_type == "data_load_result":
939
  self.dataset_info = message.content
940
  return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Dataset info received"})
941
 
942
- elif message.message_type == "data_cleaning_result":
943
- return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Data cleaning result received"})
944
-
945
  elif message.message_type == "statistics_result":
946
  self.statistics = message.content
947
  insights = self._interpret_statistics({})
@@ -952,134 +351,55 @@ class InterpretAgent(MCPAgent):
952
  insights = self._interpret_correlation({})
953
  return self.send_message(message.sender, "correlation_interpretation", insights)
954
 
955
- elif message.message_type == "filter_result":
956
- self.filter_results = message.content
957
- return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Filter result received"})
958
-
959
- elif message.message_type == "aggregation_result":
960
- self.aggregation_results = message.content
961
- return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Aggregation result received"})
962
-
963
  elif message.message_type == "request_report":
964
  report = self._generate_report(message.content)
965
  return self.send_message(message.sender, "report_result", report)
966
 
967
- elif message.message_type == "request_visualization":
968
- visualization = self._create_visualization(message.content)
969
- return self.send_message(message.sender, "visualization_result", visualization)
970
-
971
  else:
972
  return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
973
 
974
- def set_compute_agent(self, compute_agent):
975
- """Set reference to compute agent for access to dataframe"""
976
- self.compute_agent = compute_agent
977
-
978
 
979
  # ============== Main Analysis Workflow ==============
980
 
981
  class DataAnalystDuo:
982
  """Main class for the Data Analyst Duo MCP implementation"""
983
 
984
- def __init__(self, llm_model: Optional[str] = None, api_key: Optional[str] = None):
985
- self.compute_agent = ComputeAgent(llm_model=llm_model, api_key=api_key)
986
- self.interpret_agent = InterpretAgent(llm_model=llm_model, api_key=api_key)
987
 
988
  # Connect the agents as peers
989
  self.compute_agent.connect(self.interpret_agent)
990
  self.interpret_agent.connect(self.compute_agent)
991
 
992
- # Set reference to compute agent inside interpret agent
993
- self.interpret_agent.set_compute_agent(self.compute_agent)
994
-
995
- # Logs to store message flow and intermediate results
996
- self.logs = []
997
-
998
- def log_step(self, step_name: str, details: Dict):
999
- """Log a step in the analysis workflow"""
1000
- log_entry = {
1001
- "timestamp": datetime.datetime.now().isoformat(),
1002
- "step": step_name,
1003
- "details": details
1004
- }
1005
- self.logs.append(log_entry)
1006
- return log_entry
1007
-
1008
- def run_analysis(self, dataset_url: str, analysis_params: Dict = None) -> Dict:
1009
  """Run the complete analysis workflow"""
1010
- if analysis_params is None:
1011
- analysis_params = {}
1012
-
1013
- results = {}
1014
 
1015
  # 1. Load dataset
1016
- self.log_step("Initiating dataset loading", {"url": dataset_url})
1017
  self.interpret_agent.send_message("ComputeAgent", "request_data_load", {"url": dataset_url})
1018
  self.compute_agent.process_messages()
1019
  self.interpret_agent.process_messages()
1020
 
1021
- # 2. Clean data
1022
- clean_params = analysis_params.get("clean_params", {"missing_strategy": "mean", "remove_duplicates": True})
1023
- self.log_step("Initiating data cleaning", clean_params)
1024
- self.interpret_agent.send_message("ComputeAgent", "request_data_cleaning", clean_params)
1025
  self.compute_agent.process_messages()
1026
  self.interpret_agent.process_messages()
1027
 
1028
- # 3. Compute statistics
1029
- stats_params = analysis_params.get("stats_params", {"descriptive": True, "central_tendency": True, "dispersion": True})
1030
- self.log_step("Initiating statistical analysis", stats_params)
1031
- self.interpret_agent.send_message("ComputeAgent", "request_statistics", stats_params)
1032
  self.compute_agent.process_messages()
1033
  self.interpret_agent.process_messages()
1034
 
1035
- # 4. Compute correlation
1036
- corr_params = analysis_params.get("corr_params", {"method": "pearson"})
1037
- self.log_step("Initiating correlation analysis", corr_params)
1038
- self.interpret_agent.send_message("ComputeAgent", "request_correlation", corr_params)
1039
- self.compute_agent.process_messages()
1040
- self.interpret_agent.process_messages()
1041
-
1042
- # 5. Filter data if requested
1043
- if "filter_params" in analysis_params:
1044
- self.log_step("Initiating data filtering", analysis_params["filter_params"])
1045
- self.interpret_agent.send_message("ComputeAgent", "request_filter", analysis_params["filter_params"])
1046
- self.compute_agent.process_messages()
1047
- self.interpret_agent.process_messages()
1048
-
1049
- # 6. Compute aggregation if requested
1050
- if "agg_params" in analysis_params:
1051
- self.log_step("Initiating data aggregation", analysis_params["agg_params"])
1052
- self.interpret_agent.send_message("ComputeAgent", "request_aggregation", analysis_params["agg_params"])
1053
- self.compute_agent.process_messages()
1054
- self.interpret_agent.process_messages()
1055
-
1056
- # 7. Create visualizations if requested
1057
- if "viz_params" in analysis_params:
1058
- for viz_param in analysis_params["viz_params"]:
1059
- self.log_step("Initiating visualization creation", viz_param)
1060
- self.compute_agent.send_message("InterpretAgent", "request_visualization", viz_param)
1061
- self.interpret_agent.process_messages()
1062
- self.compute_agent.process_messages()
1063
-
1064
- # 8. Generate final report
1065
- report_params = analysis_params.get("report_params", {"report_title": "Data Analysis Report"})
1066
- self.log_step("Generating final report", report_params)
1067
- self.compute_agent.send_message("InterpretAgent", "request_report", report_params)
1068
  self.interpret_agent.process_messages()
1069
  self.compute_agent.process_messages()
1070
 
1071
  # Collect results
1072
- results["dataset_info"] = self.interpret_agent.dataset_info
1073
- results["statistics"] = self.interpret_agent.statistics
1074
- results["correlation_data"] = self.interpret_agent.correlation_data
1075
- results["filter_results"] = self.interpret_agent.filter_results
1076
- results["aggregation_results"] = self.interpret_agent.aggregation_results
1077
- results["visualizations"] = self.interpret_agent.visualization_results
1078
- results["compute_agent_messages"] = self.compute_agent.get_message_history()
1079
- results["interpret_agent_messages"] = self.interpret_agent.get_message_history()
1080
- results["compute_agent_llm_logs"] = self.compute_agent.get_llm_logs()
1081
- results["interpret_agent_llm_logs"] = self.interpret_agent.get_llm_logs()
1082
- results["workflow_logs"] = self.logs
1083
 
1084
  return results
1085
 
@@ -1092,264 +412,78 @@ def format_json(json_data):
1092
  return json.dumps(json_data, indent=2)
1093
  return str(json_data)
1094
 
1095
- def run_data_analysis(dataset_url, llm_provider, api_key, missing_strategy, create_visualizations, high_fiber_filter):
1096
- """Run the data analysis workflow and return results"""
1097
  try:
1098
- # Validate inputs
1099
  if not dataset_url:
1100
- dataset_url = "default" # Use default cereals dataset
1101
 
1102
- if llm_provider != "none" and not api_key:
1103
- return {
1104
- 'mcp_messages': "Error: API key is required for LLM integration",
1105
- 'llm_logs': "",
1106
- 'visualizations': "",
1107
- 'final_report': ""
1108
- }
1109
-
1110
- # Initialize the analyst duo
1111
- llm_model = None
1112
- if llm_provider == "claude":
1113
- llm_model = "claude"
1114
- elif llm_provider == "gpt":
1115
- llm_model = "gpt"
1116
- if not api_key:
1117
- api_key = os.environ.get("OPENAI_API_KEY", "")
1118
-
1119
- # Create the data analyst duo
1120
- duo = DataAnalystDuo(llm_model=llm_model, api_key=api_key)
1121
-
1122
- # Prepare analysis parameters
1123
- analysis_params = {
1124
- "clean_params": {
1125
- "missing_strategy": missing_strategy,
1126
- "remove_duplicates": True
1127
- },
1128
- "stats_params": {
1129
- "descriptive": True,
1130
- "central_tendency": True,
1131
- "dispersion": True
1132
- },
1133
- "corr_params": {
1134
- "method": "pearson"
1135
- }
1136
- }
1137
-
1138
- # Add filter for high fiber if requested
1139
- if high_fiber_filter:
1140
- analysis_params["filter_params"] = {
1141
- "filters": [
1142
- {"column": "fiber", "operator": ">", "value": 5}
1143
- ]
1144
- }
1145
-
1146
- # Add aggregation by manufacturer
1147
- analysis_params["agg_params"] = {
1148
- "groupby": ["mfr"],
1149
- "columns": ["calories", "protein", "fat", "fiber", "sugars"],
1150
- "functions": ["mean", "min", "max"]
1151
- }
1152
 
1153
- # Add visualizations if requested
1154
- if create_visualizations:
1155
- analysis_params["viz_params"] = [
1156
- {
1157
- "type": "scatter",
1158
- "title": "Calories vs Sugar Content",
1159
- "x": "calories",
1160
- "y": "sugars"
1161
- },
1162
- {
1163
- "type": "histogram",
1164
- "title": "Distribution of Fiber Content",
1165
- "x": "fiber"
1166
- },
1167
- {
1168
- "type": "heatmap",
1169
- "title": "Correlation Matrix",
1170
- "columns": ["calories", "protein", "fat", "fiber", "sugars", "rating"]
1171
- }
1172
- ]
1173
 
1174
- # Run the analysis
1175
- results = duo.run_analysis(dataset_url, analysis_params)
 
 
 
 
1176
 
1177
- # Extract MCP messages for display
1178
- compute_messages = results["compute_agent_messages"]
1179
- interpret_messages = results["interpret_agent_messages"]
 
 
 
1180
 
1181
- # Extract LLM logs
1182
- compute_llm_logs = results["compute_agent_llm_logs"]
1183
- interpret_llm_logs = results["interpret_agent_llm_logs"]
1184
 
1185
- # Format messages for display
1186
- formatted_messages = []
1187
 
1188
- # Combine and sort messages by timestamp
1189
- all_messages = []
1190
- for msg in compute_messages:
1191
- msg_copy = msg.copy()
1192
- msg_copy["agent"] = "ComputeAgent"
1193
- all_messages.append(msg_copy)
1194
-
1195
- for msg in interpret_messages:
1196
- msg_copy = msg.copy()
1197
- msg_copy["agent"] = "InterpretAgent"
1198
- all_messages.append(msg_copy)
1199
-
1200
- # Sort by timestamp
1201
- all_messages.sort(key=lambda x: x["message"]["timestamp"])
1202
-
1203
- # Format for display
1204
- for msg in all_messages:
1205
- agent = msg["agent"]
1206
- direction = msg["type"]
1207
- message = msg["message"]
1208
-
1209
- formatted_msg = f"[{message['timestamp']}] {agent} {direction.upper()} - Type: {message['message_type']}\n"
1210
- formatted_msg += format_json(message['content'])
1211
- formatted_msg += "\n\n" + "-"*80 + "\n\n"
1212
- formatted_messages.append(formatted_msg)
1213
-
1214
- # Format LLM logs
1215
- formatted_llm_logs = []
1216
-
1217
- for log in compute_llm_logs + interpret_llm_logs:
1218
- formatted_log = f"[{log['timestamp']}]\n"
1219
- formatted_log += "PROMPT:\n" + log['prompt'] + "\n\n"
1220
- formatted_log += "RESPONSE:\n" + log['response'] + "\n\n"
1221
- formatted_log += "-"*80 + "\n\n"
1222
- formatted_llm_logs.append(formatted_log)
1223
-
1224
- # Prepare visualization display
1225
- viz_html = ""
1226
- if create_visualizations and "visualizations" in results and results["visualizations"]:
1227
- viz_html = "<div style='display: flex; flex-wrap: wrap;'>"
1228
- for viz_id, viz_data in results["visualizations"].items():
1229
- viz_html += f"<div style='margin: 10px;'>"
1230
- viz_html += f"<h3>{viz_data['title']}</h3>"
1231
- viz_html += f"<img src='file={viz_data['filename']}' width='400' />"
1232
- viz_html += "</div>"
1233
- viz_html += "</div>"
1234
-
1235
- # Get the final report
1236
- report_html = "<h2>No report generated</h2>"
1237
-
1238
- # Check if report result exists
1239
- report_found = False
1240
- for msg in compute_messages:
1241
- if msg["type"] == "received" and msg["message"]["message_type"] == "report_result":
1242
- report_found = True
1243
- try:
1244
- if "content" in msg["message"] and "report" in msg["message"]["content"]:
1245
- report_content = msg["message"]["content"]["report"]
1246
- if "content" in report_content:
1247
- # LLM-generated report
1248
- report_html = f"<h2>{report_content['title']}</h2>"
1249
- newline_replaced = report_content['content'].replace("\n", "<br>")
1250
- report_html += f"<div>{newline_replaced}</div>"
1251
- elif "sections" in report_content:
1252
- # Template-based report
1253
- report_html = f"<h2>{report_content['title']}</h2>"
1254
- for section in report_content["sections"]:
1255
- report_html += f"<h3>{section['title']}</h3>"
1256
- report_html += f"<p>{section['content']}</p>"
1257
- if "insights" in section:
1258
- report_html += "<ul>"
1259
- for insight in section["insights"]:
1260
- report_html += f"<li>{insight}</li>"
1261
- report_html += "</ul>"
1262
- if "data" in section:
1263
- report_html += "<pre>" + format_json(section["data"]) + "</pre>"
1264
- except Exception as e:
1265
- report_html = f"<h2>Error generating report: {str(e)}</h2>"
1266
-
1267
- if not report_found:
1268
- report_html = "<h2>No report message received</h2>"
1269
-
1270
- # Return all results
1271
- return {
1272
- 'mcp_messages': "\n".join(formatted_messages),
1273
- 'llm_logs': "\n".join(formatted_llm_logs),
1274
- 'visualizations': viz_html,
1275
- 'final_report': report_html
1276
- }
1277
 
1278
  except Exception as e:
1279
  import traceback
1280
- error_traceback = traceback.format_exc()
1281
- return {
1282
- 'mcp_messages': f"Error: {str(e)}\n\n{error_traceback}",
1283
- 'llm_logs': "",
1284
- 'visualizations': "",
1285
- 'final_report': f"<h2>Error: {str(e)}</h2><pre>{error_traceback}</pre>"
1286
- }
1287
 
1288
  # Define the Gradio interface
1289
- def create_interface():
1290
- with gr.Blocks(title="Data Analyst Duo - MCP Implementation") as app:
1291
- gr.Markdown("""
1292
- # Data Analyst Duo - Model Context Protocol (MCP) Implementation
1293
-
1294
- This application demonstrates a multi-agent system using the Model Context Protocol (MCP).
1295
- It consists of two agents:
1296
-
1297
- 1. **ComputeAgent**: Responsible for data loading, cleaning, and computation
1298
- 2. **InterpretAgent**: Responsible for interpreting results and visualizing data
1299
-
1300
- The agents communicate directly using standardized MCP messages, showcasing agent-to-agent communication.
1301
- """)
1302
-
1303
- with gr.Row():
1304
- with gr.Column():
1305
- dataset_url = gr.Textbox(label="Dataset URL (leave empty for default cereals dataset)", placeholder="Enter dataset URL or leave empty for default")
1306
-
1307
- with gr.Row():
1308
- llm_provider = gr.Radio(["none", "claude", "gpt"], label="LLM Provider (Optional)", value="none")
1309
- api_key = gr.Textbox(label="API Key (if using LLM)", placeholder="Enter API key if using Claude or GPT")
1310
-
1311
- with gr.Row():
1312
- missing_strategy = gr.Dropdown(["drop", "mean", "median", "mode"], label="Missing Values Strategy", value="mean")
1313
- create_visualizations = gr.Checkbox(label="Create Visualizations", value=True)
1314
- high_fiber_filter = gr.Checkbox(label="Filter for High Fiber & Aggregate by Manufacturer", value=True)
1315
 
1316
- run_button = gr.Button("Run Data Analysis")
 
1317
 
1318
- with gr.Row():
1319
- with gr.Tab("MCP Messages"):
1320
- mcp_messages = gr.Textbox(label="MCP Message Flow", lines=20)
1321
- with gr.Tab("LLM Logs"):
1322
- llm_logs = gr.Textbox(label="LLM Interaction Logs", lines=20)
1323
 
1324
- with gr.Row():
1325
- with gr.Tab("Visualizations"):
1326
- visualizations = gr.HTML(label="Data Visualizations")
1327
- with gr.Tab("Final Report"):
1328
- final_report = gr.HTML(label="Analysis Report")
1329
 
1330
- # Connect the button to the analysis function
1331
- run_button.click(
1332
- fn=run_data_analysis,
1333
- inputs=[dataset_url, llm_provider, api_key, missing_strategy, create_visualizations, high_fiber_filter],
1334
- outputs=[mcp_messages, llm_logs, visualizations, final_report]
1335
- )
1336
 
1337
- gr.Markdown("""
1338
- ## How This Demonstrates MCP
1339
 
1340
- This application shows the Model Context Protocol in action:
 
1341
 
1342
- 1. **Standardized Message Structure**: All communication between agents follows a consistent format
1343
- 2. **Tool Registration**: Agents register their capabilities as tools with descriptions
1344
- 3. **Direct Peer Communication**: Agents communicate directly with structured messages
1345
- 4. **Asynchronous Processing**: Each agent processes messages independently
1346
 
1347
- The message flow display shows the exact JSON messages exchanged between agents, demonstrating the protocol in action.
1348
- """)
 
1349
 
1350
- return app
 
1351
 
1352
- # Create and launch the interface
1353
  if __name__ == "__main__":
1354
- app = create_interface()
1355
  app.launch()
 
1
  """
2
+ Data Analyst Duo MCP Implementation - Simplified version
 
 
 
 
 
 
 
3
  """
4
 
5
  import os
 
6
  import json
 
7
  import datetime
8
  import gradio as gr
9
  import pandas as pd
10
  import numpy as np
 
 
 
11
  import requests
12
  from io import StringIO
13
  import logging
14
  import uuid
 
 
 
 
 
 
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
22
  class MCPMessage:
23
  """Base class for MCP messages that agents exchange"""
24
 
25
+ def __init__(self, sender, message_type, content):
26
  self.id = str(uuid.uuid4())
27
  self.sender = sender
28
  self.message_type = message_type
29
  self.content = content
30
  self.timestamp = datetime.datetime.now().isoformat()
31
 
32
+ def to_dict(self):
33
  return {
34
  "id": self.id,
35
  "sender": self.sender,
 
38
  "timestamp": self.timestamp
39
  }
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  class MCPTool:
43
  """Defines a tool that can be used by agents through the MCP protocol"""
44
 
45
+ def __init__(self, name, description, function):
46
  self.name = name
47
  self.description = description
48
  self.function = function
49
 
50
+ def execute(self, params):
 
 
 
 
 
 
51
  return self.function(params)
52
 
53
 
54
  class MCPAgent:
55
  """Base agent class implementing MCP protocol"""
56
 
57
+ def __init__(self, name, description):
58
  self.name = name
59
  self.description = description
60
+ self.tools = {}
61
+ self.message_queue = []
62
+ self.peers = {}
63
+ self.message_history = []
64
+
65
+ def register_tool(self, tool):
 
 
 
66
  """Register a tool that this agent can use"""
67
  self.tools[tool.name] = tool
68
 
69
+ def connect(self, peer):
 
 
 
 
 
 
 
 
 
 
70
  """Connect to another agent as a peer"""
71
  self.peers[peer.name] = peer
72
 
73
+ def send_message(self, receiver, message_type, content):
74
  """Send a message to a peer agent"""
75
  if receiver not in self.peers:
76
  raise ValueError(f"Peer {receiver} not found")
 
89
  logger.info(f"Agent {self.name} sent {message_type} to {receiver}")
90
  return message_dict
91
 
92
+ def receive_message(self, message):
93
  """Receive a message from a peer agent"""
94
  self.message_queue.append(message)
95
 
 
101
 
102
  logger.info(f"Agent {self.name} received {message.message_type} from {message.sender}")
103
 
104
+ def process_messages(self):
105
  """Process all messages in the queue"""
106
  processed = []
107
  while self.message_queue:
 
110
  processed.append(response)
111
  return processed
112
 
113
+ def handle_message(self, message):
114
  """Handle a message - to be implemented by subclasses"""
115
  raise NotImplementedError("Subclasses must implement handle_message")
116
 
117
+ def get_message_history(self):
 
 
 
 
 
 
 
 
 
 
118
  """Get the agent's message history"""
119
  return self.message_history
120
 
 
 
 
 
121
 
122
  # ============== Compute Agent Implementation ==============
123
 
124
  class ComputeAgent(MCPAgent):
125
  """Agent responsible for data loading, cleaning, and computation"""
126
 
127
+ def __init__(self, name="ComputeAgent"):
128
+ super().__init__(name, "Agent responsible for data loading, cleaning and computation")
129
  self.dataframe = None
 
130
 
131
  # Register tools
132
  self.register_tool(MCPTool(
133
  "load_dataset",
134
+ "Load a dataset from URL",
135
  self._load_dataset
136
  ))
137
 
 
 
 
 
 
 
138
  self.register_tool(MCPTool(
139
  "compute_statistics",
140
  "Compute basic statistics on the dataset",
 
147
  self._compute_correlation
148
  ))
149
 
150
+ def _load_dataset(self, params):
151
+ """Load a dataset from URL"""
 
 
 
 
 
 
 
 
 
 
 
 
152
  dataset_url = params.get("url")
153
 
154
  try:
155
+ # Use default cereals dataset if not specified
156
+ if not dataset_url or dataset_url == "default":
157
  dataset_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/cereal.csv"
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Load the dataset
160
  response = requests.get(dataset_url)
161
  content = response.content.decode('utf-8')
 
166
  "status": "success",
167
  "rows": len(self.dataframe),
168
  "columns": list(self.dataframe.columns),
169
+ "preview": self.dataframe.head(5).to_dict(orient="records")
 
170
  }
171
 
172
  return info
 
174
  except Exception as e:
175
  return {"status": "error", "message": str(e)}
176
 
177
+ def _compute_statistics(self, params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  """Compute basic statistics on the dataset"""
179
  if self.dataframe is None:
180
  return {"status": "error", "message": "No dataset loaded"}
 
183
  # Get columns to compute stats for
184
  columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
185
 
 
 
 
186
  # Basic descriptive statistics
187
+ stats = self.dataframe[columns].describe().to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  return {
190
  "status": "success",
 
194
  except Exception as e:
195
  return {"status": "error", "message": str(e)}
196
 
197
+ def _compute_correlation(self, params):
198
  """Compute correlation between columns"""
199
  if self.dataframe is None:
200
  return {"status": "error", "message": "No dataset loaded"}
 
202
  try:
203
  # Get columns to compute correlation for
204
  columns = params.get("columns", list(self.dataframe.select_dtypes(include=[np.number]).columns))
 
 
 
205
 
206
+ corr_matrix = self.dataframe[columns].corr().to_dict()
 
 
 
207
 
208
  return {
209
  "status": "success",
210
+ "correlation_matrix": corr_matrix
 
211
  }
212
 
213
  except Exception as e:
214
  return {"status": "error", "message": str(e)}
215
 
216
+ def handle_message(self, message):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  """Handle incoming messages from other agents"""
218
  if message.message_type == "request_data_load":
219
  result = self._load_dataset(message.content)
220
  return self.send_message(message.sender, "data_load_result", result)
221
 
 
 
 
 
222
  elif message.message_type == "request_statistics":
223
  result = self._compute_statistics(message.content)
224
  return self.send_message(message.sender, "statistics_result", result)
 
227
  result = self._compute_correlation(message.content)
228
  return self.send_message(message.sender, "correlation_result", result)
229
 
 
 
 
 
 
 
 
 
230
  else:
231
  return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
232
 
 
236
  class InterpretAgent(MCPAgent):
237
  """Agent responsible for interpreting results and visualizing data"""
238
 
239
+ def __init__(self, name="InterpretAgent"):
240
+ super().__init__(name, "Agent responsible for interpreting results and visualizing data")
241
  self.dataset_info = None
242
  self.statistics = None
243
  self.correlation_data = None
 
 
 
244
 
245
  # Register tools
246
  self.register_tool(MCPTool(
 
255
  self._interpret_correlation
256
  ))
257
 
 
 
 
 
 
 
258
  self.register_tool(MCPTool(
259
  "generate_report",
260
  "Generate a report with key findings",
261
  self._generate_report
262
  ))
263
 
264
+ def _interpret_statistics(self, params):
265
  """Interpret statistical results and provide insights"""
266
  if not self.statistics:
267
  return {"status": "error", "message": "No statistics data available"}
268
 
269
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  insights = []
271
  stats = self.statistics.get("statistics", {})
272
 
273
+ # Simple rule-based insights
274
+ for col, col_stats in stats.items():
275
+ # Add a simple insight about the mean value
276
+ if "mean" in col_stats:
277
+ insights.append(f"The average {col} is {col_stats['mean']:.2f}")
278
+
279
+ # Add insight about range
280
+ if "min" in col_stats and "max" in col_stats:
281
+ insights.append(f"{col} ranges from {col_stats['min']:.2f} to {col_stats['max']:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  return {
284
  "status": "success",
285
+ "insights": insights[:3], # Limit to top 3 insights
286
+ "summary": "Statistical analysis complete."
287
  }
288
 
289
  except Exception as e:
290
  return {"status": "error", "message": str(e)}
291
 
292
+ def _interpret_correlation(self, params):
293
  """Interpret correlation results and provide insights"""
294
  if not self.correlation_data:
295
  return {"status": "error", "message": "No correlation data available"}
296
 
297
  try:
298
+ insights = ["Correlation analysis complete."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  return {
301
  "status": "success",
302
  "insights": insights,
303
+ "summary": "Correlation analysis complete."
304
  }
305
 
306
  except Exception as e:
307
  return {"status": "error", "message": str(e)}
308
 
309
+ def _generate_report(self, params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  """Generate a report with key findings"""
311
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  report_sections = []
313
 
314
  # Dataset overview
 
318
  "content": f"The dataset contains {self.dataset_info.get('rows', 0)} rows and {len(self.dataset_info.get('columns', []))} columns."
319
  })
320
 
321
+ # Simple conclusion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  report_sections.append({
323
  "title": "Conclusions",
324
+ "content": "Analysis complete."
325
  })
326
 
327
  return {
 
335
  except Exception as e:
336
  return {"status": "error", "message": str(e)}
337
 
338
+ def handle_message(self, message):
339
  """Handle incoming messages from other agents"""
340
  if message.message_type == "data_load_result":
341
  self.dataset_info = message.content
342
  return self.send_message(message.sender, "acknowledge", {"status": "received", "message": "Dataset info received"})
343
 
 
 
 
344
  elif message.message_type == "statistics_result":
345
  self.statistics = message.content
346
  insights = self._interpret_statistics({})
 
351
  insights = self._interpret_correlation({})
352
  return self.send_message(message.sender, "correlation_interpretation", insights)
353
 
 
 
 
 
 
 
 
 
354
  elif message.message_type == "request_report":
355
  report = self._generate_report(message.content)
356
  return self.send_message(message.sender, "report_result", report)
357
 
 
 
 
 
358
  else:
359
  return {"status": "error", "message": f"Unknown message type: {message.message_type}"}
360
 
 
 
 
 
361
 
362
  # ============== Main Analysis Workflow ==============
363
 
364
  class DataAnalystDuo:
365
  """Main class for the Data Analyst Duo MCP implementation"""
366
 
367
+ def __init__(self):
368
+ self.compute_agent = ComputeAgent()
369
+ self.interpret_agent = InterpretAgent()
370
 
371
  # Connect the agents as peers
372
  self.compute_agent.connect(self.interpret_agent)
373
  self.interpret_agent.connect(self.compute_agent)
374
 
375
+ def run_analysis(self, dataset_url="default"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  """Run the complete analysis workflow"""
 
 
 
 
377
 
378
  # 1. Load dataset
 
379
  self.interpret_agent.send_message("ComputeAgent", "request_data_load", {"url": dataset_url})
380
  self.compute_agent.process_messages()
381
  self.interpret_agent.process_messages()
382
 
383
+ # 2. Compute statistics
384
+ self.interpret_agent.send_message("ComputeAgent", "request_statistics", {"descriptive": True})
 
 
385
  self.compute_agent.process_messages()
386
  self.interpret_agent.process_messages()
387
 
388
+ # 3. Compute correlation
389
+ self.interpret_agent.send_message("ComputeAgent", "request_correlation", {"method": "pearson"})
 
 
390
  self.compute_agent.process_messages()
391
  self.interpret_agent.process_messages()
392
 
393
+ # 4. Generate final report
394
+ self.compute_agent.send_message("InterpretAgent", "request_report", {"report_title": "Data Analysis Report"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  self.interpret_agent.process_messages()
396
  self.compute_agent.process_messages()
397
 
398
  # Collect results
399
+ results = {
400
+ "compute_agent_messages": self.compute_agent.get_message_history(),
401
+ "interpret_agent_messages": self.interpret_agent.get_message_history()
402
+ }
 
 
 
 
 
 
 
403
 
404
  return results
405
 
 
412
  return json.dumps(json_data, indent=2)
413
  return str(json_data)
414
 
415
+ def run_analysis(dataset_url):
416
+ """Run the data analysis workflow and return formatted messages"""
417
  try:
418
+ # Use default cereals dataset if not specified
419
  if not dataset_url:
420
+ dataset_url = "default"
421
 
422
+ # Create and run the analyst duo
423
+ duo = DataAnalystDuo()
424
+ results = duo.run_analysis(dataset_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
+ # Format messages for display
427
+ all_messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
+ # Add compute agent messages
430
+ for msg in results["compute_agent_messages"]:
431
+ formatted_msg = f"[{msg['message']['timestamp']}] ComputeAgent {msg['type'].upper()} - Type: {msg['message']['message_type']}\n"
432
+ formatted_msg += format_json(msg['message']['content'])
433
+ formatted_msg += "\n\n" + "-"*80 + "\n\n"
434
+ all_messages.append((msg['message']['timestamp'], formatted_msg))
435
 
436
+ # Add interpret agent messages
437
+ for msg in results["interpret_agent_messages"]:
438
+ formatted_msg = f"[{msg['message']['timestamp']}] InterpretAgent {msg['type'].upper()} - Type: {msg['message']['message_type']}\n"
439
+ formatted_msg += format_json(msg['message']['content'])
440
+ formatted_msg += "\n\n" + "-"*80 + "\n\n"
441
+ all_messages.append((msg['message']['timestamp'], formatted_msg))
442
 
443
+ # Sort messages by timestamp
444
+ all_messages.sort(key=lambda x: x[0])
 
445
 
446
+ # Join messages
447
+ formatted_output = "\n".join([msg[1] for msg in all_messages])
448
 
449
+ return formatted_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  except Exception as e:
452
  import traceback
453
+ return f"Error: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
454
 
455
  # Define the Gradio interface
456
+ with gr.Blocks(title="Data Analyst Duo - MCP Communication") as app:
457
+ gr.Markdown("""
458
+ # Data Analyst Duo - Model Context Protocol (MCP) Implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
+ This application demonstrates a multi-agent system using the Model Context Protocol (MCP).
461
+ It consists of two agents:
462
 
463
+ 1. **ComputeAgent**: Responsible for data loading, cleaning, and computation
464
+ 2. **InterpretAgent**: Responsible for interpreting results
 
 
 
465
 
466
+ The agents communicate directly using standardized MCP messages, showcasing agent-to-agent communication.
467
+ """)
 
 
 
468
 
469
+ dataset_url = gr.Textbox(label="Dataset URL (leave empty for default cereals dataset)", placeholder="Enter dataset URL or leave empty for default")
470
+ run_button = gr.Button("Run Analysis")
471
+ mcp_messages = gr.Textbox(label="MCP Message Flow", lines=30)
 
 
 
472
 
473
+ run_button.click(fn=run_analysis, inputs=dataset_url, outputs=mcp_messages)
 
474
 
475
+ gr.Markdown("""
476
+ ## How This Demonstrates MCP
477
 
478
+ This application shows the Model Context Protocol in action:
 
 
 
479
 
480
+ 1. **Standardized Message Structure**: All communication between agents follows a consistent format
481
+ 2. **Direct Peer Communication**: Agents communicate directly with structured messages
482
+ 3. **Asynchronous Processing**: Each agent processes messages independently
483
 
484
+ The message flow display shows the exact JSON messages exchanged between agents, demonstrating the protocol in action.
485
+ """)
486
 
487
+ # Launch the app
488
  if __name__ == "__main__":
 
489
  app.launch()
requirements.txt CHANGED
@@ -1,9 +1,4 @@
1
  gradio==4.13.0
2
  pandas==2.1.1
3
  numpy==1.26.0
4
- matplotlib==3.8.0
5
- seaborn==0.13.0
6
- anthropic==0.8.1
7
- openai==1.1.1
8
- python-dotenv==1.0.0
9
  requests==2.31.0
 
1
  gradio==4.13.0
2
  pandas==2.1.1
3
  numpy==1.26.0
 
 
 
 
 
4
  requests==2.31.0