File size: 4,707 Bytes
b2e0e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# src/safety/sentinel.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration Constants
ANGER_THRESHOLD = 0.9
FACT_SCORE_THRESHOLD = 0.5
class SafetyStatus(Enum):
SAFE = "SAFE"
KILL_SWITCH_ACTIVE = "KILL_SWITCH_ACTIVE"
CHECK_FAILED = "CHECK_FAILED"
@dataclass
class AgentState:
user_history: List[str]
last_response: str
@dataclass
class SafetyCheckResult:
status: SafetyStatus
anger_score: Optional[float] = None
fact_score: Optional[float] = None
error_message: Optional[str] = None
class SentimentAnalyzer(ABC):
@abstractmethod
def analyze(self, user_history: List[str]) -> float:
"""Returns anger score between 0.0 and 1.0."""
pass
class FactVerifier(ABC):
@abstractmethod
def verify_facts(self, response: str) -> float:
"""Returns fact accuracy score between 0.0 and 1.0."""
pass
class AgentController(ABC):
@abstractmethod
def stop_agent(self) -> None:
pass
@abstractmethod
def alert_human_manager(self) -> None:
pass
@abstractmethod
def display_message(self, message: str) -> None:
pass
class SafetySentinel:
"""Monitors agent behavior and triggers safety responses."""
HANDOFF_MESSAGE = "I am having trouble. Connecting you to a human..."
def __init__(
self,
sentiment_analyzer: SentimentAnalyzer,
fact_verifier: FactVerifier,
agent_controller: AgentController,
anger_threshold: float = ANGER_THRESHOLD,
fact_score_threshold: float = FACT_SCORE_THRESHOLD
):
self._sentiment_analyzer = sentiment_analyzer
self._fact_verifier = fact_verifier
self._agent_controller = agent_controller
self._anger_threshold = anger_threshold
self._fact_score_threshold = fact_score_threshold
def safety_check(self, agent_state: AgentState) -> SafetyCheckResult:
"""Perform safety checks on the current agent state."""
if not agent_state:
logger.error("Invalid agent state provided")
return SafetyCheckResult(
status=SafetyStatus.CHECK_FAILED,
error_message="Agent state is required"
)
try:
# Check 1: Sentiment Analysis
user_anger = self._sentiment_analyzer.analyze(
agent_state.user_history
)
except Exception as e:
logger.exception("Sentiment analysis failed")
return SafetyCheckResult(
status=SafetyStatus.CHECK_FAILED,
error_message=f"Sentiment analysis error: {str(e)}"
)
try:
# Check 2: Hallucination Detection (Fact Check)
fact_score = self._fact_verifier.verify_facts(
agent_state.last_response
)
except Exception as e:
logger.exception("Fact verification failed")
return SafetyCheckResult(
status=SafetyStatus.CHECK_FAILED,
anger_score=user_anger,
error_message=f"Fact verification error: {str(e)}"
)
is_unsafe = (
user_anger > self._anger_threshold or
fact_score < self._fact_score_threshold
)
if is_unsafe:
logger.warning(
f"Safety threshold breached: anger={user_anger}, "
f"fact_score={fact_score}"
)
return SafetyCheckResult(
status=SafetyStatus.KILL_SWITCH_ACTIVE,
anger_score=user_anger,
fact_score=fact_score
)
return SafetyCheckResult(
status=SafetyStatus.SAFE,
anger_score=user_anger,
fact_score=fact_score
)
def execute_safety_response(self, agent_state: AgentState) -> SafetyCheckResult:
"""Check safety and execute appropriate response."""
result = self.safety_check(agent_state)
if result.status == SafetyStatus.KILL_SWITCH_ACTIVE:
try:
self._agent_controller.stop_agent()
self._agent_controller.alert_human_manager()
self._agent_controller.display_message(self.HANDOFF_MESSAGE)
logger.info("Kill switch activated - handed off to human")
except Exception as e:
logger.exception("Failed to execute safety response")
result.error_message = f"Safety response failed: {str(e)}"
return result
|