Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/runners/hook/agent_hooks.py +35 -0
- aworld/runners/hook/hook_factory.py +44 -0
- aworld/runners/hook/hooks.py +64 -0
- aworld/runners/hook/template.py +41 -0
- aworld/runners/hook/utils.py +55 -0
aworld/runners/hook/agent_hooks.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import abc
|
| 4 |
+
from typing import AsyncGenerator
|
| 5 |
+
from aworld.core.context.base import Context, AgentContext
|
| 6 |
+
from aworld.core.event.base import Message
|
| 7 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
| 8 |
+
from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook
|
| 9 |
+
from aworld.utils.common import convert_to_snake
|
| 10 |
+
|
| 11 |
+
@HookFactory.register(name="PreLLMCallContextProcessHook",
|
| 12 |
+
desc="PreLLMCallContextProcessHook")
|
| 13 |
+
class PreLLMCallContextProcessHook(PreLLMCallHook):
|
| 14 |
+
"""Process in the hook point of the pre_llm_call."""
|
| 15 |
+
__metaclass__ = abc.ABCMeta
|
| 16 |
+
|
| 17 |
+
def name(self):
|
| 18 |
+
return convert_to_snake("PreLLMCallContextProcessHook")
|
| 19 |
+
|
| 20 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
| 21 |
+
''' context.get_agent_context(message.sender) ''' # get agent context
|
| 22 |
+
# and do something
|
| 23 |
+
|
| 24 |
+
@HookFactory.register(name="PostLLMCallContextProcessHook",
|
| 25 |
+
desc="PostLLMCallContextProcessHook")
|
| 26 |
+
class PostLLMCallContextProcessHook(PostLLMCallHook):
|
| 27 |
+
"""Process in the hook point of the post_llm_call."""
|
| 28 |
+
__metaclass__ = abc.ABCMeta
|
| 29 |
+
|
| 30 |
+
def name(self):
|
| 31 |
+
return convert_to_snake("PostLLMCallContextProcessHook")
|
| 32 |
+
|
| 33 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
| 34 |
+
'''context.get_agent_context(message.sender)''' # get agent context
|
| 35 |
+
|
aworld/runners/hook/hook_factory.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
|
| 6 |
+
from aworld.core.factory import Factory
|
| 7 |
+
from aworld.logs.util import logger
|
| 8 |
+
from aworld.runners.hook.hooks import Hook, StartHook, HookPoint
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class HookManager(Factory):
|
| 12 |
+
def __init__(self, type_name: str = None):
|
| 13 |
+
super(HookManager, self).__init__(type_name)
|
| 14 |
+
|
| 15 |
+
def __call__(self, name: str, **kwargs):
|
| 16 |
+
if name is None:
|
| 17 |
+
raise ValueError("hook name is None")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
if name in self._cls:
|
| 21 |
+
act = self._cls[name](**kwargs)
|
| 22 |
+
else:
|
| 23 |
+
raise RuntimeError("The hook was not registered.\nPlease confirm the package has been imported.")
|
| 24 |
+
except Exception:
|
| 25 |
+
err = sys.exc_info()
|
| 26 |
+
logger.warning(f"Failed to create hook with name {name}:\n{err[1]}")
|
| 27 |
+
act = None
|
| 28 |
+
return act
|
| 29 |
+
|
| 30 |
+
def hooks(self, name: str = None) -> Dict[str, List[Hook]]:
|
| 31 |
+
vals = list(filter(lambda s: not s.startswith('__'), dir(HookPoint)))
|
| 32 |
+
results = {val.lower(): [] for val in vals}
|
| 33 |
+
|
| 34 |
+
for k, v in self._cls.items():
|
| 35 |
+
hook = v()
|
| 36 |
+
if name and hook.point() != name:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
results.get(hook.point(), []).append(hook)
|
| 40 |
+
|
| 41 |
+
return results
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
HookFactory = HookManager("hook_type")
|
aworld/runners/hook/hooks.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import abc
|
| 4 |
+
from typing import AsyncGenerator
|
| 5 |
+
from aworld.core.context.base import Context, AgentContext
|
| 6 |
+
from aworld.core.event.base import Message
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class HookPoint:
|
| 10 |
+
START = "start"
|
| 11 |
+
FINISHED = "finished"
|
| 12 |
+
ERROR = "error"
|
| 13 |
+
PRE_LLM_CALL = "pre_llm_call"
|
| 14 |
+
POST_LLM_CALL = "post_llm_call"
|
| 15 |
+
|
| 16 |
+
class Hook:
|
| 17 |
+
"""Runner hook."""
|
| 18 |
+
__metaclass__ = abc.ABCMeta
|
| 19 |
+
|
| 20 |
+
@abc.abstractmethod
|
| 21 |
+
def point(self):
|
| 22 |
+
"""Hook point."""
|
| 23 |
+
|
| 24 |
+
@abc.abstractmethod
|
| 25 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
| 26 |
+
"""Execute hook function."""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class StartHook(Hook):
|
| 30 |
+
"""Process in the hook point of the start."""
|
| 31 |
+
__metaclass__ = abc.ABCMeta
|
| 32 |
+
|
| 33 |
+
def point(self):
|
| 34 |
+
return HookPoint.START
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FinishedHook(Hook):
|
| 38 |
+
"""Process in the hook point of the finished."""
|
| 39 |
+
__metaclass__ = abc.ABCMeta
|
| 40 |
+
|
| 41 |
+
def point(self):
|
| 42 |
+
return HookPoint.FINISHED
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ErrorHook(Hook):
|
| 46 |
+
"""Process in the hook point of the error."""
|
| 47 |
+
__metaclass__ = abc.ABCMeta
|
| 48 |
+
|
| 49 |
+
def point(self):
|
| 50 |
+
return HookPoint.ERROR
|
| 51 |
+
|
| 52 |
+
class PreLLMCallHook(Hook):
|
| 53 |
+
"""Process in the hook point of the pre_llm_call."""
|
| 54 |
+
__metaclass__ = abc.ABCMeta
|
| 55 |
+
|
| 56 |
+
def point(self):
|
| 57 |
+
return HookPoint.PRE_LLM_CALL
|
| 58 |
+
|
| 59 |
+
class PostLLMCallHook(Hook):
|
| 60 |
+
"""Process in the hook point of the post_llm_call."""
|
| 61 |
+
__metaclass__ = abc.ABCMeta
|
| 62 |
+
|
| 63 |
+
def point(self):
|
| 64 |
+
return HookPoint.POST_LLM_CALL
|
aworld/runners/hook/template.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
|
| 4 |
+
HOOK_TEMPLATE = """
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
from aworld.core.context.base import Context
|
| 8 |
+
|
| 9 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
| 10 |
+
from aworld.runners.hook.hooks import *
|
| 11 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
| 12 |
+
from aworld.logs.util import logger
|
| 13 |
+
|
| 14 |
+
from aworld.utils.common import convert_to_snake
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@HookFactory.register(name="{name}",
|
| 18 |
+
desc="{desc}")
|
| 19 |
+
class {name}({point}Hook):
|
| 20 |
+
def name(self):
|
| 21 |
+
return convert_to_snake("{name}")
|
| 22 |
+
|
| 23 |
+
async def exec(self, message: Message) -> Message:
|
| 24 |
+
{func_import}import {func}
|
| 25 |
+
try:
|
| 26 |
+
res = {func}(message)
|
| 27 |
+
if not res:
|
| 28 |
+
raise ValueError(f"{func} no result return.")
|
| 29 |
+
return Message(payload=res,
|
| 30 |
+
session_id=Context.instance().session_id,
|
| 31 |
+
sender="{name}",
|
| 32 |
+
category=Constants.TASK,
|
| 33 |
+
topic="{topic}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(traceback.format_exc())
|
| 36 |
+
return Message(payload=str(e),
|
| 37 |
+
session_id=Context.instance().session_id,
|
| 38 |
+
sender="{name}",
|
| 39 |
+
category=Constants.TASK,
|
| 40 |
+
topic=TopicType.ERROR)
|
| 41 |
+
"""
|
aworld/runners/hook/utils.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
import inspect
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Any
|
| 8 |
+
|
| 9 |
+
from aworld.runners.hook.template import HOOK_TEMPLATE
|
| 10 |
+
from aworld.utils.common import snake_to_camel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def hook(hook_point: str, name: str = None):
|
| 14 |
+
"""Hook decorator.
|
| 15 |
+
|
| 16 |
+
NOTE: Hooks can be annotated, but they need to comply with the protocol agreement.
|
| 17 |
+
The input parameter of the hook function is `Message` type, and the @hook needs to specify `hook_point`.
|
| 18 |
+
|
| 19 |
+
Examples:
|
| 20 |
+
>>> @hook(hook_point=HookPoint.ERROR)
|
| 21 |
+
>>> def error_process(message: Message) -> Message | None:
|
| 22 |
+
>>> print("process error")
|
| 23 |
+
The function `error_process` will be executed when an error message appears in the task,
|
| 24 |
+
you can choose return nothing or return a message.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
hook_point: Hook point that wants to process the message.
|
| 28 |
+
name: Hook name.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
| 32 |
+
# converts python function into a hoop with associated hoop point
|
| 33 |
+
func_import = func.__module__
|
| 34 |
+
if func_import == '__main__':
|
| 35 |
+
path = inspect.getsourcefile(func)
|
| 36 |
+
package = path.replace(os.getcwd(), '').replace('.py', '')
|
| 37 |
+
if package[0] == '/':
|
| 38 |
+
package = package[1:]
|
| 39 |
+
func_import = f"from {package} "
|
| 40 |
+
else:
|
| 41 |
+
func_import = f"from {func_import} "
|
| 42 |
+
|
| 43 |
+
real_name = name if name else func.__name__
|
| 44 |
+
con = HOOK_TEMPLATE.format(func_import=func_import,
|
| 45 |
+
func=func.__name__,
|
| 46 |
+
point=snake_to_camel(hook_point),
|
| 47 |
+
name=real_name,
|
| 48 |
+
topic=hook_point,
|
| 49 |
+
desc='')
|
| 50 |
+
with open(f"{real_name}.py", 'w+') as write:
|
| 51 |
+
write.writelines(con)
|
| 52 |
+
importlib.import_module(real_name)
|
| 53 |
+
return func
|
| 54 |
+
|
| 55 |
+
return decorator
|