|
|
|
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
from typing import Tuple |
|
|
from ..agents.agent import Agent |
|
|
from ..actions.action import Action |
|
|
from .approval_manager import HITLManager |
|
|
from .hitl import HITLInteractionType, HITLMode, HITLDecision |
|
|
from ..core.registry import MODULE_REGISTRY |
|
|
from ..core.logging import logger |
|
|
|
|
|
class HITLInterceptorAction(Action): |
|
|
"""HITL Interceptor Action""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
target_agent_name: str, |
|
|
target_action_name: str, |
|
|
name: str = None, |
|
|
description: str = "A pre-defined action to proceed the Human-In-The-Loop", |
|
|
interaction_type: HITLInteractionType = HITLInteractionType.APPROVE_REJECT, |
|
|
mode: HITLMode = HITLMode.PRE_EXECUTION, |
|
|
**kwargs |
|
|
): |
|
|
if not name: |
|
|
name = f"hitl_intercept_{target_agent_name}_{target_action_name}_mode_{mode.value}_action" |
|
|
super().__init__( |
|
|
name=name, |
|
|
description=description, |
|
|
**kwargs |
|
|
) |
|
|
self.target_agent_name = target_agent_name |
|
|
self.target_action_name = target_action_name |
|
|
self.interaction_type = interaction_type |
|
|
self.mode = mode |
|
|
|
|
|
def execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]: |
|
|
try: |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
if loop: |
|
|
pass |
|
|
|
|
|
raise RuntimeError("Cannot use asyncio.run() in async context. Use async_execute directly.") |
|
|
except RuntimeError: |
|
|
|
|
|
return asyncio.run(self.async_execute(llm, inputs, hitl_manager, sys_msg=sys_msg, **kwargs)) |
|
|
|
|
|
async def async_execute(self, llm, inputs: dict, hitl_manager:HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]: |
|
|
""" |
|
|
Asynchronous execution of HITL Interceptor |
|
|
""" |
|
|
|
|
|
task_name = kwargs.get('wf_task', 'Unknown Task') |
|
|
workflow_goal = kwargs.get('wf_goal', None) |
|
|
|
|
|
|
|
|
response = await hitl_manager.request_approval( |
|
|
task_name=task_name, |
|
|
agent_name=self.target_agent_name, |
|
|
action_name=self.target_action_name, |
|
|
interaction_type=self.interaction_type, |
|
|
mode=self.mode, |
|
|
action_inputs_data=inputs, |
|
|
workflow_goal=workflow_goal |
|
|
) |
|
|
|
|
|
result = { |
|
|
"hitl_decision": response.decision, |
|
|
"target_agent": self.target_agent_name, |
|
|
"target_action": self.target_action_name, |
|
|
"hitl_feedback": response.feedback |
|
|
} |
|
|
for output_name in self.outputs_format.get_attrs(): |
|
|
try: |
|
|
result |= {output_name: inputs[hitl_manager.hitl_input_output_mapping[output_name]]} |
|
|
except Exception as e: |
|
|
logger.exception(e) |
|
|
|
|
|
prompt = f"HITL Interceptor executed for {self.target_agent_name}.{self.target_action_name}" |
|
|
if result["hitl_decision"] == HITLDecision.APPROVE: |
|
|
prompt += "\nHITL approved, the action will be executed" |
|
|
return result, prompt |
|
|
elif result["hitl_decision"] == HITLDecision.REJECT: |
|
|
prompt += "\nHITL rejected, the action will not be executed" |
|
|
sys.exit() |
|
|
|
|
|
|
|
|
class HITLPostExecutionAction(Action): |
|
|
pass |
|
|
|
|
|
class HITLBaseAgent(Agent): |
|
|
""" |
|
|
Include all Agent classes for hitl use case |
|
|
""" |
|
|
def _get_unique_class_name(self, candidate_name: str) -> str: |
|
|
|
|
|
if not MODULE_REGISTRY.has_module(candidate_name): |
|
|
return candidate_name |
|
|
|
|
|
i = 1 |
|
|
while True: |
|
|
unique_name = f"{candidate_name}V{i}" |
|
|
if not MODULE_REGISTRY.has_module(unique_name): |
|
|
break |
|
|
i += 1 |
|
|
return unique_name |
|
|
|
|
|
class HITLInterceptorAgent(HITLBaseAgent): |
|
|
"""HITL Interceptor Agent - Intercept the execution of other agents""" |
|
|
|
|
|
def __init__(self, |
|
|
target_agent_name: str, |
|
|
target_action_name: str, |
|
|
name: str = None, |
|
|
interaction_type: HITLInteractionType = HITLInteractionType.APPROVE_REJECT, |
|
|
mode: HITLMode = HITLMode.PRE_EXECUTION, |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
if target_action_name: |
|
|
agent_name = f"HITL_Interceptor_{target_agent_name}_{target_action_name}_mode_{mode.value}" |
|
|
else: |
|
|
agent_name = f"HITL_Interceptor_{target_agent_name}_mode_{mode.value}" |
|
|
|
|
|
super().__init__( |
|
|
name=agent_name, |
|
|
description=f"HITL Interceptor - Intercept the execution of {target_agent_name}", |
|
|
is_human=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self.target_agent_name = target_agent_name |
|
|
self.target_action_name = target_action_name |
|
|
self.interaction_type = interaction_type |
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
if mode == HITLMode.PRE_EXECUTION: |
|
|
action = HITLInterceptorAction( |
|
|
target_agent_name=target_agent_name, |
|
|
target_action_name=target_action_name or "any", |
|
|
interaction_type=interaction_type, |
|
|
mode=mode |
|
|
) |
|
|
elif mode == HITLMode.POST_EXECUTION: |
|
|
action = HITLPostExecutionAction( |
|
|
target_agent_name=target_agent_name, |
|
|
target_action_name=target_action_name or "any", |
|
|
interaction_type=interaction_type |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid mode: {mode}") |
|
|
|
|
|
self.add_action(action) |
|
|
|
|
|
|
|
|
def get_hitl_agent_name(self) -> str: |
|
|
""" |
|
|
Get the name of the HITL agent. Useful when the name of HITL agent is generated dynamically. |
|
|
""" |
|
|
return self.name |
|
|
|
|
|
|
|
|
class HITLUserInputCollectorAction(Action): |
|
|
"""HITL User Input Collector Action - Collect user input for the HITL Interceptor""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str = None, |
|
|
agent_name: str = None, |
|
|
description: str = "A pre-defined action to collect user input for the HITL Interceptor", |
|
|
interaction_type: HITLInteractionType = HITLInteractionType.COLLECT_USER_INPUT, |
|
|
input_fields: dict = None, |
|
|
**kwargs |
|
|
): |
|
|
if not name: |
|
|
pass |
|
|
|
|
|
super().__init__(name=name, description=description, **kwargs) |
|
|
|
|
|
self.interaction_type = interaction_type |
|
|
self.input_fields = input_fields or {} |
|
|
self.agent_name = agent_name |
|
|
|
|
|
def execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]: |
|
|
try: |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
if loop: |
|
|
pass |
|
|
|
|
|
raise RuntimeError("Cannot use asyncio.run() in async context. Use async_execute directly.") |
|
|
except RuntimeError: |
|
|
|
|
|
return asyncio.run(self.async_execute(llm, inputs, hitl_manager, sys_msg=sys_msg, **kwargs)) |
|
|
|
|
|
async def async_execute(self, llm, inputs: dict, hitl_manager: HITLManager, sys_msg: str = None, **kwargs) -> Tuple[dict, str]: |
|
|
""" |
|
|
Asynchronous execution of HITL User Input Collector |
|
|
""" |
|
|
|
|
|
task_name = kwargs.get('wf_task', 'Unknown Task') |
|
|
workflow_goal = kwargs.get('wf_goal', None) |
|
|
|
|
|
|
|
|
response = await hitl_manager.request_user_input( |
|
|
task_name=task_name, |
|
|
agent_name=self.agent_name, |
|
|
action_name=self.name, |
|
|
input_fields=self.input_fields, |
|
|
workflow_goal=workflow_goal |
|
|
) |
|
|
|
|
|
result = { |
|
|
"hitl_decision": response.decision, |
|
|
"collected_user_input": response.modified_content or {}, |
|
|
"hitl_feedback": response.feedback |
|
|
} |
|
|
|
|
|
|
|
|
if self.outputs_format: |
|
|
for output_name in self.outputs_format.get_attrs(): |
|
|
if output_name in response.modified_content: |
|
|
result[output_name] = response.modified_content[output_name] |
|
|
|
|
|
prompt = f"HITL User Input Collector executed: {self.name}" |
|
|
if result["hitl_decision"] == HITLDecision.CONTINUE: |
|
|
prompt += f"\nUser input collection completed: {result['collected_user_input']}" |
|
|
return result, prompt |
|
|
elif result["hitl_decision"] == HITLDecision.REJECT: |
|
|
prompt += "\nUser cancelled input or error occurred" |
|
|
sys.exit() |
|
|
|
|
|
class HITLUserInputCollectorAgent(HITLBaseAgent): |
|
|
"""HITL User Input Collector Agent - Collect user input for the HITL Interceptor""" |
|
|
|
|
|
def __init__(self, |
|
|
name: str = None, |
|
|
input_fields: dict = None, |
|
|
interaction_type: HITLInteractionType = HITLInteractionType.COLLECT_USER_INPUT, |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
if name: |
|
|
agent_name = f"HITL_User_Input_Collector_{name}" |
|
|
else: |
|
|
pass |
|
|
|
|
|
super().__init__( |
|
|
name=agent_name, |
|
|
description="HITL User Input Collector - Collect predefined user inputs", |
|
|
is_human=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self.interaction_type = interaction_type |
|
|
self.input_fields = input_fields or {} |
|
|
|
|
|
|
|
|
action_name_validated = False |
|
|
name_i = 0 |
|
|
action_name = None |
|
|
while not action_name_validated: |
|
|
action_name = "HITLUserInputCollectorAction"+f"_{name_i}" |
|
|
if MODULE_REGISTRY.has_module(action_name): |
|
|
continue |
|
|
else: |
|
|
action_name_validated = True |
|
|
|
|
|
action = HITLUserInputCollectorAction( |
|
|
name=action_name, |
|
|
agent_name=agent_name, |
|
|
interaction_type=interaction_type, |
|
|
input_fields=self.input_fields |
|
|
) |
|
|
|
|
|
self.add_action(action) |
|
|
|
|
|
def get_hitl_agent_name(self) -> str: |
|
|
""" |
|
|
Get the name of the HITL agent. Useful when the name of HITL agent is generated dynamically. |
|
|
""" |
|
|
return self.name |
|
|
|
|
|
def set_input_fields(self, input_fields: dict): |
|
|
"""Set the input fields for user input collection""" |
|
|
self.input_fields = input_fields |
|
|
|
|
|
for action in self.actions: |
|
|
if isinstance(action, HITLUserInputCollectorAction): |
|
|
action.input_fields = input_fields |
|
|
|
|
|
class HITLConversationAgent(HITLBaseAgent): |
|
|
pass |
|
|
|
|
|
class HITLConversationAction(Action): |
|
|
pass |