wuhp-agents / extensions /visualization.py
wuhp's picture
Update extensions/visualization.py
5383630 verified
"""
Enhanced Data Visualization Extension v3.0 - MODULAR
Create charts from ANY data source through capability-based integration.
No hardcoded assumptions about data sources!
"""
from base_extension import BaseExtension
from google.genai import types
from typing import Dict, Any, List, Optional
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
import base64
from pathlib import Path
import numpy as np
import datetime
class VisualizationExtension(BaseExtension):
@property
def name(self) -> str:
return "visualization"
@property
def display_name(self) -> str:
return "Data Visualization"
@property
def description(self) -> str:
return "Create charts, graphs, and plots from ANY numerical data"
@property
def icon(self) -> str:
return "πŸ“Š"
@property
def version(self) -> str:
return "3.0.0"
# ==========================================
# CAPABILITY-BASED DISCOVERY
# ==========================================
def get_capabilities(self) -> Dict[str, Any]:
"""Declare what Visualization can do"""
return {
'provides_data': [], # Visualization is an end node
'consumes_data': ['time_series', 'comparison_table', 'categorical_data',
'numerical_series', 'stock_history', 'financial_data'],
'creates_output': ['visualization', 'chart', 'graph', 'plot'],
'keywords': [
'chart', 'graph', 'plot', 'visualize', 'show', 'display',
'diagram', 'figure', 'image', 'visual', 'draw'
],
'data_outputs': {
'chart': {
'format': 'image',
'fields': ['image_base64', 'filepath', 'title', 'chart_type']
}
}
}
def get_suggested_next_action(self, tool_result: Dict[str, Any],
available_extensions: List['BaseExtension']) -> Optional[Dict[str, Any]]:
"""
Visualization is typically an end node, so no follow-up suggestions.
Could suggest saving/sharing if those extensions existed.
"""
return None
# ==========================================
# SYSTEM CONTEXT
# ==========================================
def get_system_context(self) -> str:
return """
You have access to Data Visualization for creating charts and graphs.
## CRITICAL: Always Create Charts When Requested!
When user says "make a chart", "create a graph", "visualize", or "show me":
- YOU MUST call one of the chart creation tools
- DO NOT just describe what you could do
- DO NOT say you can't create charts - YOU CAN!
## Chart Types:
- **Line charts**: Time series, trends, stock prices over time
- **Bar charts**: Comparisons, categories
- **Scatter plots**: Correlations, relationships
- **Pie charts**: Proportions, percentages
## Data Format:
Charts work with ANY data that has numerical values:
- **Line chart**: `data = {"Series": {"x_values": [...], "y_values": [...]}}`
- **Bar chart**: `categories = [...]`, `values = [...]`
- **Scatter plot**: `x_values = [...]`, `y_values = [...]`
- **Pie chart**: `labels = [...]`, `values = [...]`
## Integration:
This extension automatically receives formatted data from other extensions.
When suggestions appear, follow them to create charts!
## Remember:
- You CAN and SHOULD create visualizations
- Charts display automatically to users
- Extract data from previous tool results to visualize
"""
# ==========================================
# STATE MANAGEMENT
# ==========================================
def _get_default_state(self) -> Dict[str, Any]:
return {
"charts": [],
"output_dir": "visualizations",
"total_created": 0,
"created_at": datetime.datetime.now().isoformat(),
"last_updated": datetime.datetime.now().isoformat()
}
def get_state_summary(self, user_id: str) -> Optional[str]:
state = self.get_state(user_id)
chart_count = len(state.get("charts", []))
if chart_count > 0:
return f"{chart_count} visualizations created"
return None
def get_metrics(self, user_id: str) -> Dict[str, Any]:
state = self.get_state(user_id)
charts = state.get("charts", [])
chart_types = {}
for chart in charts:
chart_type = chart.get("type", "unknown")
chart_types[chart_type] = chart_types.get(chart_type, 0) + 1
return {
"total_created": state.get("total_created", 0),
"current_session": len(charts),
"by_type": chart_types
}
# ==========================================
# TOOLS
# ==========================================
def get_tools(self) -> List[types.Tool]:
create_line_chart = types.FunctionDeclaration(
name="create_line_chart",
description="Create line chart for time series or trends. Perfect for stock prices, growth over time, etc.",
parameters={
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Chart title"
},
"x_label": {
"type": "string",
"description": "X-axis label"
},
"y_label": {
"type": "string",
"description": "Y-axis label"
},
"data": {
"type": "object",
"description": "Data series: {series_name: {x_values: [...], y_values: [...]}}"
}
},
"required": ["title", "x_label", "y_label", "data"]
}
)
create_bar_chart = types.FunctionDeclaration(
name="create_bar_chart",
description="Create bar chart for categorical comparisons",
parameters={
"type": "object",
"properties": {
"title": {"type": "string", "description": "Chart title"},
"x_label": {"type": "string", "description": "X-axis label"},
"y_label": {"type": "string", "description": "Y-axis label"},
"categories": {
"type": "array",
"items": {"type": "string"},
"description": "Category names"
},
"values": {
"type": "array",
"items": {"type": "number"},
"description": "Values for each category"
}
},
"required": ["title", "x_label", "y_label", "categories", "values"]
}
)
create_scatter_plot = types.FunctionDeclaration(
name="create_scatter_plot",
description="Create scatter plot to show relationships between two variables",
parameters={
"type": "object",
"properties": {
"title": {"type": "string"},
"x_label": {"type": "string"},
"y_label": {"type": "string"},
"x_values": {
"type": "array",
"items": {"type": "number"}
},
"y_values": {
"type": "array",
"items": {"type": "number"}
}
},
"required": ["title", "x_label", "y_label", "x_values", "y_values"]
}
)
create_pie_chart = types.FunctionDeclaration(
name="create_pie_chart",
description="Create pie chart for proportions or percentages",
parameters={
"type": "object",
"properties": {
"title": {"type": "string"},
"labels": {
"type": "array",
"items": {"type": "string"}
},
"values": {
"type": "array",
"items": {"type": "number"}
}
},
"required": ["title", "labels", "values"]
}
)
list_charts = types.FunctionDeclaration(
name="list_visualizations",
description="List all created visualizations",
parameters={"type": "object", "properties": {}}
)
return [types.Tool(function_declarations=[
create_line_chart,
create_bar_chart,
create_scatter_plot,
create_pie_chart,
list_charts
])]
# ==========================================
# HELPER METHODS
# ==========================================
def _save_chart(self, fig, user_id: str, chart_type: str, title: str) -> tuple:
"""Save chart and return filepath + base64 for inline display"""
state = self.get_state(user_id)
output_dir = Path(state["output_dir"])
output_dir.mkdir(exist_ok=True)
# Generate filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
safe_title = safe_title.replace(' ', '_')[:50]
filename = f"{chart_type}_{safe_title}_{timestamp}.png"
filepath = output_dir / filename
# Save to file
fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white')
# Save to base64 for inline display
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
plt.close(fig)
# Track in state
chart_info = {
"type": chart_type,
"title": title,
"filepath": str(filepath),
"timestamp": timestamp
}
state["charts"].append(chart_info)
state["total_created"] = state.get("total_created", 0) + 1
self.update_state(user_id, state)
return str(filepath), img_base64
# ==========================================
# TOOL EXECUTION
# ==========================================
def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
"""Execute visualization tools"""
try:
if tool_name == "create_line_chart":
fig, ax = plt.subplots(figsize=(12, 7))
data = args["data"]
print(f"πŸ“Š Creating line chart with {len(data)} series")
for series_name, series_data in data.items():
x_vals = series_data.get("x_values", [])
y_vals = series_data.get("y_values", [])
print(f" πŸ“ˆ Series '{series_name}': {len(x_vals)} points")
ax.plot(x_vals, y_vals, marker='o', label=series_name,
linewidth=2, markersize=4)
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold', pad=20)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, linestyle='--')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "line",
"title": args["title"],
"series_count": len(data)
})
print(f"βœ… Line chart saved: {filepath}")
return {
"success": True,
"message": f"Line chart created: {args['title']}",
"filepath": filepath,
"chart_type": "line",
"image_base64": img_base64
}
elif tool_name == "create_bar_chart":
fig, ax = plt.subplots(figsize=(10, 6))
categories = args["categories"]
values = args["values"]
print(f"πŸ“Š Creating bar chart with {len(categories)} categories")
bars = ax.bar(categories, values, color='#2E86AB', alpha=0.8, edgecolor='black')
# Add value labels on bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}',
ha='center', va='bottom', fontsize=9, fontweight='bold')
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
if len(categories) > 5:
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "bar",
"title": args["title"],
"categories": len(categories)
})
print(f"βœ… Bar chart saved: {filepath}")
return {
"success": True,
"message": f"Bar chart created: {args['title']}",
"filepath": filepath,
"chart_type": "bar",
"image_base64": img_base64
}
elif tool_name == "create_scatter_plot":
fig, ax = plt.subplots(figsize=(10, 6))
x_vals = args["x_values"]
y_vals = args["y_values"]
print(f"πŸ“Š Creating scatter plot with {len(x_vals)} points")
ax.scatter(x_vals, y_vals, color='#A23B72', alpha=0.6,
s=100, edgecolors='black')
# Add trend line if enough points
if len(x_vals) > 1:
z = np.polyfit(x_vals, y_vals, 1)
p = np.poly1d(z)
ax.plot(x_vals, p(x_vals), "r--", alpha=0.5,
linewidth=2, label='Trend Line')
ax.legend()
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "scatter",
"title": args["title"],
"points": len(x_vals)
})
print(f"βœ… Scatter plot saved: {filepath}")
return {
"success": True,
"message": f"Scatter plot created: {args['title']}",
"filepath": filepath,
"chart_type": "scatter",
"image_base64": img_base64
}
elif tool_name == "create_pie_chart":
fig, ax = plt.subplots(figsize=(10, 8))
labels = args["labels"]
values = args["values"]
print(f"πŸ“Š Creating pie chart with {len(labels)} slices")
colors = plt.cm.Set3(range(len(labels)))
wedges, texts, autotexts = ax.pie(
values,
labels=labels,
autopct='%1.1f%%',
colors=colors,
startangle=90,
textprops={'fontsize': 11, 'fontweight': 'bold'}
)
for text in texts:
text.set_fontsize(11)
for autotext in autotexts:
autotext.set_color('white')
autotext.set_fontsize(10)
autotext.set_fontweight('bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "pie",
"title": args["title"],
"slices": len(labels)
})
print(f"βœ… Pie chart saved: {filepath}")
return {
"success": True,
"message": f"Pie chart created: {args['title']}",
"filepath": filepath,
"chart_type": "pie",
"image_base64": img_base64
}
elif tool_name == "list_visualizations":
state = self.get_state(user_id)
charts = state.get("charts", [])
return {
"total_charts": len(charts),
"charts": charts[-10:] # Last 10
}
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"❌ Visualization error: {error_details}")
return {
"success": False,
"error": f"Error creating visualization: {str(e)}",
"details": error_details
}
return {"error": f"Unknown tool: {tool_name}"}
def on_enable(self, user_id: str) -> str:
self.initialize_state(user_id)
return "πŸ“Š Data Visualization enabled! Create charts from any numerical data. Just ask!"
def on_disable(self, user_id: str) -> str:
state = self.get_state(user_id)
total = state.get("total_created", 0)
return f"πŸ“Š Visualization disabled. {total} charts created this session."
def health_check(self, user_id: str) -> Dict[str, Any]:
try:
import matplotlib
return {
"healthy": True,
"extension": self.name,
"version": self.version,
"matplotlib_available": True
}
except ImportError:
return {
"healthy": False,
"extension": self.name,
"version": self.version,
"issues": ["matplotlib not installed"]
}