Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Base Extension System for Wuhp Agents | |
| Compatible with the modular app.py - combines state management with capability-based discovery. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, Any, List, Optional, Callable | |
| from google.genai import types | |
| import json | |
| import datetime | |
| from pathlib import Path | |
| class BaseExtension(ABC): | |
| """Enhanced base class for all agent extensions with capability-based discovery""" | |
| def __init__(self): | |
| self.enabled = False | |
| self.state: Dict[str, Any] = {} | |
| self._state_validators: Dict[str, Callable] = {} | |
| self._state_hooks: Dict[str, List[Callable]] = { | |
| 'before_update': [], | |
| 'after_update': [], | |
| 'before_tool_call': [], | |
| 'after_tool_call': [] | |
| } | |
| # ========================================== | |
| # REQUIRED PROPERTIES (Use @property decorator) | |
| # ========================================== | |
| def name(self) -> str: | |
| """Unique identifier for the extension (lowercase, no spaces)""" | |
| pass | |
| def display_name(self) -> str: | |
| """Human-readable name shown in UI""" | |
| pass | |
| def description(self) -> str: | |
| """Brief description of what the extension does""" | |
| pass | |
| def icon(self) -> str: | |
| """Emoji icon for the extension""" | |
| return "🔧" | |
| def version(self) -> str: | |
| """Extension version for compatibility checking""" | |
| return "1.0.0" | |
| # ========================================== | |
| # REQUIRED ABSTRACT METHODS | |
| # ========================================== | |
| def get_system_context(self) -> str: | |
| """Returns context to inject into system prompt when enabled""" | |
| pass | |
| def get_tools(self) -> List[types.Tool]: | |
| """Returns Gemini function calling tools for this extension""" | |
| pass | |
| # ========================================== | |
| # CAPABILITY-BASED DISCOVERY SYSTEM (NEW!) | |
| # ========================================== | |
| def get_capabilities(self) -> Dict[str, Any]: | |
| """ | |
| Declare what this extension can do. | |
| This enables automatic integration without hardcoding. | |
| Returns dict with capabilities like: | |
| { | |
| 'provides_data': ['stock_history', 'financial_data'], | |
| 'consumes_data': ['numerical_series', 'comparison_data'], | |
| 'creates_output': ['visualization', 'report'], | |
| 'keywords': ['stock', 'price', 'chart', 'graph'], | |
| 'data_outputs': { | |
| 'stock_history': { | |
| 'format': 'time_series', | |
| 'fields': ['dates', 'close_prices', 'ticker'] | |
| } | |
| } | |
| } | |
| Override this in your extension to declare capabilities! | |
| """ | |
| return { | |
| 'provides_data': [], | |
| 'consumes_data': [], | |
| 'creates_output': [], | |
| 'keywords': [], | |
| 'data_outputs': {} | |
| } | |
| def can_consume(self, data_type: str) -> bool: | |
| """Check if this extension can consume a specific data type""" | |
| return data_type in self.get_capabilities().get('consumes_data', []) | |
| def can_provide(self, data_type: str) -> bool: | |
| """Check if this extension can provide a specific data type""" | |
| return data_type in self.get_capabilities().get('provides_data', []) | |
| def can_create(self, output_type: str) -> bool: | |
| """Check if this extension can create a specific output""" | |
| return output_type in self.get_capabilities().get('creates_output', []) | |
| def get_suggested_next_action(self, tool_result: Dict[str, Any], | |
| available_extensions: List['BaseExtension']) -> Optional[Dict[str, Any]]: | |
| """ | |
| Given a tool result, suggest what should happen next. | |
| Returns None or dict with: {'extension': ext_name, 'tool': tool_name, 'reason': str, 'data': dict} | |
| This makes extensions self-aware of their integration opportunities! | |
| Example implementation: | |
| ```python | |
| # Find visualization extension by capability (NO HARDCODING!) | |
| viz_ext = next((ext for ext in available_extensions | |
| if ext.can_create('visualization')), None) | |
| if viz_ext and 'dates' in tool_result: | |
| return { | |
| 'extension': viz_ext.name, | |
| 'tool': 'create_line_chart', | |
| 'reason': 'Data is ready for time-series visualization', | |
| 'data': tool_result | |
| } | |
| ``` | |
| Override this in your extension to suggest next actions! | |
| """ | |
| return None | |
| # ========================================== | |
| # STATE MANAGEMENT | |
| # ========================================== | |
| def get_state_summary(self, user_id: str) -> Optional[str]: | |
| """ | |
| Override to provide a human-readable summary of current state. | |
| This will be included in the system prompt for context awareness. | |
| Example: "You have 2 active timers and 5 pending tasks" | |
| """ | |
| return None | |
| def initialize_state(self, user_id: str) -> None: | |
| """Initialize empty state for a new user""" | |
| if user_id not in self.state: | |
| self.state[user_id] = self._get_default_state() | |
| self._run_hooks('after_update', user_id, self.state[user_id]) | |
| def _get_default_state(self) -> Dict[str, Any]: | |
| """Override to provide default state structure""" | |
| return { | |
| 'created_at': datetime.datetime.now().isoformat(), | |
| 'last_updated': datetime.datetime.now().isoformat() | |
| } | |
| def get_state(self, user_id: str) -> Dict[str, Any]: | |
| """Get state for a specific user""" | |
| if user_id not in self.state: | |
| self.initialize_state(user_id) | |
| return self.state.get(user_id, {}) | |
| def update_state(self, user_id: str, updates: Dict[str, Any]) -> None: | |
| """Update state for a specific user with validation and hooks""" | |
| if user_id not in self.state: | |
| self.initialize_state(user_id) | |
| # Run before_update hooks | |
| self._run_hooks('before_update', user_id, updates) | |
| # Validate updates | |
| self._validate_state_updates(updates) | |
| # Update timestamp | |
| updates['last_updated'] = datetime.datetime.now().isoformat() | |
| # Apply updates | |
| self.state[user_id].update(updates) | |
| # Run after_update hooks | |
| self._run_hooks('after_update', user_id, self.state[user_id]) | |
| def _validate_state_updates(self, updates: Dict[str, Any]) -> None: | |
| """Validate state updates using registered validators""" | |
| for key, value in updates.items(): | |
| if key in self._state_validators: | |
| validator = self._state_validators[key] | |
| if not validator(value): | |
| raise ValueError(f"Invalid value for state key '{key}': {value}") | |
| def register_state_validator(self, key: str, validator: Callable[[Any], bool]) -> None: | |
| """Register a validation function for a state key""" | |
| self._state_validators[key] = validator | |
| def add_hook(self, hook_type: str, func: Callable) -> None: | |
| """Add a hook function to be called at specific points""" | |
| if hook_type in self._state_hooks: | |
| self._state_hooks[hook_type].append(func) | |
| def _run_hooks(self, hook_type: str, *args, **kwargs) -> None: | |
| """Run all registered hooks of a specific type""" | |
| for hook_func in self._state_hooks.get(hook_type, []): | |
| try: | |
| hook_func(*args, **kwargs) | |
| except Exception as e: | |
| print(f"Hook error in {self.name}.{hook_type}: {e}") | |
| # ========================================== | |
| # TOOL EXECUTION | |
| # ========================================== | |
| def handle_tool_call(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: | |
| """ | |
| Handle a tool call from Gemini with hooks. | |
| Override this to implement tool logic, OR override _execute_tool instead. | |
| """ | |
| # Run before_tool_call hooks | |
| self._run_hooks('before_tool_call', user_id, tool_name, args) | |
| try: | |
| # Call the actual tool implementation | |
| result = self._execute_tool(user_id, tool_name, args) | |
| # Run after_tool_call hooks | |
| self._run_hooks('after_tool_call', user_id, tool_name, args, result) | |
| return result | |
| except Exception as e: | |
| error_result = { | |
| "success": False, | |
| "error": str(e), | |
| "tool": tool_name | |
| } | |
| self._run_hooks('after_tool_call', user_id, tool_name, args, error_result) | |
| return error_result | |
| def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: | |
| """ | |
| Override this method to implement actual tool logic. | |
| This is called by handle_tool_call after running before hooks. | |
| """ | |
| return {"error": f"Tool {tool_name} not implemented"} | |
| def get_tool_by_name(self, tool_name: str) -> Optional[types.FunctionDeclaration]: | |
| """Helper to find a specific tool declaration by name""" | |
| for tool in self.get_tools(): | |
| if hasattr(tool, 'function_declarations'): | |
| for func_decl in tool.function_declarations: | |
| if func_decl.name == tool_name: | |
| return func_decl | |
| return None | |
| # ========================================== | |
| # LIFECYCLE HOOKS | |
| # ========================================== | |
| def on_enable(self, user_id: str) -> Optional[str]: | |
| """ | |
| Called when extension is enabled for a user. | |
| Return a message to show to the user, or None. | |
| """ | |
| self.initialize_state(user_id) | |
| self.enabled = True | |
| return None | |
| def on_disable(self, user_id: str) -> Optional[str]: | |
| """ | |
| Called when extension is disabled for a user. | |
| Return a message to show to the user, or None. | |
| """ | |
| self.enabled = False | |
| return None | |
| def get_proactive_message(self, user_id: str) -> Optional[str]: | |
| """ | |
| Called periodically to check if extension wants to proactively message user. | |
| Return message string or None. | |
| Override this to implement proactive notifications (timers, reminders, etc.) | |
| """ | |
| return None | |
| # ========================================== | |
| # PERSISTENCE | |
| # ========================================== | |
| def serialize_state(self, user_id: str) -> str: | |
| """Serialize state to JSON for persistence""" | |
| return json.dumps(self.get_state(user_id), indent=2, default=str) | |
| def deserialize_state(self, user_id: str, state_json: str) -> None: | |
| """Load state from JSON""" | |
| try: | |
| loaded = json.loads(state_json) | |
| self.state[user_id] = loaded | |
| except Exception as e: | |
| print(f"Error loading state for {self.name}: {e}") | |
| def save_state_to_file(self, user_id: str, filepath: Optional[str] = None) -> str: | |
| """Save state to a file""" | |
| if filepath is None: | |
| filepath = f"state_{self.name}_{user_id}.json" | |
| state_path = Path(filepath) | |
| state_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(state_path, 'w') as f: | |
| f.write(self.serialize_state(user_id)) | |
| return str(state_path) | |
| def load_state_from_file(self, user_id: str, filepath: str) -> bool: | |
| """Load state from a file""" | |
| try: | |
| with open(filepath, 'r') as f: | |
| state_json = f.read() | |
| self.deserialize_state(user_id, state_json) | |
| return True | |
| except Exception as e: | |
| print(f"Error loading state from {filepath}: {e}") | |
| return False | |
| def clear_state(self, user_id: str) -> None: | |
| """Clear all state for a user (useful for testing/reset)""" | |
| self.state[user_id] = self._get_default_state() | |
| def export_state(self, user_id: str) -> Dict[str, Any]: | |
| """Export state in a format suitable for external use""" | |
| return { | |
| 'extension': self.name, | |
| 'version': self.version, | |
| 'exported_at': datetime.datetime.now().isoformat(), | |
| 'state': self.get_state(user_id) | |
| } | |
| def import_state(self, user_id: str, exported_data: Dict[str, Any]) -> bool: | |
| """Import state from exported data with version checking""" | |
| try: | |
| if exported_data.get('extension') != self.name: | |
| print(f"Extension name mismatch: {exported_data.get('extension')} != {self.name}") | |
| return False | |
| # Could add version compatibility checks here | |
| self.state[user_id] = exported_data['state'] | |
| return True | |
| except Exception as e: | |
| print(f"Error importing state: {e}") | |
| return False | |
| # ========================================== | |
| # DEPENDENCIES & VALIDATION | |
| # ========================================== | |
| def get_dependencies(self) -> List[str]: | |
| """ | |
| Return list of other extension names this extension depends on. | |
| The orchestrator can use this to ensure dependencies are loaded. | |
| """ | |
| return [] | |
| def validate_dependencies(self, available_extensions: List[str]) -> bool: | |
| """Check if all required dependencies are available""" | |
| deps = self.get_dependencies() | |
| return all(dep in available_extensions for dep in deps) | |
| # ========================================== | |
| # LOGGING & METRICS | |
| # ========================================== | |
| def log_activity(self, user_id: str, activity: str, details: Dict[str, Any] = None) -> None: | |
| """ | |
| Log extension activity for debugging/auditing. | |
| Override to implement custom logging. | |
| """ | |
| timestamp = datetime.datetime.now().isoformat() | |
| log_entry = { | |
| 'timestamp': timestamp, | |
| 'extension': self.name, | |
| 'activity': activity, | |
| 'details': details or {} | |
| } | |
| # Store in state | |
| state = self.get_state(user_id) | |
| if 'activity_log' not in state: | |
| state['activity_log'] = [] | |
| state['activity_log'].append(log_entry) | |
| # Keep only last 100 entries | |
| if len(state['activity_log']) > 100: | |
| state['activity_log'] = state['activity_log'][-100:] | |
| def get_recent_activity(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]: | |
| """Get recent activity log entries""" | |
| state = self.get_state(user_id) | |
| activity_log = state.get('activity_log', []) | |
| return activity_log[-limit:] | |
| def get_metrics(self, user_id: str) -> Dict[str, Any]: | |
| """ | |
| Override to provide usage metrics/statistics. | |
| Example: {"total_timers_created": 15, "active_timers": 2} | |
| """ | |
| return {} | |
| def health_check(self, user_id: str) -> Dict[str, Any]: | |
| """ | |
| Perform a health check on the extension state. | |
| Returns dict with 'healthy': bool and optional 'issues': list | |
| """ | |
| return { | |
| 'healthy': True, | |
| 'extension': self.name, | |
| 'version': self.version | |
| } | |
| # ========================================== | |
| # HELPER FUNCTIONS FOR CAPABILITY MATCHING | |
| # ========================================== | |
| def find_extensions_with_capability(extensions: List[BaseExtension], | |
| capability_type: str, | |
| capability_value: str) -> List[BaseExtension]: | |
| """ | |
| Helper function to find extensions with a specific capability. | |
| Args: | |
| extensions: List of extension instances to search | |
| capability_type: 'provides_data', 'consumes_data', or 'creates_output' | |
| capability_value: The specific capability to search for | |
| Returns: | |
| List of extensions that have the specified capability | |
| Example: | |
| viz_extensions = find_extensions_with_capability( | |
| all_extensions, 'creates_output', 'visualization' | |
| ) | |
| """ | |
| matching = [] | |
| for ext in extensions: | |
| caps = ext.get_capabilities().get(capability_type, []) | |
| if capability_value in caps: | |
| matching.append(ext) | |
| return matching | |
| def get_data_flow_possibilities(extensions: List[BaseExtension]) -> List[Dict[str, Any]]: | |
| """ | |
| Analyze a list of extensions and return all possible data flow chains. | |
| Returns list of dicts with: | |
| { | |
| 'provider': extension_name, | |
| 'data_type': data_type, | |
| 'consumers': [list of extension names that can consume this data] | |
| } | |
| Example usage: | |
| flows = get_data_flow_possibilities(enabled_extensions) | |
| for flow in flows: | |
| print(f"{flow['provider']} produces {flow['data_type']}") | |
| print(f" → can be consumed by: {', '.join(flow['consumers'])}") | |
| """ | |
| flows = [] | |
| for provider in extensions: | |
| provided_data = provider.get_capabilities().get('provides_data', []) | |
| for data_type in provided_data: | |
| consumers = [ | |
| ext.name for ext in extensions | |
| if ext.can_consume(data_type) and ext.name != provider.name | |
| ] | |
| if consumers: | |
| flows.append({ | |
| 'provider': provider.name, | |
| 'data_type': data_type, | |
| 'consumers': consumers | |
| }) | |
| return flows | |
| def detect_relevant_extensions(query: str, extensions: List[BaseExtension]) -> List[str]: | |
| """ | |
| Detect which extensions are relevant to a query based on keywords. | |
| Args: | |
| query: The user's query string | |
| extensions: List of available extensions | |
| Returns: | |
| List of extension names that are relevant to the query | |
| Example: | |
| query = "Show me a chart of AAPL stock prices" | |
| relevant = detect_relevant_extensions(query, all_extensions) | |
| # Returns: ['yfinance', 'visualization'] | |
| """ | |
| query_lower = query.lower() | |
| relevant = [] | |
| for ext in extensions: | |
| keywords = ext.get_capabilities().get('keywords', []) | |
| if any(keyword.lower() in query_lower for keyword in keywords): | |
| relevant.append(ext.name) | |
| return relevant |