Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Data Visualization Extension v3.0 - MODULAR | |
| Create charts from ANY data source through capability-based integration. | |
| No hardcoded assumptions about data sources! | |
| """ | |
| from base_extension import BaseExtension | |
| from google.genai import types | |
| from typing import Dict, Any, List, Optional | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| from pathlib import Path | |
| import numpy as np | |
| import datetime | |
| class VisualizationExtension(BaseExtension): | |
| def name(self) -> str: | |
| return "visualization" | |
| def display_name(self) -> str: | |
| return "Data Visualization" | |
| def description(self) -> str: | |
| return "Create charts, graphs, and plots from ANY numerical data" | |
| def icon(self) -> str: | |
| return "π" | |
| def version(self) -> str: | |
| return "3.0.0" | |
| # ========================================== | |
| # CAPABILITY-BASED DISCOVERY | |
| # ========================================== | |
| def get_capabilities(self) -> Dict[str, Any]: | |
| """Declare what Visualization can do""" | |
| return { | |
| 'provides_data': [], # Visualization is an end node | |
| 'consumes_data': ['time_series', 'comparison_table', 'categorical_data', | |
| 'numerical_series', 'stock_history', 'financial_data'], | |
| 'creates_output': ['visualization', 'chart', 'graph', 'plot'], | |
| 'keywords': [ | |
| 'chart', 'graph', 'plot', 'visualize', 'show', 'display', | |
| 'diagram', 'figure', 'image', 'visual', 'draw' | |
| ], | |
| 'data_outputs': { | |
| 'chart': { | |
| 'format': 'image', | |
| 'fields': ['image_base64', 'filepath', 'title', 'chart_type'] | |
| } | |
| } | |
| } | |
| def get_suggested_next_action(self, tool_result: Dict[str, Any], | |
| available_extensions: List['BaseExtension']) -> Optional[Dict[str, Any]]: | |
| """ | |
| Visualization is typically an end node, so no follow-up suggestions. | |
| Could suggest saving/sharing if those extensions existed. | |
| """ | |
| return None | |
| # ========================================== | |
| # SYSTEM CONTEXT | |
| # ========================================== | |
| def get_system_context(self) -> str: | |
| return """ | |
| You have access to Data Visualization for creating charts and graphs. | |
| ## CRITICAL: Always Create Charts When Requested! | |
| When user says "make a chart", "create a graph", "visualize", or "show me": | |
| - YOU MUST call one of the chart creation tools | |
| - DO NOT just describe what you could do | |
| - DO NOT say you can't create charts - YOU CAN! | |
| ## Chart Types: | |
| - **Line charts**: Time series, trends, stock prices over time | |
| - **Bar charts**: Comparisons, categories | |
| - **Scatter plots**: Correlations, relationships | |
| - **Pie charts**: Proportions, percentages | |
| ## Data Format: | |
| Charts work with ANY data that has numerical values: | |
| - **Line chart**: `data = {"Series": {"x_values": [...], "y_values": [...]}}` | |
| - **Bar chart**: `categories = [...]`, `values = [...]` | |
| - **Scatter plot**: `x_values = [...]`, `y_values = [...]` | |
| - **Pie chart**: `labels = [...]`, `values = [...]` | |
| ## Integration: | |
| This extension automatically receives formatted data from other extensions. | |
| When suggestions appear, follow them to create charts! | |
| ## Remember: | |
| - You CAN and SHOULD create visualizations | |
| - Charts display automatically to users | |
| - Extract data from previous tool results to visualize | |
| """ | |
| # ========================================== | |
| # STATE MANAGEMENT | |
| # ========================================== | |
| def _get_default_state(self) -> Dict[str, Any]: | |
| return { | |
| "charts": [], | |
| "output_dir": "visualizations", | |
| "total_created": 0, | |
| "created_at": datetime.datetime.now().isoformat(), | |
| "last_updated": datetime.datetime.now().isoformat() | |
| } | |
| def get_state_summary(self, user_id: str) -> Optional[str]: | |
| state = self.get_state(user_id) | |
| chart_count = len(state.get("charts", [])) | |
| if chart_count > 0: | |
| return f"{chart_count} visualizations created" | |
| return None | |
| def get_metrics(self, user_id: str) -> Dict[str, Any]: | |
| state = self.get_state(user_id) | |
| charts = state.get("charts", []) | |
| chart_types = {} | |
| for chart in charts: | |
| chart_type = chart.get("type", "unknown") | |
| chart_types[chart_type] = chart_types.get(chart_type, 0) + 1 | |
| return { | |
| "total_created": state.get("total_created", 0), | |
| "current_session": len(charts), | |
| "by_type": chart_types | |
| } | |
| # ========================================== | |
| # TOOLS | |
| # ========================================== | |
| def get_tools(self) -> List[types.Tool]: | |
| create_line_chart = types.FunctionDeclaration( | |
| name="create_line_chart", | |
| description="Create line chart for time series or trends. Perfect for stock prices, growth over time, etc.", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": { | |
| "type": "string", | |
| "description": "Chart title" | |
| }, | |
| "x_label": { | |
| "type": "string", | |
| "description": "X-axis label" | |
| }, | |
| "y_label": { | |
| "type": "string", | |
| "description": "Y-axis label" | |
| }, | |
| "data": { | |
| "type": "object", | |
| "description": "Data series: {series_name: {x_values: [...], y_values: [...]}}" | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "data"] | |
| } | |
| ) | |
| create_bar_chart = types.FunctionDeclaration( | |
| name="create_bar_chart", | |
| description="Create bar chart for categorical comparisons", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string", "description": "Chart title"}, | |
| "x_label": {"type": "string", "description": "X-axis label"}, | |
| "y_label": {"type": "string", "description": "Y-axis label"}, | |
| "categories": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Category names" | |
| }, | |
| "values": { | |
| "type": "array", | |
| "items": {"type": "number"}, | |
| "description": "Values for each category" | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "categories", "values"] | |
| } | |
| ) | |
| create_scatter_plot = types.FunctionDeclaration( | |
| name="create_scatter_plot", | |
| description="Create scatter plot to show relationships between two variables", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string"}, | |
| "x_label": {"type": "string"}, | |
| "y_label": {"type": "string"}, | |
| "x_values": { | |
| "type": "array", | |
| "items": {"type": "number"} | |
| }, | |
| "y_values": { | |
| "type": "array", | |
| "items": {"type": "number"} | |
| } | |
| }, | |
| "required": ["title", "x_label", "y_label", "x_values", "y_values"] | |
| } | |
| ) | |
| create_pie_chart = types.FunctionDeclaration( | |
| name="create_pie_chart", | |
| description="Create pie chart for proportions or percentages", | |
| parameters={ | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string"}, | |
| "labels": { | |
| "type": "array", | |
| "items": {"type": "string"} | |
| }, | |
| "values": { | |
| "type": "array", | |
| "items": {"type": "number"} | |
| } | |
| }, | |
| "required": ["title", "labels", "values"] | |
| } | |
| ) | |
| list_charts = types.FunctionDeclaration( | |
| name="list_visualizations", | |
| description="List all created visualizations", | |
| parameters={"type": "object", "properties": {}} | |
| ) | |
| return [types.Tool(function_declarations=[ | |
| create_line_chart, | |
| create_bar_chart, | |
| create_scatter_plot, | |
| create_pie_chart, | |
| list_charts | |
| ])] | |
| # ========================================== | |
| # HELPER METHODS | |
| # ========================================== | |
| def _save_chart(self, fig, user_id: str, chart_type: str, title: str) -> tuple: | |
| """Save chart and return filepath + base64 for inline display""" | |
| state = self.get_state(user_id) | |
| output_dir = Path(state["output_dir"]) | |
| output_dir.mkdir(exist_ok=True) | |
| # Generate filename | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip() | |
| safe_title = safe_title.replace(' ', '_')[:50] | |
| filename = f"{chart_type}_{safe_title}_{timestamp}.png" | |
| filepath = output_dir / filename | |
| # Save to file | |
| fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white') | |
| # Save to base64 for inline display | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| buf.close() | |
| plt.close(fig) | |
| # Track in state | |
| chart_info = { | |
| "type": chart_type, | |
| "title": title, | |
| "filepath": str(filepath), | |
| "timestamp": timestamp | |
| } | |
| state["charts"].append(chart_info) | |
| state["total_created"] = state.get("total_created", 0) + 1 | |
| self.update_state(user_id, state) | |
| return str(filepath), img_base64 | |
| # ========================================== | |
| # TOOL EXECUTION | |
| # ========================================== | |
| def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: | |
| """Execute visualization tools""" | |
| try: | |
| if tool_name == "create_line_chart": | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| data = args["data"] | |
| print(f"π Creating line chart with {len(data)} series") | |
| for series_name, series_data in data.items(): | |
| x_vals = series_data.get("x_values", []) | |
| y_vals = series_data.get("y_values", []) | |
| print(f" π Series '{series_name}': {len(x_vals)} points") | |
| ax.plot(x_vals, y_vals, marker='o', label=series_name, | |
| linewidth=2, markersize=4) | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold', pad=20) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "line", | |
| "title": args["title"], | |
| "series_count": len(data) | |
| }) | |
| print(f"β Line chart saved: {filepath}") | |
| return { | |
| "success": True, | |
| "message": f"Line chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "line", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_bar_chart": | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| categories = args["categories"] | |
| values = args["values"] | |
| print(f"π Creating bar chart with {len(categories)} categories") | |
| bars = ax.bar(categories, values, color='#2E86AB', alpha=0.8, edgecolor='black') | |
| # Add value labels on bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.2f}', | |
| ha='center', va='bottom', fontsize=9, fontweight='bold') | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| ax.grid(True, alpha=0.3, axis='y') | |
| if len(categories) > 5: | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "bar", | |
| "title": args["title"], | |
| "categories": len(categories) | |
| }) | |
| print(f"β Bar chart saved: {filepath}") | |
| return { | |
| "success": True, | |
| "message": f"Bar chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "bar", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_scatter_plot": | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| x_vals = args["x_values"] | |
| y_vals = args["y_values"] | |
| print(f"π Creating scatter plot with {len(x_vals)} points") | |
| ax.scatter(x_vals, y_vals, color='#A23B72', alpha=0.6, | |
| s=100, edgecolors='black') | |
| # Add trend line if enough points | |
| if len(x_vals) > 1: | |
| z = np.polyfit(x_vals, y_vals, 1) | |
| p = np.poly1d(z) | |
| ax.plot(x_vals, p(x_vals), "r--", alpha=0.5, | |
| linewidth=2, label='Trend Line') | |
| ax.legend() | |
| ax.set_xlabel(args["x_label"], fontsize=12, fontweight='bold') | |
| ax.set_ylabel(args["y_label"], fontsize=12, fontweight='bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "scatter", | |
| "title": args["title"], | |
| "points": len(x_vals) | |
| }) | |
| print(f"β Scatter plot saved: {filepath}") | |
| return { | |
| "success": True, | |
| "message": f"Scatter plot created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "scatter", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "create_pie_chart": | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| labels = args["labels"] | |
| values = args["values"] | |
| print(f"π Creating pie chart with {len(labels)} slices") | |
| colors = plt.cm.Set3(range(len(labels))) | |
| wedges, texts, autotexts = ax.pie( | |
| values, | |
| labels=labels, | |
| autopct='%1.1f%%', | |
| colors=colors, | |
| startangle=90, | |
| textprops={'fontsize': 11, 'fontweight': 'bold'} | |
| ) | |
| for text in texts: | |
| text.set_fontsize(11) | |
| for autotext in autotexts: | |
| autotext.set_color('white') | |
| autotext.set_fontsize(10) | |
| autotext.set_fontweight('bold') | |
| ax.set_title(args["title"], fontsize=14, fontweight='bold') | |
| filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"]) | |
| self.log_activity(user_id, "chart_created", { | |
| "type": "pie", | |
| "title": args["title"], | |
| "slices": len(labels) | |
| }) | |
| print(f"β Pie chart saved: {filepath}") | |
| return { | |
| "success": True, | |
| "message": f"Pie chart created: {args['title']}", | |
| "filepath": filepath, | |
| "chart_type": "pie", | |
| "image_base64": img_base64 | |
| } | |
| elif tool_name == "list_visualizations": | |
| state = self.get_state(user_id) | |
| charts = state.get("charts", []) | |
| return { | |
| "total_charts": len(charts), | |
| "charts": charts[-10:] # Last 10 | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"β Visualization error: {error_details}") | |
| return { | |
| "success": False, | |
| "error": f"Error creating visualization: {str(e)}", | |
| "details": error_details | |
| } | |
| return {"error": f"Unknown tool: {tool_name}"} | |
| def on_enable(self, user_id: str) -> str: | |
| self.initialize_state(user_id) | |
| return "π Data Visualization enabled! Create charts from any numerical data. Just ask!" | |
| def on_disable(self, user_id: str) -> str: | |
| state = self.get_state(user_id) | |
| total = state.get("total_created", 0) | |
| return f"π Visualization disabled. {total} charts created this session." | |
| def health_check(self, user_id: str) -> Dict[str, Any]: | |
| try: | |
| import matplotlib | |
| return { | |
| "healthy": True, | |
| "extension": self.name, | |
| "version": self.version, | |
| "matplotlib_available": True | |
| } | |
| except ImportError: | |
| return { | |
| "healthy": False, | |
| "extension": self.name, | |
| "version": self.version, | |
| "issues": ["matplotlib not installed"] | |
| } |