File size: 4,716 Bytes
b2e0e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# src/safety/escalation_ladder.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RiskLevel(Enum):
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
CRITICAL = "CRITICAL"
class ExecutionStatus(Enum):
SUCCESS = "SUCCESS"
WAITING_FOR_APPROVAL = "WAITING_FOR_APPROVAL"
REJECTED = "REJECTED"
ERROR = "ERROR"
@dataclass
class ExecutionResult:
status: ExecutionStatus
result: Optional[Any] = None
error_message: Optional[str] = None
decision_id: Optional[str] = None
class RiskPolicy(ABC):
@abstractmethod
def get_level(self, tool_name: str) -> RiskLevel:
pass
class Tool(ABC):
@abstractmethod
def run(self, parameters: dict) -> Any:
pass
class ApprovalService(ABC):
@abstractmethod
def create_approval_request(self, tool_name: str, parameters: dict) -> str:
pass
@abstractmethod
def notify_manager(self, decision_id: str) -> None:
pass
class EscalationLadder:
def __init__(
self,
risk_policy: RiskPolicy,
tool: Tool,
approval_service: ApprovalService
):
self._risk_policy = risk_policy
self._tool = tool
self._approval_service = approval_service
def execute_tool(self, tool_name: str, parameters: dict) -> ExecutionResult:
"""Execute a tool with appropriate risk-based escalation."""
if not tool_name or not isinstance(tool_name, str):
logger.error("Invalid tool_name provided")
return ExecutionResult(
status=ExecutionStatus.ERROR,
error_message="Invalid tool_name: must be a non-empty string"
)
if not isinstance(parameters, dict):
logger.error("Invalid parameters provided")
return ExecutionResult(
status=ExecutionStatus.ERROR,
error_message="Invalid parameters: must be a dictionary"
)
try:
risk_level = self._risk_policy.get_level(tool_name)
except Exception as e:
logger.exception("Failed to assess risk level")
return ExecutionResult(
status=ExecutionStatus.ERROR,
error_message=f"Risk assessment failed: {str(e)}"
)
if risk_level == RiskLevel.CRITICAL:
logger.warning(f"CRITICAL risk tool '{tool_name}' blocked")
return ExecutionResult(
status=ExecutionStatus.REJECTED,
error_message="Critical risk tools are not permitted"
)
if risk_level == RiskLevel.HIGH:
return self._handle_high_risk(tool_name, parameters)
if risk_level == RiskLevel.MEDIUM:
logger.info(f"MEDIUM risk tool '{tool_name}' - logging for audit")
return self._execute_with_logging(tool_name, parameters)
# LOW risk - execute immediately
return self._execute_tool(parameters)
def _handle_high_risk(self, tool_name: str, parameters: dict) -> ExecutionResult:
"""Handle high-risk tool execution with approval workflow."""
try:
decision_id = self._approval_service.create_approval_request(
tool_name, parameters
)
self._approval_service.notify_manager(decision_id)
logger.info(f"Approval request created: {decision_id}")
return ExecutionResult(
status=ExecutionStatus.WAITING_FOR_APPROVAL,
decision_id=decision_id
)
except Exception as e:
logger.exception("Failed to create approval request")
return ExecutionResult(
status=ExecutionStatus.ERROR,
error_message=f"Approval workflow failed: {str(e)}"
)
def _execute_with_logging(self, tool_name: str, parameters: dict) -> ExecutionResult:
"""Execute tool with enhanced audit logging."""
logger.info(f"Executing medium-risk tool: {tool_name}")
return self._execute_tool(parameters)
def _execute_tool(self, parameters: dict) -> ExecutionResult:
"""Execute the tool and return result."""
try:
result = self._tool.run(parameters)
return ExecutionResult(status=ExecutionStatus.SUCCESS, result=result)
except Exception as e:
logger.exception("Tool execution failed")
return ExecutionResult(
status=ExecutionStatus.ERROR,
error_message=f"Execution failed: {str(e)}"
)
|