gemiwine-agents / extensions /visualization.py
wuhp's picture
Update extensions/visualization.py
dceea05 verified
"""
Enhanced Data Visualization Extension
Create charts, graphs, and visualizations from data
Now with better state management and orchestrator integration
"""
from base_extension import BaseExtension
from google.genai import types
from typing import Dict, Any, List, Optional
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
import base64
from pathlib import Path
import numpy as np
import datetime
class VisualizationExtension(BaseExtension):
@property
def name(self) -> str:
return "visualization"
@property
def display_name(self) -> str:
return "Data Visualization"
@property
def description(self) -> str:
return "Create charts, graphs, and plots from data - works great with stock market data!"
@property
def icon(self) -> str:
return "📊"
@property
def version(self) -> str:
return "2.0.0"
def get_system_context(self) -> str:
return """
You have access to a Data Visualization system for creating charts and graphs.
## ALWAYS CREATE CHARTS WHEN ASKED
When a user asks you to "make a chart", "create a graph", "visualize", or "show me visually":
- You MUST call one of the chart creation tools
- DO NOT just describe what you could do
- DO NOT say you can't create charts - YOU CAN!
You can create:
- Line charts (for time series, trends, stock prices over time)
- Bar charts (for comparisons, categories)
- Scatter plots (for correlations, relationships)
- Pie charts (for proportions, percentages)
## CRITICAL: Creating Charts from Data
**If you have data (from research, stock history, etc), YOU CAN AND SHOULD create a chart!**
### Example: Stock Price Chart
```python
# You have: stock history with dates and prices
# DO THIS:
create_line_chart(
title="NVDA Stock Price - 1 Year",
x_label="Date",
y_label="Price (USD)",
data={
"NVDA": {
"x_values": dates_from_stock_history,
"y_values": close_prices_from_stock_history
}
}
)
```
### Example: Comparison Bar Chart
```python
# You have: revenue data for multiple years
# DO THIS:
create_bar_chart(
title="NVIDIA Revenue Growth",
x_label="Year",
y_label="Revenue (Billions)",
categories=["2022", "2023", "2024", "2025"],
values=[26.9, 60.9, 130.5, 197.0]
)
```
## When User Says "Make a Chart/Graph":
1. Identify what data you have (from previous tool results)
2. Choose appropriate chart type
3. Extract the relevant data
4. IMMEDIATELY call the create_X_chart tool
5. DO NOT ask for permission or more info if you have sufficient data
## Data Format Requirements:
- **Line charts**: `data = {"Series Name": {"x_values": [...], "y_values": [...]}}`
- **Bar charts**: `categories = [...]`, `values = [...]`
- **Pie charts**: `labels = [...]`, `values = [...]`
- **Scatter plots**: `x_values = [...]`, `y_values = [...]`
## REMEMBER:
- You CAN create visualizations - it's your primary function!
- When data is available, CREATE the chart immediately
- The chart will be displayed automatically to the user
- DO NOT say "I can't create charts" - that's FALSE!
"""
def _get_default_state(self) -> Dict[str, Any]:
return {
"charts": [],
"output_dir": "visualizations",
"total_created": 0,
"created_at": datetime.datetime.now().isoformat(),
"last_updated": datetime.datetime.now().isoformat()
}
def get_state_summary(self, user_id: str) -> Optional[str]:
"""Provide state summary for system prompt"""
state = self.get_state(user_id)
chart_count = len(state.get("charts", []))
if chart_count > 0:
return f"{chart_count} visualizations created this session"
return None
def get_metrics(self, user_id: str) -> Dict[str, Any]:
"""Provide usage metrics"""
state = self.get_state(user_id)
charts = state.get("charts", [])
chart_types = {}
for chart in charts:
chart_type = chart.get("type", "unknown")
chart_types[chart_type] = chart_types.get(chart_type, 0) + 1
return {
"total_created": state.get("total_created", 0),
"current_session": len(charts),
"by_type": chart_types
}
def get_tools(self) -> List[types.Tool]:
create_line_chart = types.FunctionDeclaration(
name="create_line_chart",
description="Create a line chart for time series or trend data. Perfect for stock prices over time. Data must be formatted as: {series_name: {x_values: [...], y_values: [...]}}",
parameters={
"type": "object",
"properties": {
"title": {"type": "string", "description": "Chart title (e.g., 'AAPL Stock Price - 3 Months')"},
"x_label": {"type": "string", "description": "X-axis label (e.g., 'Date')"},
"y_label": {"type": "string", "description": "Y-axis label (e.g., 'Price (USD)')"},
"data": {
"type": "object",
"description": "Data series as nested dict: {series_name: {x_values: [...], y_values: [...]}}. Example: {'AAPL': {'x_values': ['2025-01-01', '2025-01-02'], 'y_values': [150.2, 151.5]}}",
}
},
"required": ["title", "x_label", "y_label", "data"]
}
)
create_bar_chart = types.FunctionDeclaration(
name="create_bar_chart",
description="Create a bar chart for categorical comparisons (e.g., comparing stock prices, market caps)",
parameters={
"type": "object",
"properties": {
"title": {"type": "string", "description": "Chart title"},
"x_label": {"type": "string", "description": "X-axis label (categories)"},
"y_label": {"type": "string", "description": "Y-axis label (values)"},
"categories": {
"type": "array",
"items": {"type": "string"},
"description": "List of category names (e.g., stock tickers)"
},
"values": {
"type": "array",
"items": {"type": "number"},
"description": "List of values corresponding to categories"
}
},
"required": ["title", "x_label", "y_label", "categories", "values"]
}
)
create_scatter_plot = types.FunctionDeclaration(
name="create_scatter_plot",
description="Create a scatter plot to show relationships between two variables",
parameters={
"type": "object",
"properties": {
"title": {"type": "string", "description": "Chart title"},
"x_label": {"type": "string", "description": "X-axis label"},
"y_label": {"type": "string", "description": "Y-axis label"},
"x_values": {
"type": "array",
"items": {"type": "number"},
"description": "X-axis values"
},
"y_values": {
"type": "array",
"items": {"type": "number"},
"description": "Y-axis values"
}
},
"required": ["title", "x_label", "y_label", "x_values", "y_values"]
}
)
create_pie_chart = types.FunctionDeclaration(
name="create_pie_chart",
description="Create a pie chart to show proportions or percentages (e.g., portfolio allocation)",
parameters={
"type": "object",
"properties": {
"title": {"type": "string", "description": "Chart title"},
"labels": {
"type": "array",
"items": {"type": "string"},
"description": "Slice labels"
},
"values": {
"type": "array",
"items": {"type": "number"},
"description": "Values for each slice"
}
},
"required": ["title", "labels", "values"]
}
)
list_charts = types.FunctionDeclaration(
name="list_visualizations",
description="List all created visualizations in this session",
parameters={"type": "object", "properties": {}}
)
return [types.Tool(function_declarations=[
create_line_chart,
create_bar_chart,
create_scatter_plot,
create_pie_chart,
list_charts
])]
def _save_chart(self, fig, user_id: str, chart_type: str, title: str) -> tuple:
"""Save chart and return both file path and base64 data for inline display"""
state = self.get_state(user_id)
output_dir = Path(state["output_dir"])
output_dir.mkdir(exist_ok=True)
# Generate filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
safe_title = safe_title.replace(' ', '_')[:50]
filename = f"{chart_type}_{safe_title}_{timestamp}.png"
filepath = output_dir / filename
# Save figure to file
fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white')
# Also save to bytes buffer for base64 encoding
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
plt.close(fig)
# Track in state
chart_info = {
"type": chart_type,
"title": title,
"filepath": str(filepath),
"timestamp": timestamp
}
state["charts"].append(chart_info)
state["total_created"] = state.get("total_created", 0) + 1
self.update_state(user_id, state)
return str(filepath), img_base64
def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
"""Execute tool logic"""
try:
if tool_name == "create_line_chart":
fig, ax = plt.subplots(figsize=(12, 7))
data = args["data"]
print(f"📊 Creating line chart with {len(data)} series")
for series_name, series_data in data.items():
x_vals = series_data.get("x_values", [])
y_vals = series_data.get("y_values", [])
print(f" 📈 Series '{series_name}': {len(x_vals)} data points")
ax.plot(x_vals, y_vals, marker='o', label=series_name, linewidth=2, markersize=4)
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold', pad=20)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, linestyle='--')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
# Log activity
self.log_activity(user_id, "chart_created", {
"type": "line",
"title": args["title"],
"series_count": len(data)
})
print(f"✅ Line chart saved: {filepath}")
return {
"success": True,
"message": f"Line chart created: {args['title']}",
"filepath": filepath,
"chart_type": "line",
"image_base64": img_base64
}
elif tool_name == "create_bar_chart":
fig, ax = plt.subplots(figsize=(10, 6))
categories = args["categories"]
values = args["values"]
bars = ax.bar(categories, values, color='#2E86AB', alpha=0.8, edgecolor='black')
# Add value labels on bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}',
ha='center', va='bottom', fontsize=9, fontweight='bold')
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
if len(categories) > 5:
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "bar",
"title": args["title"],
"categories": len(categories)
})
return {
"success": True,
"message": f"Bar chart created: {args['title']}",
"filepath": filepath,
"chart_type": "bar",
"image_base64": img_base64
}
elif tool_name == "create_scatter_plot":
fig, ax = plt.subplots(figsize=(10, 6))
x_vals = args["x_values"]
y_vals = args["y_values"]
ax.scatter(x_vals, y_vals, color='#A23B72', alpha=0.6, s=100, edgecolors='black')
# Add trend line
if len(x_vals) > 1:
z = np.polyfit(x_vals, y_vals, 1)
p = np.poly1d(z)
ax.plot(x_vals, p(x_vals), "r--", alpha=0.5, linewidth=2, label='Trend Line')
ax.legend()
ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold')
ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "scatter",
"title": args["title"],
"points": len(x_vals)
})
return {
"success": True,
"message": f"Scatter plot created: {args['title']}",
"filepath": filepath,
"chart_type": "scatter",
"image_base64": img_base64
}
elif tool_name == "create_pie_chart":
fig, ax = plt.subplots(figsize=(10, 8))
labels = args["labels"]
values = args["values"]
colors = plt.cm.Set3(range(len(labels)))
wedges, texts, autotexts = ax.pie(
values,
labels=labels,
autopct='%1.1f%%',
colors=colors,
startangle=90,
textprops={'fontsize': 11, 'fontweight': 'bold'}
)
for text in texts:
text.set_fontsize(11)
for autotext in autotexts:
autotext.set_color('white')
autotext.set_fontsize(10)
autotext.set_fontweight('bold')
ax.set_title(args["title"], fontsize=14, fontweight='bold')
filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
self.log_activity(user_id, "chart_created", {
"type": "pie",
"title": args["title"],
"slices": len(labels)
})
return {
"success": True,
"message": f"Pie chart created: {args['title']}",
"filepath": filepath,
"chart_type": "pie",
"image_base64": img_base64
}
elif tool_name == "list_visualizations":
state = self.get_state(user_id)
charts = state.get("charts", [])
return {
"total_charts": len(charts),
"charts": charts[-10:] # Last 10 charts
}
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"❌ Visualization error: {error_details}")
return {
"success": False,
"error": f"Error creating visualization: {str(e)}",
"details": error_details
}
return {"error": f"Unknown tool: {tool_name}"}
def on_enable(self, user_id: str) -> str:
self.initialize_state(user_id)
return "📊 Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!"
def on_disable(self, user_id: str) -> str:
state = self.get_state(user_id)
total = state.get("total_created", 0)
return f"📊 Data Visualization disabled. You created {total} visualizations this session."
def health_check(self, user_id: str) -> Dict[str, Any]:
"""Check extension health"""
try:
import matplotlib
return {
"healthy": True,
"extension": self.name,
"version": self.version,
"matplotlib_available": True
}
except ImportError:
return {
"healthy": False,
"extension": self.name,
"version": self.version,
"issues": ["matplotlib library not installed"]
}