File size: 11,233 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
# evoagentx/hitl/interceptor_agent.py
import asyncio
import sys
from typing import Tuple
from ..agents.agent import Agent
from ..actions.action import Action
from .approval_manager import HITLManager
from .hitl import HITLInteractionType, HITLMode, HITLDecision
from ..core.registry import MODULE_REGISTRY
from ..core.logging import logger
class HITLInterceptorAction(Action):
"""HITL Interceptor Action"""
def __init__(
self,
target_agent_name: str,
target_action_name: str,
name: str = None,
description: str = "A pre-defined action to proceed the Human-In-The-Loop",
interaction_type: HITLInteractionType = HITLInteractionType.APPROVE_REJECT,
mode: HITLMode = HITLMode.PRE_EXECUTION,
**kwargs
):
if not name:
name = f"hitl_intercept_{target_agent_name}_{target_action_name}_mode_{mode.value}_action"
super().__init__(
name=name,
description=description,
**kwargs
)
self.target_agent_name = target_agent_name
self.target_action_name = target_action_name
self.interaction_type = interaction_type
self.mode = mode
def execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]:
try:
# get current running loop
loop = asyncio.get_running_loop()
if loop:
pass
# if in async context, cannot use asyncio.run()
raise RuntimeError("Cannot use asyncio.run() in async context. Use async_execute directly.")
except RuntimeError:
# if not in async context, use asyncio.run()
return asyncio.run(self.async_execute(llm, inputs, hitl_manager, sys_msg=sys_msg, **kwargs))
async def async_execute(self, llm, inputs: dict, hitl_manager:HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]:
"""
Asynchronous execution of HITL Interceptor
"""
task_name = kwargs.get('wf_task', 'Unknown Task')
workflow_goal = kwargs.get('wf_goal', None)
# request HITL approval
response = await hitl_manager.request_approval(
task_name=task_name,
agent_name=self.target_agent_name,
action_name=self.target_action_name,
interaction_type=self.interaction_type,
mode=self.mode,
action_inputs_data=inputs,
workflow_goal=workflow_goal
)
result = {
"hitl_decision": response.decision,
"target_agent": self.target_agent_name,
"target_action": self.target_action_name,
"hitl_feedback": response.feedback
}
for output_name in self.outputs_format.get_attrs():
try:
result |= {output_name: inputs[hitl_manager.hitl_input_output_mapping[output_name]]}
except Exception as e:
logger.exception(e)
prompt = f"HITL Interceptor executed for {self.target_agent_name}.{self.target_action_name}"
if result["hitl_decision"] == HITLDecision.APPROVE:
prompt += "\nHITL approved, the action will be executed"
return result, prompt
elif result["hitl_decision"] == HITLDecision.REJECT:
prompt += "\nHITL rejected, the action will not be executed"
sys.exit()
# return result, prompt
class HITLPostExecutionAction(Action):
pass
class HITLBaseAgent(Agent):
"""
Include all Agent classes for hitl use case
"""
def _get_unique_class_name(self, candidate_name: str) -> str:
if not MODULE_REGISTRY.has_module(candidate_name):
return candidate_name
i = 1
while True:
unique_name = f"{candidate_name}V{i}"
if not MODULE_REGISTRY.has_module(unique_name):
break
i += 1
return unique_name
class HITLInterceptorAgent(HITLBaseAgent):
"""HITL Interceptor Agent - Intercept the execution of other agents"""
def __init__(self,
target_agent_name: str,
target_action_name: str,
name: str = None,
interaction_type: HITLInteractionType = HITLInteractionType.APPROVE_REJECT,
mode: HITLMode = HITLMode.PRE_EXECUTION,
**kwargs):
# generate agent name
if target_action_name:
agent_name = f"HITL_Interceptor_{target_agent_name}_{target_action_name}_mode_{mode.value}"
else:
agent_name = f"HITL_Interceptor_{target_agent_name}_mode_{mode.value}"
super().__init__(
name=agent_name,
description=f"HITL Interceptor - Intercept the execution of {target_agent_name}",
is_human=True,
**kwargs
)
self.target_agent_name = target_agent_name
self.target_action_name = target_action_name
self.interaction_type = interaction_type
self.mode = mode
# add intercept action
if mode == HITLMode.PRE_EXECUTION:
action = HITLInterceptorAction(
target_agent_name=target_agent_name,
target_action_name=target_action_name or "any",
interaction_type=interaction_type,
mode=mode
)
elif mode == HITLMode.POST_EXECUTION:
action = HITLPostExecutionAction(
target_agent_name=target_agent_name,
target_action_name=target_action_name or "any",
interaction_type=interaction_type
)
else:
raise ValueError(f"Invalid mode: {mode}")
self.add_action(action)
# self.default_action_name = action.name
def get_hitl_agent_name(self) -> str:
"""
Get the name of the HITL agent. Useful when the name of HITL agent is generated dynamically.
"""
return self.name
class HITLUserInputCollectorAction(Action):
"""HITL User Input Collector Action - Collect user input for the HITL Interceptor"""
def __init__(
self,
name: str = None,
agent_name: str = None,
description: str = "A pre-defined action to collect user input for the HITL Interceptor",
interaction_type: HITLInteractionType = HITLInteractionType.COLLECT_USER_INPUT,
input_fields: dict = None,
**kwargs
):
if not name:
pass # TODO: generate name
super().__init__(name=name, description=description, **kwargs)
self.interaction_type = interaction_type
self.input_fields = input_fields or {}
self.agent_name = agent_name
def execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]:
try:
# get current running loop
loop = asyncio.get_running_loop()
if loop:
pass
# if in async context, cannot use asyncio.run()
raise RuntimeError("Cannot use asyncio.run() in async context. Use async_execute directly.")
except RuntimeError:
# if not in async context, use asyncio.run()
return asyncio.run(self.async_execute(llm, inputs, hitl_manager, sys_msg=sys_msg, **kwargs))
async def async_execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]:
"""
Asynchronous execution of HITL User Input Collector
"""
task_name = kwargs.get('wf_task', 'Unknown Task')
workflow_goal = kwargs.get('wf_goal', None)
# request user input from HITL manager
response = await hitl_manager.request_user_input(
task_name=task_name,
agent_name=self.agent_name,
action_name=self.name,
input_fields=self.input_fields,
workflow_goal=workflow_goal
)
result = {
"hitl_decision": response.decision,
"collected_user_input": response.modified_content or {},
"hitl_feedback": response.feedback
}
# Map collected user input to outputs if output format is defined
if self.outputs_format:
for output_name in self.outputs_format.get_attrs():
if output_name in response.modified_content:
result[output_name] = response.modified_content[output_name]
prompt = f"HITL User Input Collector executed: {self.name}"
if result["hitl_decision"] == HITLDecision.CONTINUE:
prompt += f"\nUser input collection completed: {result['collected_user_input']}"
return result, prompt
elif result["hitl_decision"] == HITLDecision.REJECT:
prompt += "\nUser cancelled input or error occurred"
sys.exit()
class HITLUserInputCollectorAgent(HITLBaseAgent):
"""HITL User Input Collector Agent - Collect user input for the HITL Interceptor"""
def __init__(self,
name: str = None,
input_fields: dict = None,
interaction_type: HITLInteractionType = HITLInteractionType.COLLECT_USER_INPUT,
**kwargs):
# generate agent name
if name:
agent_name = f"HITL_User_Input_Collector_{name}"
else:
pass # TODO: generate name
super().__init__(
name=agent_name,
description="HITL User Input Collector - Collect predefined user inputs",
is_human=True,
**kwargs
)
self.interaction_type = interaction_type
self.input_fields = input_fields or {}
# generation Action name
action_name_validated = False
name_i = 0
action_name = None
while not action_name_validated:
action_name = "HITLUserInputCollectorAction"+f"_{name_i}"
if MODULE_REGISTRY.has_module(action_name):
continue
else:
action_name_validated = True
# add user input collector action
action = HITLUserInputCollectorAction(
name=action_name,
agent_name=agent_name,
interaction_type=interaction_type,
input_fields=self.input_fields
)
self.add_action(action)
def get_hitl_agent_name(self) -> str:
"""
Get the name of the HITL agent. Useful when the name of HITL agent is generated dynamically.
"""
return self.name
def set_input_fields(self, input_fields: dict):
"""Set the input fields for user input collection"""
self.input_fields = input_fields
# Update the action's input fields as well
for action in self.actions:
if isinstance(action, HITLUserInputCollectorAction):
action.input_fields = input_fields
class HITLConversationAgent(HITLBaseAgent):
pass
class HITLConversationAction(Action):
pass |