""" Enhanced Data Visualization Extension Create charts, graphs, and visualizations from data Now with better state management and orchestrator integration """ from base_extension import BaseExtension from google.genai import types from typing import Dict, Any, List, Optional import json import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import io import base64 from pathlib import Path import numpy as np import datetime class VisualizationExtension(BaseExtension): @property def name(self) -> str: return "visualization" @property def display_name(self) -> str: return "Data Visualization" @property def description(self) -> str: return "Create charts, graphs, and plots from data - works great with stock market data!" @property def icon(self) -> str: return "📊" @property def version(self) -> str: return "2.0.0" def get_system_context(self) -> str: return """ You have access to a Data Visualization system for creating charts and graphs. ## ALWAYS CREATE CHARTS WHEN ASKED When a user asks you to "make a chart", "create a graph", "visualize", or "show me visually": - You MUST call one of the chart creation tools - DO NOT just describe what you could do - DO NOT say you can't create charts - YOU CAN! You can create: - Line charts (for time series, trends, stock prices over time) - Bar charts (for comparisons, categories) - Scatter plots (for correlations, relationships) - Pie charts (for proportions, percentages) ## CRITICAL: Creating Charts from Data **If you have data (from research, stock history, etc), YOU CAN AND SHOULD create a chart!** ### Example: Stock Price Chart ```python # You have: stock history with dates and prices # DO THIS: create_line_chart( title="NVDA Stock Price - 1 Year", x_label="Date", y_label="Price (USD)", data={ "NVDA": { "x_values": dates_from_stock_history, "y_values": close_prices_from_stock_history } } ) ``` ### Example: Comparison Bar Chart ```python # You have: revenue data for multiple years # DO THIS: create_bar_chart( title="NVIDIA Revenue Growth", x_label="Year", y_label="Revenue (Billions)", categories=["2022", "2023", "2024", "2025"], values=[26.9, 60.9, 130.5, 197.0] ) ``` ## When User Says "Make a Chart/Graph": 1. Identify what data you have (from previous tool results) 2. Choose appropriate chart type 3. Extract the relevant data 4. IMMEDIATELY call the create_X_chart tool 5. DO NOT ask for permission or more info if you have sufficient data ## Data Format Requirements: - **Line charts**: `data = {"Series Name": {"x_values": [...], "y_values": [...]}}` - **Bar charts**: `categories = [...]`, `values = [...]` - **Pie charts**: `labels = [...]`, `values = [...]` - **Scatter plots**: `x_values = [...]`, `y_values = [...]` ## REMEMBER: - You CAN create visualizations - it's your primary function! - When data is available, CREATE the chart immediately - The chart will be displayed automatically to the user - DO NOT say "I can't create charts" - that's FALSE! """ def _get_default_state(self) -> Dict[str, Any]: return { "charts": [], "output_dir": "visualizations", "total_created": 0, "created_at": datetime.datetime.now().isoformat(), "last_updated": datetime.datetime.now().isoformat() } def get_state_summary(self, user_id: str) -> Optional[str]: """Provide state summary for system prompt""" state = self.get_state(user_id) chart_count = len(state.get("charts", [])) if chart_count > 0: return f"{chart_count} visualizations created this session" return None def get_metrics(self, user_id: str) -> Dict[str, Any]: """Provide usage metrics""" state = self.get_state(user_id) charts = state.get("charts", []) chart_types = {} for chart in charts: chart_type = chart.get("type", "unknown") chart_types[chart_type] = chart_types.get(chart_type, 0) + 1 return { "total_created": state.get("total_created", 0), "current_session": len(charts), "by_type": chart_types } def get_tools(self) -> List[types.Tool]: create_line_chart = types.FunctionDeclaration( name="create_line_chart", description="Create a line chart for time series or trend data. Perfect for stock prices over time. Data must be formatted as: {series_name: {x_values: [...], y_values: [...]}}", parameters={ "type": "object", "properties": { "title": {"type": "string", "description": "Chart title (e.g., 'AAPL Stock Price - 3 Months')"}, "x_label": {"type": "string", "description": "X-axis label (e.g., 'Date')"}, "y_label": {"type": "string", "description": "Y-axis label (e.g., 'Price (USD)')"}, "data": { "type": "object", "description": "Data series as nested dict: {series_name: {x_values: [...], y_values: [...]}}. Example: {'AAPL': {'x_values': ['2025-01-01', '2025-01-02'], 'y_values': [150.2, 151.5]}}", } }, "required": ["title", "x_label", "y_label", "data"] } ) create_bar_chart = types.FunctionDeclaration( name="create_bar_chart", description="Create a bar chart for categorical comparisons (e.g., comparing stock prices, market caps)", parameters={ "type": "object", "properties": { "title": {"type": "string", "description": "Chart title"}, "x_label": {"type": "string", "description": "X-axis label (categories)"}, "y_label": {"type": "string", "description": "Y-axis label (values)"}, "categories": { "type": "array", "items": {"type": "string"}, "description": "List of category names (e.g., stock tickers)" }, "values": { "type": "array", "items": {"type": "number"}, "description": "List of values corresponding to categories" } }, "required": ["title", "x_label", "y_label", "categories", "values"] } ) create_scatter_plot = types.FunctionDeclaration( name="create_scatter_plot", description="Create a scatter plot to show relationships between two variables", parameters={ "type": "object", "properties": { "title": {"type": "string", "description": "Chart title"}, "x_label": {"type": "string", "description": "X-axis label"}, "y_label": {"type": "string", "description": "Y-axis label"}, "x_values": { "type": "array", "items": {"type": "number"}, "description": "X-axis values" }, "y_values": { "type": "array", "items": {"type": "number"}, "description": "Y-axis values" } }, "required": ["title", "x_label", "y_label", "x_values", "y_values"] } ) create_pie_chart = types.FunctionDeclaration( name="create_pie_chart", description="Create a pie chart to show proportions or percentages (e.g., portfolio allocation)", parameters={ "type": "object", "properties": { "title": {"type": "string", "description": "Chart title"}, "labels": { "type": "array", "items": {"type": "string"}, "description": "Slice labels" }, "values": { "type": "array", "items": {"type": "number"}, "description": "Values for each slice" } }, "required": ["title", "labels", "values"] } ) list_charts = types.FunctionDeclaration( name="list_visualizations", description="List all created visualizations in this session", parameters={"type": "object", "properties": {}} ) return [types.Tool(function_declarations=[ create_line_chart, create_bar_chart, create_scatter_plot, create_pie_chart, list_charts ])] def _save_chart(self, fig, user_id: str, chart_type: str, title: str) -> tuple: """Save chart and return both file path and base64 data for inline display""" state = self.get_state(user_id) output_dir = Path(state["output_dir"]) output_dir.mkdir(exist_ok=True) # Generate filename timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip() safe_title = safe_title.replace(' ', '_')[:50] filename = f"{chart_type}_{safe_title}_{timestamp}.png" filepath = output_dir / filename # Save figure to file fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white') # Also save to bytes buffer for base64 encoding buf = io.BytesIO() fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode('utf-8') buf.close() plt.close(fig) # Track in state chart_info = { "type": chart_type, "title": title, "filepath": str(filepath), "timestamp": timestamp } state["charts"].append(chart_info) state["total_created"] = state.get("total_created", 0) + 1 self.update_state(user_id, state) return str(filepath), img_base64 def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: """Execute tool logic""" try: if tool_name == "create_line_chart": fig, ax = plt.subplots(figsize=(12, 7)) data = args["data"] print(f"📊 Creating line chart with {len(data)} series") for series_name, series_data in data.items(): x_vals = series_data.get("x_values", []) y_vals = series_data.get("y_values", []) print(f" 📈 Series '{series_name}': {len(x_vals)} data points") ax.plot(x_vals, y_vals, marker='o', label=series_name, linewidth=2, markersize=4) ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') ax.set_title(args["title"], fontsize=14, fontweight='bold', pad=20) ax.legend(fontsize=10) ax.grid(True, alpha=0.3, linestyle='--') plt.xticks(rotation=45, ha='right') plt.tight_layout() filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"]) # Log activity self.log_activity(user_id, "chart_created", { "type": "line", "title": args["title"], "series_count": len(data) }) print(f"✅ Line chart saved: {filepath}") return { "success": True, "message": f"Line chart created: {args['title']}", "filepath": filepath, "chart_type": "line", "image_base64": img_base64 } elif tool_name == "create_bar_chart": fig, ax = plt.subplots(figsize=(10, 6)) categories = args["categories"] values = args["values"] bars = ax.bar(categories, values, color='#2E86AB', alpha=0.8, edgecolor='black') # Add value labels on bars for bar in bars: height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom', fontsize=9, fontweight='bold') ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') ax.set_title(args["title"], fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3, axis='y') if len(categories) > 5: plt.xticks(rotation=45, ha='right') plt.tight_layout() filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"]) self.log_activity(user_id, "chart_created", { "type": "bar", "title": args["title"], "categories": len(categories) }) return { "success": True, "message": f"Bar chart created: {args['title']}", "filepath": filepath, "chart_type": "bar", "image_base64": img_base64 } elif tool_name == "create_scatter_plot": fig, ax = plt.subplots(figsize=(10, 6)) x_vals = args["x_values"] y_vals = args["y_values"] ax.scatter(x_vals, y_vals, color='#A23B72', alpha=0.6, s=100, edgecolors='black') # Add trend line if len(x_vals) > 1: z = np.polyfit(x_vals, y_vals, 1) p = np.poly1d(z) ax.plot(x_vals, p(x_vals), "r--", alpha=0.5, linewidth=2, label='Trend Line') ax.legend() ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') ax.set_title(args["title"], fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) plt.tight_layout() filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"]) self.log_activity(user_id, "chart_created", { "type": "scatter", "title": args["title"], "points": len(x_vals) }) return { "success": True, "message": f"Scatter plot created: {args['title']}", "filepath": filepath, "chart_type": "scatter", "image_base64": img_base64 } elif tool_name == "create_pie_chart": fig, ax = plt.subplots(figsize=(10, 8)) labels = args["labels"] values = args["values"] colors = plt.cm.Set3(range(len(labels))) wedges, texts, autotexts = ax.pie( values, labels=labels, autopct='%1.1f%%', colors=colors, startangle=90, textprops={'fontsize': 11, 'fontweight': 'bold'} ) for text in texts: text.set_fontsize(11) for autotext in autotexts: autotext.set_color('white') autotext.set_fontsize(10) autotext.set_fontweight('bold') ax.set_title(args["title"], fontsize=14, fontweight='bold') filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"]) self.log_activity(user_id, "chart_created", { "type": "pie", "title": args["title"], "slices": len(labels) }) return { "success": True, "message": f"Pie chart created: {args['title']}", "filepath": filepath, "chart_type": "pie", "image_base64": img_base64 } elif tool_name == "list_visualizations": state = self.get_state(user_id) charts = state.get("charts", []) return { "total_charts": len(charts), "charts": charts[-10:] # Last 10 charts } except Exception as e: import traceback error_details = traceback.format_exc() print(f"❌ Visualization error: {error_details}") return { "success": False, "error": f"Error creating visualization: {str(e)}", "details": error_details } return {"error": f"Unknown tool: {tool_name}"} def on_enable(self, user_id: str) -> str: self.initialize_state(user_id) return "📊 Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!" def on_disable(self, user_id: str) -> str: state = self.get_state(user_id) total = state.get("total_created", 0) return f"📊 Data Visualization disabled. You created {total} visualizations this session." def health_check(self, user_id: str) -> Dict[str, Any]: """Check extension health""" try: import matplotlib return { "healthy": True, "extension": self.name, "version": self.version, "matplotlib_available": True } except ImportError: return { "healthy": False, "extension": self.name, "version": self.version, "issues": ["matplotlib library not installed"] }