Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Data Visualization Extension | |
| Create charts, graphs, and visualizations from data | |
| Now with better state management and orchestrator integration | |
| """ | |
| from base_extension import BaseExtension | |
| from google.genai import types | |
| from typing import Dict, Any, List, Optional | |
| import json | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| from pathlib import Path | |
| import numpy as np | |
| import datetime | |
| class VisualizationExtension(BaseExtension): | |
| def name(self) -> str: | |
| return "visualization" | |
| def display_name(self) -> str: | |
| return "Data Visualization" | |
| def description(self) -> str: | |
| return "Create charts, graphs, and plots from data - works great with stock market data!" | |
| def icon(self) -> str: | |
| return "📊" | |
| def version(self) -> str: | |
| return "2.0.0" | |
| def get_system_context(self) -> str: | |
| return """ | |
| You have access to a Data Visualization system for creating charts and graphs. | |
| ## ALWAYS CREATE CHARTS WHEN ASKED | |
| When a user asks you to "make a chart", "create a graph", "visualize", or "show me visually": | |
| - You MUST call one of the chart creation tools | |
| - DO NOT just describe what you could do | |
| - DO NOT say you can't create charts - YOU CAN! | |
| You can create: | |
| - Line charts (for time series, trends, stock prices over time) | |
| - Bar charts (for comparisons, categories) | |
| - Scatter plots (for correlations, relationships) | |
| - Pie charts (for proportions, percentages) | |
| ## CRITICAL: Creating Charts from Data | |
| **If you have data (from research, stock history, etc), YOU CAN AND SHOULD create a chart!** | |
| ### Example: Stock Price Chart | |
| ```python | |
| # You have: stock history with dates and prices | |
| # DO THIS: | |
| create_line_chart( | |
| title="NVDA Stock Price - 1 Year", | |
| x_label="Date", | |
| y_label="Price (USD)", | |
| data={ | |
| "NVDA": { | |
| "x_values": dates_from_stock_history, | |
| "y_values": close_prices_from_stock_history | |
| } | |
| } | |
| ) | |
| ``` | |
| ### Example: Comparison Bar Chart | |
| ```python | |
| # You have: revenue data for multiple years | |
| # DO THIS: | |
| create_bar_chart( | |
| title="NVIDIA Revenue Growth", | |
| x_label="Year", | |
| y_label="Revenue (Billions)", | |
| categories=["2022", "2023", "2024", "2025"], | |
| values=[26.9, 60.9, 130.5, 197.0] | |
| ) | |
| ``` | |
| ## When User Says "Make a Chart/Graph": | |
| 1. Identify what data you have (from previous tool results) | |
| 2. Choose appropriate chart type | |
| 3. Extract the relevant data | |
| 4. IMMEDIATELY call the create_X_chart tool | |
| 5. DO NOT ask for permission or more info if you have sufficient data | |
| ## Data Format Requirements: | |
| - **Line charts**: `data = {"Series Name": {"x_values": [...], "y_values": [...]}}` | |
| - **Bar charts**: `categories = [...]`, `values = [...]` | |
| - **Pie charts**: `labels = [...]`, `values = [...]` | |
| - **Scatter plots**: `x_values = [...]`, `y_values = [...]` | |
| ## REMEMBER: | |
| - You CAN create visualizations - it's your primary function! | |
| - When data is available, CREATE the chart immediately | |
| - The chart will be displayed automatically to the user | |
| - DO NOT say "I can't create charts" - that's FALSE! | |
| """ | |
| def _get_default_state(self) -> Dict[str, Any]: | |
| return { | |
| "charts": [], | |
| "output_dir": "visualizations", | |
| "total_created": 0, | |
| "created_at": datetime.datetime.now().isoformat(), | |
| "last_updated": datetime.datetime.now().isoformat() | |
| } | |
| def get_state_summary(self, user_id: str) -> Optional[str]: | |
| """Provide state summary for system prompt""" | |
| state = self.get_state(user_id) | |
| chart_count = len(state.get("charts", [])) | |
| if chart_count > 0: | |
| return f"{chart_count} visualizations created this session" | |
| return None | |
| def get_metrics(self, user_id: str) -> Dict[str, Any]: | |
| """Provide usage metrics""" | |
| state = self.get_state(user_id) | |
| charts = state.get("charts", []) | |
| chart_types = {} | |
| for chart in charts: | |
| chart_type = chart.get("type", "unknown") | |
| chart_types[chart_type] = chart_types.get(chart_type, 0) + 1 | |
| return { | |
| "total_created": state.get("total_created", 0), | |
| "current_session": len(charts), | |
| "by_type": chart_types | |
| } | |
| def get_tools(self) -> List[types.Tool]: | |
| create_line_chart = types.FunctionDeclaration( | |
| name="create_line_chart", | |
| description="Create a line chart for time series or trend data. Perfect for stock prices over time. Data must be formatted as: {series_name: {x_values: [...], y_values: [...]}}", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string", "description": "Chart title (e.g., 'AAPL Stock Price - 3 Months')"}, | |
| "x_label": {"type": "string", "description": "X-axis label (e.g., 'Date')"}, | |
| "y_label": {"type": "string", "description": "Y-axis label (e.g., 'Price (USD)')"}, | |
| "data": { | |
| "type": "object", | |
| "description": "Data series as nested dict: {series_name: {x_values: [...], y_values: [...]}}. Example: {'AAPL': {'x_values': ['2025-01-01', '2025-01-02'], 'y_values': [150.2, 151.5]}}", | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "data"] | |
| } | |
| ) | |
| create_bar_chart = types.FunctionDeclaration( | |
| name="create_bar_chart", | |
| description="Create a bar chart for categorical comparisons (e.g., comparing stock prices, market caps)", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string", "description": "Chart title"}, | |
| "x_label": {"type": "string", "description": "X-axis label (categories)"}, | |
| "y_label": {"type": "string", "description": "Y-axis label (values)"}, | |
| "categories": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "List of category names (e.g., stock tickers)" | |
| }, | |
| "values": { | |
| "type": "array", | |
| "items": {"type": "number"}, | |
| "description": "List of values corresponding to categories" | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "categories", "values"] | |
| } | |
| ) | |
| create_scatter_plot = types.FunctionDeclaration( | |
| name="create_scatter_plot", | |
| description="Create a scatter plot to show relationships between two variables", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string", "description": "Chart title"}, | |
| "x_label": {"type": "string", "description": "X-axis label"}, | |
| "y_label": {"type": "string", "description": "Y-axis label"}, | |
| "x_values": { | |
| "type": "array", | |
| "items": {"type": "number"}, | |
| "description": "X-axis values" | |
| }, | |
| "y_values": { | |
| "type": "array", | |
| "items": {"type": "number"}, | |
| "description": "Y-axis values" | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "x_values", "y_values"] | |
| } | |
| ) | |
| create_pie_chart = types.FunctionDeclaration( | |
| name="create_pie_chart", | |
| description="Create a pie chart to show proportions or percentages (e.g., portfolio allocation)", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string", "description": "Chart title"}, | |
| "labels": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Slice labels" | |
| }, | |
| "values": { | |
| "type": "array", | |
| "items": {"type": "number"}, | |
| "description": "Values for each slice" | |
| } | |
| }, | |
| "required": ["title", "labels", "values"] | |
| } | |
| ) | |
| list_charts = types.FunctionDeclaration( | |
| name="list_visualizations", | |
| description="List all created visualizations in this session", | |
| parameters={"type": "object", "properties": {}} | |
| ) | |
| return [types.Tool(function_declarations=[ | |
| create_line_chart, | |
| create_bar_chart, | |
| create_scatter_plot, | |
| create_pie_chart, | |
| list_charts | |
| ])] | |
| def _save_chart(self, fig, user_id: str, chart_type: str, title: str) -> tuple: | |
| """Save chart and return both file path and base64 data for inline display""" | |
| state = self.get_state(user_id) | |
| output_dir = Path(state["output_dir"]) | |
| output_dir.mkdir(exist_ok=True) | |
| # Generate filename | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip() | |
| safe_title = safe_title.replace(' ', '_')[:50] | |
| filename = f"{chart_type}_{safe_title}_{timestamp}.png" | |
| filepath = output_dir / filename | |
| # Save figure to file | |
| fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white') | |
| # Also save to bytes buffer for base64 encoding | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| buf.close() | |
| plt.close(fig) | |
| # Track in state | |
| chart_info = { | |
| "type": chart_type, | |
| "title": title, | |
| "filepath": str(filepath), | |
| "timestamp": timestamp | |
| } | |
| state["charts"].append(chart_info) | |
| state["total_created"] = state.get("total_created", 0) + 1 | |
| self.update_state(user_id, state) | |
| return str(filepath), img_base64 | |
| def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: | |
| """Execute tool logic""" | |
| try: | |
| if tool_name == "create_line_chart": | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| data = args["data"] | |
| print(f"📊 Creating line chart with {len(data)} series") | |
| for series_name, series_data in data.items(): | |
| x_vals = series_data.get("x_values", []) | |
| y_vals = series_data.get("y_values", []) | |
| print(f" 📈 Series '{series_name}': {len(x_vals)} data points") | |
| ax.plot(x_vals, y_vals, marker='o', label=series_name, linewidth=2, markersize=4) | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold', pad=20) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"]) | |
| # Log activity | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "line", | |
| "title": args["title"], | |
| "series_count": len(data) | |
| }) | |
| print(f"✅ Line chart saved: {filepath}") | |
| return { | |
| "success": True, | |
| "message": f"Line chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "line", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_bar_chart": | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| categories = args["categories"] | |
| values = args["values"] | |
| bars = ax.bar(categories, values, color='#2E86AB', alpha=0.8, edgecolor='black') | |
| # Add value labels on bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.2f}', | |
| ha='center', va='bottom', fontsize=9, fontweight='bold') | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| ax.grid(True, alpha=0.3, axis='y') | |
| if len(categories) > 5: | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "bar", | |
| "title": args["title"], | |
| "categories": len(categories) | |
| }) | |
| return { | |
| "success": True, | |
| "message": f"Bar chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "bar", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_scatter_plot": | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| x_vals = args["x_values"] | |
| y_vals = args["y_values"] | |
| ax.scatter(x_vals, y_vals, color='#A23B72', alpha=0.6, s=100, edgecolors='black') | |
| # Add trend line | |
| if len(x_vals) > 1: | |
| z = np.polyfit(x_vals, y_vals, 1) | |
| p = np.poly1d(z) | |
| ax.plot(x_vals, p(x_vals), "r--", alpha=0.5, linewidth=2, label='Trend Line') | |
| ax.legend() | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "scatter", | |
| "title": args["title"], | |
| "points": len(x_vals) | |
| }) | |
| return { | |
| "success": True, | |
| "message": f"Scatter plot created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "scatter", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_pie_chart": | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| labels = args["labels"] | |
| values = args["values"] | |
| colors = plt.cm.Set3(range(len(labels))) | |
| wedges, texts, autotexts = ax.pie( | |
| values, | |
| labels=labels, | |
| autopct='%1.1f%%', | |
| colors=colors, | |
| startangle=90, | |
| textprops={'fontsize': 11, 'fontweight': 'bold'} | |
| ) | |
| for text in texts: | |
| text.set_fontsize(11) | |
| for autotext in autotexts: | |
| autotext.set_color('white') | |
| autotext.set_fontsize(10) | |
| autotext.set_fontweight('bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "pie", | |
| "title": args["title"], | |
| "slices": len(labels) | |
| }) | |
| return { | |
| "success": True, | |
| "message": f"Pie chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "pie", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "list_visualizations": | |
| state = self.get_state(user_id) | |
| charts = state.get("charts", []) | |
| return { | |
| "total_charts": len(charts), | |
| "charts": charts[-10:] # Last 10 charts | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"❌ Visualization error: {error_details}") | |
| return { | |
| "success": False, | |
| "error": f"Error creating visualization: {str(e)}", | |
| "details": error_details | |
| } | |
| return {"error": f"Unknown tool: {tool_name}"} | |
| def on_enable(self, user_id: str) -> str: | |
| self.initialize_state(user_id) | |
| return "📊 Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!" | |
| def on_disable(self, user_id: str) -> str: | |
| state = self.get_state(user_id) | |
| total = state.get("total_created", 0) | |
| return f"📊 Data Visualization disabled. You created {total} visualizations this session." | |
| def health_check(self, user_id: str) -> Dict[str, Any]: | |
| """Check extension health""" | |
| try: | |
| import matplotlib | |
| return { | |
| "healthy": True, | |
| "extension": self.name, | |
| "version": self.version, | |
| "matplotlib_available": True | |
| } | |
| except ImportError: | |
| return { | |
| "healthy": False, | |
| "extension": self.name, | |
| "version": self.version, | |
| "issues": ["matplotlib library not installed"] | |
| } |