wuhp-agents / base_extension.py
wuhp's picture
Update base_extension.py
f626045 verified
"""
Enhanced Base Extension System for Wuhp Agents
Compatible with the modular app.py - combines state management with capability-based discovery.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Callable
from google.genai import types
import json
import datetime
from pathlib import Path
class BaseExtension(ABC):
"""Enhanced base class for all agent extensions with capability-based discovery"""
def __init__(self):
self.enabled = False
self.state: Dict[str, Any] = {}
self._state_validators: Dict[str, Callable] = {}
self._state_hooks: Dict[str, List[Callable]] = {
'before_update': [],
'after_update': [],
'before_tool_call': [],
'after_tool_call': []
}
# ==========================================
# REQUIRED PROPERTIES (Use @property decorator)
# ==========================================
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for the extension (lowercase, no spaces)"""
pass
@property
@abstractmethod
def display_name(self) -> str:
"""Human-readable name shown in UI"""
pass
@property
@abstractmethod
def description(self) -> str:
"""Brief description of what the extension does"""
pass
@property
def icon(self) -> str:
"""Emoji icon for the extension"""
return "🔧"
@property
def version(self) -> str:
"""Extension version for compatibility checking"""
return "1.0.0"
# ==========================================
# REQUIRED ABSTRACT METHODS
# ==========================================
@abstractmethod
def get_system_context(self) -> str:
"""Returns context to inject into system prompt when enabled"""
pass
@abstractmethod
def get_tools(self) -> List[types.Tool]:
"""Returns Gemini function calling tools for this extension"""
pass
# ==========================================
# CAPABILITY-BASED DISCOVERY SYSTEM (NEW!)
# ==========================================
def get_capabilities(self) -> Dict[str, Any]:
"""
Declare what this extension can do.
This enables automatic integration without hardcoding.
Returns dict with capabilities like:
{
'provides_data': ['stock_history', 'financial_data'],
'consumes_data': ['numerical_series', 'comparison_data'],
'creates_output': ['visualization', 'report'],
'keywords': ['stock', 'price', 'chart', 'graph'],
'data_outputs': {
'stock_history': {
'format': 'time_series',
'fields': ['dates', 'close_prices', 'ticker']
}
}
}
Override this in your extension to declare capabilities!
"""
return {
'provides_data': [],
'consumes_data': [],
'creates_output': [],
'keywords': [],
'data_outputs': {}
}
def can_consume(self, data_type: str) -> bool:
"""Check if this extension can consume a specific data type"""
return data_type in self.get_capabilities().get('consumes_data', [])
def can_provide(self, data_type: str) -> bool:
"""Check if this extension can provide a specific data type"""
return data_type in self.get_capabilities().get('provides_data', [])
def can_create(self, output_type: str) -> bool:
"""Check if this extension can create a specific output"""
return output_type in self.get_capabilities().get('creates_output', [])
def get_suggested_next_action(self, tool_result: Dict[str, Any],
available_extensions: List['BaseExtension']) -> Optional[Dict[str, Any]]:
"""
Given a tool result, suggest what should happen next.
Returns None or dict with: {'extension': ext_name, 'tool': tool_name, 'reason': str, 'data': dict}
This makes extensions self-aware of their integration opportunities!
Example implementation:
```python
# Find visualization extension by capability (NO HARDCODING!)
viz_ext = next((ext for ext in available_extensions
if ext.can_create('visualization')), None)
if viz_ext and 'dates' in tool_result:
return {
'extension': viz_ext.name,
'tool': 'create_line_chart',
'reason': 'Data is ready for time-series visualization',
'data': tool_result
}
```
Override this in your extension to suggest next actions!
"""
return None
# ==========================================
# STATE MANAGEMENT
# ==========================================
def get_state_summary(self, user_id: str) -> Optional[str]:
"""
Override to provide a human-readable summary of current state.
This will be included in the system prompt for context awareness.
Example: "You have 2 active timers and 5 pending tasks"
"""
return None
def initialize_state(self, user_id: str) -> None:
"""Initialize empty state for a new user"""
if user_id not in self.state:
self.state[user_id] = self._get_default_state()
self._run_hooks('after_update', user_id, self.state[user_id])
def _get_default_state(self) -> Dict[str, Any]:
"""Override to provide default state structure"""
return {
'created_at': datetime.datetime.now().isoformat(),
'last_updated': datetime.datetime.now().isoformat()
}
def get_state(self, user_id: str) -> Dict[str, Any]:
"""Get state for a specific user"""
if user_id not in self.state:
self.initialize_state(user_id)
return self.state.get(user_id, {})
def update_state(self, user_id: str, updates: Dict[str, Any]) -> None:
"""Update state for a specific user with validation and hooks"""
if user_id not in self.state:
self.initialize_state(user_id)
# Run before_update hooks
self._run_hooks('before_update', user_id, updates)
# Validate updates
self._validate_state_updates(updates)
# Update timestamp
updates['last_updated'] = datetime.datetime.now().isoformat()
# Apply updates
self.state[user_id].update(updates)
# Run after_update hooks
self._run_hooks('after_update', user_id, self.state[user_id])
def _validate_state_updates(self, updates: Dict[str, Any]) -> None:
"""Validate state updates using registered validators"""
for key, value in updates.items():
if key in self._state_validators:
validator = self._state_validators[key]
if not validator(value):
raise ValueError(f"Invalid value for state key '{key}': {value}")
def register_state_validator(self, key: str, validator: Callable[[Any], bool]) -> None:
"""Register a validation function for a state key"""
self._state_validators[key] = validator
def add_hook(self, hook_type: str, func: Callable) -> None:
"""Add a hook function to be called at specific points"""
if hook_type in self._state_hooks:
self._state_hooks[hook_type].append(func)
def _run_hooks(self, hook_type: str, *args, **kwargs) -> None:
"""Run all registered hooks of a specific type"""
for hook_func in self._state_hooks.get(hook_type, []):
try:
hook_func(*args, **kwargs)
except Exception as e:
print(f"Hook error in {self.name}.{hook_type}: {e}")
# ==========================================
# TOOL EXECUTION
# ==========================================
def handle_tool_call(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
"""
Handle a tool call from Gemini with hooks.
Override this to implement tool logic, OR override _execute_tool instead.
"""
# Run before_tool_call hooks
self._run_hooks('before_tool_call', user_id, tool_name, args)
try:
# Call the actual tool implementation
result = self._execute_tool(user_id, tool_name, args)
# Run after_tool_call hooks
self._run_hooks('after_tool_call', user_id, tool_name, args, result)
return result
except Exception as e:
error_result = {
"success": False,
"error": str(e),
"tool": tool_name
}
self._run_hooks('after_tool_call', user_id, tool_name, args, error_result)
return error_result
def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
"""
Override this method to implement actual tool logic.
This is called by handle_tool_call after running before hooks.
"""
return {"error": f"Tool {tool_name} not implemented"}
def get_tool_by_name(self, tool_name: str) -> Optional[types.FunctionDeclaration]:
"""Helper to find a specific tool declaration by name"""
for tool in self.get_tools():
if hasattr(tool, 'function_declarations'):
for func_decl in tool.function_declarations:
if func_decl.name == tool_name:
return func_decl
return None
# ==========================================
# LIFECYCLE HOOKS
# ==========================================
def on_enable(self, user_id: str) -> Optional[str]:
"""
Called when extension is enabled for a user.
Return a message to show to the user, or None.
"""
self.initialize_state(user_id)
self.enabled = True
return None
def on_disable(self, user_id: str) -> Optional[str]:
"""
Called when extension is disabled for a user.
Return a message to show to the user, or None.
"""
self.enabled = False
return None
def get_proactive_message(self, user_id: str) -> Optional[str]:
"""
Called periodically to check if extension wants to proactively message user.
Return message string or None.
Override this to implement proactive notifications (timers, reminders, etc.)
"""
return None
# ==========================================
# PERSISTENCE
# ==========================================
def serialize_state(self, user_id: str) -> str:
"""Serialize state to JSON for persistence"""
return json.dumps(self.get_state(user_id), indent=2, default=str)
def deserialize_state(self, user_id: str, state_json: str) -> None:
"""Load state from JSON"""
try:
loaded = json.loads(state_json)
self.state[user_id] = loaded
except Exception as e:
print(f"Error loading state for {self.name}: {e}")
def save_state_to_file(self, user_id: str, filepath: Optional[str] = None) -> str:
"""Save state to a file"""
if filepath is None:
filepath = f"state_{self.name}_{user_id}.json"
state_path = Path(filepath)
state_path.parent.mkdir(parents=True, exist_ok=True)
with open(state_path, 'w') as f:
f.write(self.serialize_state(user_id))
return str(state_path)
def load_state_from_file(self, user_id: str, filepath: str) -> bool:
"""Load state from a file"""
try:
with open(filepath, 'r') as f:
state_json = f.read()
self.deserialize_state(user_id, state_json)
return True
except Exception as e:
print(f"Error loading state from {filepath}: {e}")
return False
def clear_state(self, user_id: str) -> None:
"""Clear all state for a user (useful for testing/reset)"""
self.state[user_id] = self._get_default_state()
def export_state(self, user_id: str) -> Dict[str, Any]:
"""Export state in a format suitable for external use"""
return {
'extension': self.name,
'version': self.version,
'exported_at': datetime.datetime.now().isoformat(),
'state': self.get_state(user_id)
}
def import_state(self, user_id: str, exported_data: Dict[str, Any]) -> bool:
"""Import state from exported data with version checking"""
try:
if exported_data.get('extension') != self.name:
print(f"Extension name mismatch: {exported_data.get('extension')} != {self.name}")
return False
# Could add version compatibility checks here
self.state[user_id] = exported_data['state']
return True
except Exception as e:
print(f"Error importing state: {e}")
return False
# ==========================================
# DEPENDENCIES & VALIDATION
# ==========================================
def get_dependencies(self) -> List[str]:
"""
Return list of other extension names this extension depends on.
The orchestrator can use this to ensure dependencies are loaded.
"""
return []
def validate_dependencies(self, available_extensions: List[str]) -> bool:
"""Check if all required dependencies are available"""
deps = self.get_dependencies()
return all(dep in available_extensions for dep in deps)
# ==========================================
# LOGGING & METRICS
# ==========================================
def log_activity(self, user_id: str, activity: str, details: Dict[str, Any] = None) -> None:
"""
Log extension activity for debugging/auditing.
Override to implement custom logging.
"""
timestamp = datetime.datetime.now().isoformat()
log_entry = {
'timestamp': timestamp,
'extension': self.name,
'activity': activity,
'details': details or {}
}
# Store in state
state = self.get_state(user_id)
if 'activity_log' not in state:
state['activity_log'] = []
state['activity_log'].append(log_entry)
# Keep only last 100 entries
if len(state['activity_log']) > 100:
state['activity_log'] = state['activity_log'][-100:]
def get_recent_activity(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""Get recent activity log entries"""
state = self.get_state(user_id)
activity_log = state.get('activity_log', [])
return activity_log[-limit:]
def get_metrics(self, user_id: str) -> Dict[str, Any]:
"""
Override to provide usage metrics/statistics.
Example: {"total_timers_created": 15, "active_timers": 2}
"""
return {}
def health_check(self, user_id: str) -> Dict[str, Any]:
"""
Perform a health check on the extension state.
Returns dict with 'healthy': bool and optional 'issues': list
"""
return {
'healthy': True,
'extension': self.name,
'version': self.version
}
# ==========================================
# HELPER FUNCTIONS FOR CAPABILITY MATCHING
# ==========================================
def find_extensions_with_capability(extensions: List[BaseExtension],
capability_type: str,
capability_value: str) -> List[BaseExtension]:
"""
Helper function to find extensions with a specific capability.
Args:
extensions: List of extension instances to search
capability_type: 'provides_data', 'consumes_data', or 'creates_output'
capability_value: The specific capability to search for
Returns:
List of extensions that have the specified capability
Example:
viz_extensions = find_extensions_with_capability(
all_extensions, 'creates_output', 'visualization'
)
"""
matching = []
for ext in extensions:
caps = ext.get_capabilities().get(capability_type, [])
if capability_value in caps:
matching.append(ext)
return matching
def get_data_flow_possibilities(extensions: List[BaseExtension]) -> List[Dict[str, Any]]:
"""
Analyze a list of extensions and return all possible data flow chains.
Returns list of dicts with:
{
'provider': extension_name,
'data_type': data_type,
'consumers': [list of extension names that can consume this data]
}
Example usage:
flows = get_data_flow_possibilities(enabled_extensions)
for flow in flows:
print(f"{flow['provider']} produces {flow['data_type']}")
print(f" → can be consumed by: {', '.join(flow['consumers'])}")
"""
flows = []
for provider in extensions:
provided_data = provider.get_capabilities().get('provides_data', [])
for data_type in provided_data:
consumers = [
ext.name for ext in extensions
if ext.can_consume(data_type) and ext.name != provider.name
]
if consumers:
flows.append({
'provider': provider.name,
'data_type': data_type,
'consumers': consumers
})
return flows
def detect_relevant_extensions(query: str, extensions: List[BaseExtension]) -> List[str]:
"""
Detect which extensions are relevant to a query based on keywords.
Args:
query: The user's query string
extensions: List of available extensions
Returns:
List of extension names that are relevant to the query
Example:
query = "Show me a chart of AAPL stock prices"
relevant = detect_relevant_extensions(query, all_extensions)
# Returns: ['yfinance', 'visualization']
"""
query_lower = query.lower()
relevant = []
for ext in extensions:
keywords = ext.get_capabilities().get('keywords', [])
if any(keyword.lower() in query_lower for keyword in keywords):
relevant.append(ext.name)
return relevant