Jia12 commited on
Commit
bd37bb9
·
verified ·
1 Parent(s): 91df074

Upload 111 files

Browse files

Add lagent files

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. lagent/__init__.py +4 -0
  2. lagent/__pycache__/__init__.cpython-311.pyc +0 -0
  3. lagent/__pycache__/schema.cpython-311.pyc +0 -0
  4. lagent/__pycache__/version.cpython-311.pyc +0 -0
  5. lagent/actions/__init__.py +26 -0
  6. lagent/actions/__pycache__/__init__.cpython-311.pyc +0 -0
  7. lagent/actions/__pycache__/action_executor.cpython-311.pyc +0 -0
  8. lagent/actions/__pycache__/arxiv_search.cpython-311.pyc +0 -0
  9. lagent/actions/__pycache__/base_action.cpython-311.pyc +0 -0
  10. lagent/actions/__pycache__/bing_map.cpython-311.pyc +0 -0
  11. lagent/actions/__pycache__/builtin_actions.cpython-311.pyc +0 -0
  12. lagent/actions/__pycache__/google_scholar_search.cpython-311.pyc +0 -0
  13. lagent/actions/__pycache__/google_search.cpython-311.pyc +0 -0
  14. lagent/actions/__pycache__/ipython_interactive.cpython-311.pyc +0 -0
  15. lagent/actions/__pycache__/ipython_interpreter.cpython-311.pyc +0 -0
  16. lagent/actions/__pycache__/ipython_manager.cpython-311.pyc +0 -0
  17. lagent/actions/__pycache__/parser.cpython-311.pyc +0 -0
  18. lagent/actions/__pycache__/ppt.cpython-311.pyc +0 -0
  19. lagent/actions/__pycache__/python_interpreter.cpython-311.pyc +0 -0
  20. lagent/actions/__pycache__/weather_query.cpython-311.pyc +0 -0
  21. lagent/actions/__pycache__/web_browser.cpython-311.pyc +0 -0
  22. lagent/actions/action_executor.py +198 -0
  23. lagent/actions/arxiv_search.py +79 -0
  24. lagent/actions/base_action.py +434 -0
  25. lagent/actions/bing_map.py +268 -0
  26. lagent/actions/builtin_actions.py +109 -0
  27. lagent/actions/google_scholar_search.py +438 -0
  28. lagent/actions/google_search.py +244 -0
  29. lagent/actions/ipython_interactive.py +273 -0
  30. lagent/actions/ipython_interpreter.py +584 -0
  31. lagent/actions/ipython_manager.py +220 -0
  32. lagent/actions/parser.py +146 -0
  33. lagent/actions/ppt.py +233 -0
  34. lagent/actions/python_interpreter.py +176 -0
  35. lagent/actions/weather_query.py +71 -0
  36. lagent/actions/web_browser.py +908 -0
  37. lagent/agents/__init__.py +9 -0
  38. lagent/agents/__pycache__/__init__.cpython-311.pyc +0 -0
  39. lagent/agents/__pycache__/agent.cpython-311.pyc +0 -0
  40. lagent/agents/__pycache__/react.cpython-311.pyc +0 -0
  41. lagent/agents/__pycache__/stream.cpython-311.pyc +0 -0
  42. lagent/agents/agent.py +400 -0
  43. lagent/agents/aggregator/__init__.py +4 -0
  44. lagent/agents/aggregator/__pycache__/__init__.cpython-311.pyc +0 -0
  45. lagent/agents/aggregator/__pycache__/default_aggregator.cpython-311.pyc +0 -0
  46. lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-311.pyc +0 -0
  47. lagent/agents/aggregator/default_aggregator.py +44 -0
  48. lagent/agents/aggregator/tool_aggregator.py +106 -0
  49. lagent/agents/react.py +161 -0
  50. lagent/agents/stream.py +316 -0
lagent/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .version import __version__, version_info
3
+
4
+ __all__ = ['__version__', 'version_info']
lagent/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (273 Bytes). View file
 
lagent/__pycache__/schema.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
lagent/__pycache__/version.cpython-311.pyc ADDED
Binary file (1.3 kB). View file
 
lagent/actions/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .action_executor import ActionExecutor, AsyncActionExecutor
2
+ from .arxiv_search import ArxivSearch, AsyncArxivSearch
3
+ from .base_action import BaseAction, tool_api
4
+ from .bing_map import AsyncBINGMap, BINGMap
5
+ from .builtin_actions import FinishAction, InvalidAction, NoAction
6
+ from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
7
+ from .google_search import AsyncGoogleSearch, GoogleSearch
8
+ from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
9
+ from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
10
+ from .ipython_manager import IPythonInteractiveManager
11
+ from .parser import BaseParser, JsonParser, TupleParser
12
+ from .ppt import PPT, AsyncPPT
13
+ from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
14
+ from .web_browser import AsyncWebBrowser, WebBrowser
15
+ from .weather_query import WeatherQuery
16
+
17
+ __all__ = [
18
+ 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
19
+ 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
20
+ 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
21
+ 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
22
+ 'IPythonInteractive', 'AsyncIPythonInteractive',
23
+ 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
24
+ 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
25
+ 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery'
26
+ ]
lagent/actions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.83 kB). View file
 
lagent/actions/__pycache__/action_executor.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
lagent/actions/__pycache__/arxiv_search.cpython-311.pyc ADDED
Binary file (4.77 kB). View file
 
lagent/actions/__pycache__/base_action.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
lagent/actions/__pycache__/bing_map.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
lagent/actions/__pycache__/builtin_actions.cpython-311.pyc ADDED
Binary file (5.47 kB). View file
 
lagent/actions/__pycache__/google_scholar_search.cpython-311.pyc ADDED
Binary file (18.6 kB). View file
 
lagent/actions/__pycache__/google_search.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
lagent/actions/__pycache__/ipython_interactive.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
lagent/actions/__pycache__/ipython_interpreter.cpython-311.pyc ADDED
Binary file (32.1 kB). View file
 
lagent/actions/__pycache__/ipython_manager.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
lagent/actions/__pycache__/parser.cpython-311.pyc ADDED
Binary file (8.87 kB). View file
 
lagent/actions/__pycache__/ppt.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
lagent/actions/__pycache__/python_interpreter.cpython-311.pyc ADDED
Binary file (9.03 kB). View file
 
lagent/actions/__pycache__/weather_query.cpython-311.pyc ADDED
Binary file (4.3 kB). View file
 
lagent/actions/__pycache__/web_browser.cpython-311.pyc ADDED
Binary file (59.1 kB). View file
 
lagent/actions/action_executor.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from collections import OrderedDict
3
+ from typing import Callable, Dict, List, Union
4
+
5
+ from lagent.actions.base_action import BaseAction
6
+ from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
7
+ from lagent.hooks import Hook, RemovableHandle
8
+ from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
9
+ from lagent.utils import create_object
10
+
11
+
12
+ class ActionExecutor:
13
+ """The action executor class.
14
+
15
+ Args:
16
+ actions (Union[BaseAction, List[BaseAction]]): The action or actions.
17
+ invalid_action (BaseAction, optional): The invalid action. Defaults to
18
+ InvalidAction().
19
+ no_action (BaseAction, optional): The no action.
20
+ Defaults to NoAction().
21
+ finish_action (BaseAction, optional): The finish action. Defaults to
22
+ FinishAction().
23
+ finish_in_action (bool, optional): Whether the finish action is in the
24
+ action list. Defaults to False.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
30
+ invalid_action: BaseAction = dict(type=InvalidAction),
31
+ no_action: BaseAction = dict(type=NoAction),
32
+ finish_action: BaseAction = dict(type=FinishAction),
33
+ finish_in_action: bool = False,
34
+ hooks: List[Dict] = None,
35
+ ):
36
+
37
+ if not isinstance(actions, list):
38
+ actions = [actions]
39
+ finish_action = create_object(finish_action)
40
+ if finish_in_action:
41
+ actions.append(finish_action)
42
+ for i, action in enumerate(actions):
43
+ actions[i] = create_object(action)
44
+ self.actions = {action.name: action for action in actions}
45
+
46
+ self.invalid_action = create_object(invalid_action)
47
+ self.no_action = create_object(no_action)
48
+ self.finish_action = finish_action
49
+ self._hooks: Dict[int, Hook] = OrderedDict()
50
+ if hooks:
51
+ for hook in hooks:
52
+ hook = create_object(hook)
53
+ self.register_hook(hook)
54
+
55
+ def description(self) -> List[Dict]:
56
+ actions = []
57
+ for action_name, action in self.actions.items():
58
+ if action.is_toolkit:
59
+ for api in action.description['api_list']:
60
+ api_desc = api.copy()
61
+ api_desc['name'] = f"{action_name}.{api_desc['name']}"
62
+ actions.append(api_desc)
63
+ else:
64
+ action_desc = action.description.copy()
65
+ actions.append(action_desc)
66
+ return actions
67
+
68
+ def __contains__(self, name: str):
69
+ return name in self.actions
70
+
71
+ def keys(self):
72
+ return list(self.actions.keys())
73
+
74
+ def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
75
+ action = create_object(action)
76
+ self.actions[action.name] = action
77
+
78
+ def __delitem__(self, name: str):
79
+ del self.actions[name]
80
+
81
+ def forward(self, name, parameters, **kwargs) -> ActionReturn:
82
+ action_name, api_name = (
83
+ name.split('.') if '.' in name else (name, 'run'))
84
+ action_return: ActionReturn = ActionReturn()
85
+ if action_name not in self:
86
+ if name == self.no_action.name:
87
+ action_return = self.no_action(parameters)
88
+ elif name == self.finish_action.name:
89
+ action_return = self.finish_action(parameters)
90
+ else:
91
+ action_return = self.invalid_action(parameters)
92
+ else:
93
+ action_return = self.actions[action_name](parameters, api_name)
94
+ action_return.valid = ActionValidCode.OPEN
95
+ return action_return
96
+
97
+ def __call__(self,
98
+ message: AgentMessage,
99
+ session_id=0,
100
+ **kwargs) -> AgentMessage:
101
+ # message.receiver = self.name
102
+ for hook in self._hooks.values():
103
+ result = hook.before_action(self, message, session_id)
104
+ if result:
105
+ message = result
106
+
107
+ assert isinstance(message.content, FunctionCall) or (
108
+ isinstance(message.content, dict) and 'name' in message.content
109
+ and 'parameters' in message.content)
110
+ if isinstance(message.content, dict):
111
+ name = message.content.get('name')
112
+ parameters = message.content.get('parameters')
113
+ else:
114
+ name = message.content.name
115
+ parameters = message.content.parameters
116
+
117
+ response_message = self.forward(
118
+ name=name, parameters=parameters, **kwargs)
119
+ if not isinstance(response_message, AgentMessage):
120
+ response_message = AgentMessage(
121
+ sender=self.__class__.__name__,
122
+ content=response_message,
123
+ )
124
+
125
+ for hook in self._hooks.values():
126
+ result = hook.after_action(self, response_message, session_id)
127
+ if result:
128
+ response_message = result
129
+ return response_message
130
+
131
+ def register_hook(self, hook: Callable):
132
+ handle = RemovableHandle(self._hooks)
133
+ self._hooks[handle.id] = hook
134
+ return handle
135
+
136
+
137
+ class AsyncActionExecutor(ActionExecutor):
138
+
139
+ async def forward(self, name, parameters, **kwargs) -> ActionReturn:
140
+ action_name, api_name = (
141
+ name.split('.') if '.' in name else (name, 'run'))
142
+ action_return: ActionReturn = ActionReturn()
143
+ if action_name not in self:
144
+ if name == self.no_action.name:
145
+ action_return = self.no_action(parameters)
146
+ elif name == self.finish_action.name:
147
+ action_return = self.finish_action(parameters)
148
+ else:
149
+ action_return = self.invalid_action(parameters)
150
+ else:
151
+ action = self.actions[action_name]
152
+ if inspect.iscoroutinefunction(action.__call__):
153
+ action_return = await action(parameters, api_name)
154
+ else:
155
+ action_return = action(parameters, api_name)
156
+ action_return.valid = ActionValidCode.OPEN
157
+ return action_return
158
+
159
+ async def __call__(self,
160
+ message: AgentMessage,
161
+ session_id=0,
162
+ **kwargs) -> AgentMessage:
163
+ # message.receiver = self.name
164
+ for hook in self._hooks.values():
165
+ if inspect.iscoroutinefunction(hook.before_action):
166
+ result = await hook.before_action(self, message, session_id)
167
+ else:
168
+ result = hook.before_action(self, message, session_id)
169
+ if result:
170
+ message = result
171
+
172
+ assert isinstance(message.content, FunctionCall) or (
173
+ isinstance(message.content, dict) and 'name' in message.content
174
+ and 'parameters' in message.content)
175
+ if isinstance(message.content, dict):
176
+ name = message.content.get('name')
177
+ parameters = message.content.get('parameters')
178
+ else:
179
+ name = message.content.name
180
+ parameters = message.content.parameters
181
+
182
+ response_message = await self.forward(
183
+ name=name, parameters=parameters, **kwargs)
184
+ if not isinstance(response_message, AgentMessage):
185
+ response_message = AgentMessage(
186
+ sender=self.__class__.__name__,
187
+ content=response_message,
188
+ )
189
+
190
+ for hook in self._hooks.values():
191
+ if inspect.iscoroutinefunction(hook.after_action):
192
+ result = await hook.after_action(self, response_message,
193
+ session_id)
194
+ else:
195
+ result = hook.after_action(self, response_message, session_id)
196
+ if result:
197
+ response_message = result
198
+ return response_message
lagent/actions/arxiv_search.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type
2
+
3
+ from asyncer import asyncify
4
+
5
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
+ from lagent.actions.parser import BaseParser, JsonParser
7
+ from lagent.schema import ActionReturn, ActionStatusCode
8
+
9
+
10
+ class ArxivSearch(BaseAction):
11
+ """Search information from Arxiv.org. \
12
+ Useful for when you need to answer questions about Physics, Mathematics, \
13
+ Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
14
+ Electrical Engineering, and Economics from scientific articles on arxiv.org.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ top_k_results: int = 3,
20
+ max_query_len: int = 300,
21
+ doc_content_chars_max: int = 1500,
22
+ description: Optional[dict] = None,
23
+ parser: Type[BaseParser] = JsonParser,
24
+ ):
25
+ super().__init__(description, parser)
26
+ self.top_k_results = top_k_results
27
+ self.max_query_len = max_query_len
28
+ self.doc_content_chars_max = doc_content_chars_max
29
+
30
+ @tool_api(explode_return=True)
31
+ def get_arxiv_article_information(self, query: str) -> dict:
32
+ """Run Arxiv search and get the article meta information.
33
+
34
+ Args:
35
+ query (:class:`str`): the content of search query
36
+
37
+ Returns:
38
+ :class:`dict`: article information
39
+ * content (str): a list of 3 arxiv search papers
40
+ """
41
+ import arxiv
42
+
43
+ try:
44
+ results = arxiv.Search( # type: ignore
45
+ query[: self.max_query_len], max_results=self.top_k_results
46
+ ).results()
47
+ except Exception as exc:
48
+ return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
49
+ docs = [
50
+ f'Published: {result.updated.date()}\nTitle: {result.title}\n'
51
+ f'Authors: {", ".join(a.name for a in result.authors)}\n'
52
+ f'Summary: {result.summary[:self.doc_content_chars_max]}'
53
+ for result in results
54
+ ]
55
+ if docs:
56
+ return {'content': '\n\n'.join(docs)}
57
+ return {'content': 'No good Arxiv Result was found'}
58
+
59
+
60
+ class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
61
+ """Search information from Arxiv.org. \
62
+ Useful for when you need to answer questions about Physics, Mathematics, \
63
+ Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
64
+ Electrical Engineering, and Economics from scientific articles on arxiv.org.
65
+ """
66
+
67
+ @tool_api(explode_return=True)
68
+ @asyncify
69
+ def get_arxiv_article_information(self, query: str) -> dict:
70
+ """Run Arxiv search and get the article meta information.
71
+
72
+ Args:
73
+ query (:class:`str`): the content of search query
74
+
75
+ Returns:
76
+ :class:`dict`: article information
77
+ * content (str): a list of 3 arxiv search papers
78
+ """
79
+ return super().get_arxiv_article_information(query)
lagent/actions/base_action.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import re
4
+ from abc import ABCMeta
5
+ from copy import deepcopy
6
+ from functools import wraps
7
+ from typing import Callable, Optional, Type, get_args, get_origin
8
+
9
+ try:
10
+ from typing import Annotated
11
+ except ImportError:
12
+ from typing_extensions import Annotated
13
+
14
+ from griffe import Docstring
15
+
16
+ try:
17
+ from griffe import DocstringSectionKind
18
+ except ImportError:
19
+ from griffe.enumerations import DocstringSectionKind
20
+
21
+ from ..schema import ActionReturn, ActionStatusCode
22
+ from .parser import BaseParser, JsonParser, ParseError
23
+
24
+ logging.getLogger('griffe').setLevel(logging.ERROR)
25
+
26
+
27
+ def tool_api(func: Optional[Callable] = None,
28
+ *,
29
+ explode_return: bool = False,
30
+ returns_named_value: bool = False,
31
+ **kwargs):
32
+ """Turn functions into tools. It will parse typehints as well as docstrings
33
+ to build the tool description and attach it to functions via an attribute
34
+ ``api_description``.
35
+
36
+ Examples:
37
+
38
+ .. code-block:: python
39
+
40
+ # typehints has higher priority than docstrings
41
+ from typing import Annotated
42
+
43
+ @tool_api
44
+ def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
45
+ '''Add operation
46
+
47
+ Args:
48
+ x (int): a
49
+ y (int): b
50
+ '''
51
+ return a + b
52
+
53
+ print(add.api_description)
54
+
55
+ Args:
56
+ func (Optional[Callable]): function to decorate. Defaults to ``None``.
57
+ explode_return (bool): whether to flatten the dictionary or tuple return
58
+ as the ``return_data`` field. When enabled, it is recommended to
59
+ annotate the member in docstrings. Defaults to ``False``.
60
+
61
+ .. code-block:: python
62
+
63
+ @tool_api(explode_return=True)
64
+ def foo(a, b):
65
+ '''A simple function
66
+
67
+ Args:
68
+ a (int): a
69
+ b (int): b
70
+
71
+ Returns:
72
+ dict: information of inputs
73
+ * x: value of a
74
+ * y: value of b
75
+ '''
76
+ return {'x': a, 'y': b}
77
+
78
+ print(foo.api_description)
79
+
80
+ returns_named_value (bool): whether to parse ``thing: Description`` in
81
+ returns sections as a name and description, rather than a type and
82
+ description. When true, type must be wrapped in parentheses:
83
+ ``(int): Description``. When false, parentheses are optional but
84
+ the items cannot be named: ``int: Description``. Defaults to ``False``.
85
+
86
+ Returns:
87
+ Callable: wrapped function or partial decorator
88
+
89
+ Important:
90
+ ``return_data`` field will be added to ``api_description`` only
91
+ when ``explode_return`` or ``returns_named_value`` is enabled.
92
+ """
93
+
94
+ def _detect_type(string):
95
+ field_type = 'STRING'
96
+ if 'list' in string:
97
+ field_type = 'Array'
98
+ elif 'str' not in string:
99
+ if 'float' in string:
100
+ field_type = 'FLOAT'
101
+ elif 'int' in string:
102
+ field_type = 'NUMBER'
103
+ elif 'bool' in string:
104
+ field_type = 'BOOLEAN'
105
+ return field_type
106
+
107
+ def _explode(desc):
108
+ kvs = []
109
+ desc = '\nArgs:\n' + '\n'.join([
110
+ ' ' + item.lstrip(' -+*#.')
111
+ for item in desc.split('\n')[1:] if item.strip()
112
+ ])
113
+ docs = Docstring(desc).parse('google')
114
+ if not docs:
115
+ return kvs
116
+ if docs[0].kind is DocstringSectionKind.parameters:
117
+ for d in docs[0].value:
118
+ d = d.as_dict()
119
+ if not d['annotation']:
120
+ d.pop('annotation')
121
+ else:
122
+ d['type'] = _detect_type(d.pop('annotation').lower())
123
+ kvs.append(d)
124
+ return kvs
125
+
126
+ def _parse_tool(function):
127
+ # remove rst syntax
128
+ docs = Docstring(
129
+ re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
130
+ 'google', returns_named_value=returns_named_value, **kwargs)
131
+ desc = dict(
132
+ name=function.__name__,
133
+ description=docs[0].value
134
+ if docs[0].kind is DocstringSectionKind.text else '',
135
+ parameters=[],
136
+ required=[],
137
+ )
138
+ args_doc, returns_doc = {}, []
139
+ for doc in docs:
140
+ if doc.kind is DocstringSectionKind.parameters:
141
+ for d in doc.value:
142
+ d = d.as_dict()
143
+ d['type'] = _detect_type(d.pop('annotation').lower())
144
+ args_doc[d['name']] = d
145
+ if doc.kind is DocstringSectionKind.returns:
146
+ for d in doc.value:
147
+ d = d.as_dict()
148
+ if not d['name']:
149
+ d.pop('name')
150
+ if not d['annotation']:
151
+ d.pop('annotation')
152
+ else:
153
+ d['type'] = _detect_type(d.pop('annotation').lower())
154
+ returns_doc.append(d)
155
+
156
+ sig = inspect.signature(function)
157
+ for name, param in sig.parameters.items():
158
+ if name == 'self':
159
+ continue
160
+ parameter = dict(
161
+ name=param.name,
162
+ type='STRING',
163
+ description=args_doc.get(param.name,
164
+ {}).get('description', ''))
165
+ annotation = param.annotation
166
+ if annotation is inspect.Signature.empty:
167
+ parameter['type'] = args_doc.get(param.name,
168
+ {}).get('type', 'STRING')
169
+ else:
170
+ if get_origin(annotation) is Annotated:
171
+ annotation, info = get_args(annotation)
172
+ if info:
173
+ parameter['description'] = info
174
+ while get_origin(annotation):
175
+ annotation = get_args(annotation)
176
+ parameter['type'] = _detect_type(str(annotation))
177
+ desc['parameters'].append(parameter)
178
+ if param.default is inspect.Signature.empty:
179
+ desc['required'].append(param.name)
180
+
181
+ return_data = []
182
+ if explode_return:
183
+ return_data = _explode(returns_doc[0]['description'])
184
+ elif returns_named_value:
185
+ return_data = returns_doc
186
+ if return_data:
187
+ desc['return_data'] = return_data
188
+ return desc
189
+
190
+ if callable(func):
191
+
192
+ if inspect.iscoroutinefunction(func):
193
+
194
+ @wraps(func)
195
+ async def wrapper(self, *args, **kwargs):
196
+ return await func(self, *args, **kwargs)
197
+
198
+ else:
199
+
200
+ @wraps(func)
201
+ def wrapper(self, *args, **kwargs):
202
+ return func(self, *args, **kwargs)
203
+
204
+ wrapper.api_description = _parse_tool(func)
205
+ return wrapper
206
+
207
+ def decorate(func):
208
+
209
+ if inspect.iscoroutinefunction(func):
210
+
211
+ @wraps(func)
212
+ async def wrapper(self, *args, **kwargs):
213
+ return await func(self, *args, **kwargs)
214
+
215
+ else:
216
+
217
+ @wraps(func)
218
+ def wrapper(self, *args, **kwargs):
219
+ return func(self, *args, **kwargs)
220
+
221
+ wrapper.api_description = _parse_tool(func)
222
+ return wrapper
223
+
224
+ return decorate
225
+
226
+
227
+ class ToolMeta(ABCMeta):
228
+ """Metaclass of tools."""
229
+
230
+ def __new__(mcs, name, base, attrs):
231
+ is_toolkit, tool_desc = True, dict(
232
+ name=name,
233
+ description=Docstring(attrs.get('__doc__',
234
+ '')).parse('google')[0].value)
235
+ for key, value in attrs.items():
236
+ if callable(value) and hasattr(value, 'api_description'):
237
+ api_desc = getattr(value, 'api_description')
238
+ if key == 'run':
239
+ tool_desc['parameters'] = api_desc['parameters']
240
+ tool_desc['required'] = api_desc['required']
241
+ if api_desc['description']:
242
+ tool_desc['description'] = api_desc['description']
243
+ if api_desc.get('return_data'):
244
+ tool_desc['return_data'] = api_desc['return_data']
245
+ is_toolkit = False
246
+ else:
247
+ tool_desc.setdefault('api_list', []).append(api_desc)
248
+ if not is_toolkit and 'api_list' in tool_desc:
249
+ raise KeyError('`run` and other tool APIs can not be implemented '
250
+ 'at the same time')
251
+ if is_toolkit and 'api_list' not in tool_desc:
252
+ is_toolkit = False
253
+ if callable(attrs.get('run')):
254
+ run_api = tool_api(attrs['run'])
255
+ api_desc = run_api.api_description
256
+ tool_desc['parameters'] = api_desc['parameters']
257
+ tool_desc['required'] = api_desc['required']
258
+ if api_desc['description']:
259
+ tool_desc['description'] = api_desc['description']
260
+ if api_desc.get('return_data'):
261
+ tool_desc['return_data'] = api_desc['return_data']
262
+ attrs['run'] = run_api
263
+ else:
264
+ tool_desc['parameters'], tool_desc['required'] = [], []
265
+ attrs['_is_toolkit'] = is_toolkit
266
+ attrs['__tool_description__'] = tool_desc
267
+ return super().__new__(mcs, name, base, attrs)
268
+
269
+
270
+ class BaseAction(metaclass=ToolMeta):
271
+ """Base class for all actions.
272
+
273
+ Args:
274
+ description (:class:`Optional[dict]`): The description of the action.
275
+ Defaults to ``None``.
276
+ parser (:class:`Type[BaseParser]`): The parser class to process the
277
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
278
+
279
+ Examples:
280
+
281
+ * simple tool
282
+
283
+ .. code-block:: python
284
+
285
+ class Bold(BaseAction):
286
+ '''Make text bold'''
287
+
288
+ def run(self, text: str):
289
+ '''
290
+ Args:
291
+ text (str): input text
292
+
293
+ Returns:
294
+ str: bold text
295
+ '''
296
+ return '**' + text + '**'
297
+
298
+ action = Bold()
299
+
300
+ * toolkit with multiple APIs
301
+
302
+ .. code-block:: python
303
+
304
+ class Calculator(BaseAction):
305
+ '''Calculator'''
306
+
307
+ @tool_api
308
+ def add(self, a, b):
309
+ '''Add operation
310
+
311
+ Args:
312
+ a (int): augend
313
+ b (int): addend
314
+
315
+ Returns:
316
+ int: sum
317
+ '''
318
+ return a + b
319
+
320
+ @tool_api
321
+ def sub(self, a, b):
322
+ '''Subtraction operation
323
+
324
+ Args:
325
+ a (int): minuend
326
+ b (int): subtrahend
327
+
328
+ Returns:
329
+ int: difference
330
+ '''
331
+ return a - b
332
+
333
+ action = Calculator()
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ description: Optional[dict] = None,
339
+ parser: Type[BaseParser] = JsonParser,
340
+ ):
341
+ self._description = deepcopy(description or self.__tool_description__)
342
+ self._name = self._description['name']
343
+ self._parser = parser(self)
344
+
345
+ def __call__(self, inputs: str, name='run') -> ActionReturn:
346
+ fallback_args = {'inputs': inputs, 'name': name}
347
+ if not hasattr(self, name):
348
+ return ActionReturn(
349
+ fallback_args,
350
+ type=self.name,
351
+ errmsg=f'invalid API: {name}',
352
+ state=ActionStatusCode.API_ERROR)
353
+ try:
354
+ inputs = self._parser.parse_inputs(inputs, name)
355
+ except ParseError as exc:
356
+ return ActionReturn(
357
+ fallback_args,
358
+ type=self.name,
359
+ errmsg=exc.err_msg,
360
+ state=ActionStatusCode.ARGS_ERROR)
361
+ try:
362
+ outputs = getattr(self, name)(**inputs)
363
+ except Exception as exc:
364
+ return ActionReturn(
365
+ inputs,
366
+ type=self.name,
367
+ errmsg=str(exc),
368
+ state=ActionStatusCode.API_ERROR)
369
+ if isinstance(outputs, ActionReturn):
370
+ action_return = outputs
371
+ if not action_return.args:
372
+ action_return.args = inputs
373
+ if not action_return.type:
374
+ action_return.type = self.name
375
+ else:
376
+ result = self._parser.parse_outputs(outputs)
377
+ action_return = ActionReturn(inputs, type=self.name, result=result)
378
+ return action_return
379
+
380
+ @property
381
+ def name(self):
382
+ return self._name
383
+
384
+ @property
385
+ def is_toolkit(self):
386
+ return self._is_toolkit
387
+
388
+ @property
389
+ def description(self) -> dict:
390
+ """Description of the tool."""
391
+ return self._description
392
+
393
+ def __repr__(self):
394
+ return f'{self.description}'
395
+
396
+ __str__ = __repr__
397
+
398
+
399
+ class AsyncActionMixin:
400
+
401
+ async def __call__(self, inputs: str, name='run') -> ActionReturn:
402
+ fallback_args = {'inputs': inputs, 'name': name}
403
+ if not hasattr(self, name):
404
+ return ActionReturn(
405
+ fallback_args,
406
+ type=self.name,
407
+ errmsg=f'invalid API: {name}',
408
+ state=ActionStatusCode.API_ERROR)
409
+ try:
410
+ inputs = self._parser.parse_inputs(inputs, name)
411
+ except ParseError as exc:
412
+ return ActionReturn(
413
+ fallback_args,
414
+ type=self.name,
415
+ errmsg=exc.err_msg,
416
+ state=ActionStatusCode.ARGS_ERROR)
417
+ try:
418
+ outputs = await getattr(self, name)(**inputs)
419
+ except Exception as exc:
420
+ return ActionReturn(
421
+ inputs,
422
+ type=self.name,
423
+ errmsg=str(exc),
424
+ state=ActionStatusCode.API_ERROR)
425
+ if isinstance(outputs, ActionReturn):
426
+ action_return = outputs
427
+ if not action_return.args:
428
+ action_return.args = inputs
429
+ if not action_return.type:
430
+ action_return.type = self.name
431
+ else:
432
+ result = self._parser.parse_outputs(outputs)
433
+ action_return = ActionReturn(inputs, type=self.name, result=result)
434
+ return action_return
lagent/actions/bing_map.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import json
3
+ import os
4
+ from typing import Optional, Type
5
+
6
+ import aiohttp
7
+ import requests
8
+
9
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
+ from lagent.actions.parser import BaseParser, JsonParser
11
+
12
+
13
+ class BINGMap(BaseAction):
14
+ """BING Map plugin for looking up map information."""
15
+
16
+ def __init__(
17
+ self,
18
+ key: Optional[str] = None,
19
+ description: Optional[dict] = None,
20
+ parser: Type[BaseParser] = JsonParser,
21
+ ) -> None:
22
+ super().__init__(description, parser)
23
+ key = os.environ.get('BING_MAP_KEY', key)
24
+ if key is None:
25
+ raise ValueError(
26
+ 'Please set BING Map API key either in the environment '
27
+ 'as BING_MAP_KEY or pass it as `key` parameter.')
28
+ self.key = key
29
+ self.base_url = 'http://dev.virtualearth.net/REST/V1/'
30
+
31
+ @tool_api(explode_return=True)
32
+ def get_distance(self, start: str, end: str) -> dict:
33
+ """Get the distance between two locations in km.
34
+
35
+ Args:
36
+ start (:class:`str`): The start location
37
+ end (:class:`str`): The end location
38
+
39
+ Returns:
40
+ :class:`dict`: distance information
41
+ * distance (str): the distance in km.
42
+ """
43
+ # Request URL
44
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
45
+ # GET request
46
+ r = requests.get(url)
47
+ # TODO check request status?
48
+ data = json.loads(r.text)
49
+ # Extract route information
50
+ route = data['resourceSets'][0]['resources'][0]
51
+ # Extract distance in miles
52
+ distance = route['travelDistance']
53
+ return dict(distance=distance)
54
+
55
+ @tool_api(explode_return=True)
56
+ def get_route(self, start: str, end: str) -> dict:
57
+ """Get the route between two locations in km.
58
+
59
+ Args:
60
+ start (:class:`str`): The start location
61
+ end (:class:`str`): The end location
62
+
63
+ Returns:
64
+ :class:`dict`: route information
65
+ * route (list): the route, a list of actions.
66
+ """
67
+ # Request URL
68
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
69
+ # GET request
70
+ r = requests.get(url)
71
+ data = json.loads(r.text)
72
+ # Extract route information
73
+ route = data['resourceSets'][0]['resources'][0]
74
+ itinerary = route['routeLegs'][0]['itineraryItems']
75
+ # Extract route text information
76
+ route_text = []
77
+ for item in itinerary:
78
+ if 'instruction' in item:
79
+ route_text.append(item['instruction']['text'])
80
+ return dict(route=route_text)
81
+
82
+ @tool_api(explode_return=True)
83
+ def get_coordinates(self, location: str) -> dict:
84
+ """Get the coordinates of a location.
85
+
86
+ Args:
87
+ location (:class:`str`): the location need to get coordinates.
88
+
89
+ Returns:
90
+ :class:`dict`: coordinates information
91
+ * latitude (float): the latitude of the location.
92
+ * longitude (float): the longitude of the location.
93
+ """
94
+ url = self.base_url + 'Locations'
95
+ params = {'query': location, 'key': self.key}
96
+ response = requests.get(url, params=params)
97
+ json_data = response.json()
98
+ coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
99
+ 'coordinates']
100
+ return dict(latitude=coordinates[0], longitude=coordinates[1])
101
+
102
+ @tool_api(explode_return=True)
103
+ def search_nearby(self,
104
+ search_term: str,
105
+ places: str = 'unknown',
106
+ latitude: float = 0.0,
107
+ longitude: float = 0.0,
108
+ radius: int = 5000) -> dict:
109
+ """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
110
+
111
+ Args:
112
+ search_term (:class:`str`): the place name.
113
+ places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
114
+ latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
115
+ longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
116
+ radius (:class:`int`): radius in meters. Defaults to ``5000``.
117
+
118
+ Returns:
119
+ :class:`dict`: places information
120
+ * places (list): the list of places, each place is a dict with name and address, at most 5 places.
121
+ """
122
+ url = self.base_url + 'LocalSearch'
123
+ if places != 'unknown':
124
+ pos = self.get_coordinates(**{'location': places})
125
+ latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
126
+ # Build the request query string
127
+ params = {
128
+ 'query': search_term,
129
+ 'userLocation': f'{latitude},{longitude}',
130
+ 'radius': radius,
131
+ 'key': self.key
132
+ }
133
+ # Make the request
134
+ response = requests.get(url, params=params)
135
+ # Parse the response
136
+ response_data = json.loads(response.content)
137
+ # Get the results
138
+ results = response_data['resourceSets'][0]['resources']
139
+ addresses = []
140
+ for result in results:
141
+ name = result['name']
142
+ address = result['Address']['formattedAddress']
143
+ addresses.append(dict(name=name, address=address))
144
+ if len(addresses) == 5:
145
+ break
146
+ return dict(place=addresses)
147
+
148
+
149
+ class AsyncBINGMap(AsyncActionMixin, BINGMap):
150
+ """BING Map plugin for looking up map information."""
151
+
152
+ @tool_api(explode_return=True)
153
+ async def get_distance(self, start: str, end: str) -> dict:
154
+ """Get the distance between two locations in km.
155
+
156
+ Args:
157
+ start (:class:`str`): The start location
158
+ end (:class:`str`): The end location
159
+
160
+ Returns:
161
+ :class:`dict`: distance information
162
+ * distance (str): the distance in km.
163
+ """
164
+ # Request URL
165
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
166
+ # GET request
167
+ async with aiohttp.ClientSession() as session:
168
+ async with session.get(url) as resp:
169
+ # TODO check request status?
170
+ data = await resp.json()
171
+ # Extract route information
172
+ route = data['resourceSets'][0]['resources'][0]
173
+ # Extract distance in miles
174
+ distance = route['travelDistance']
175
+ return dict(distance=distance)
176
+
177
+ @tool_api(explode_return=True)
178
+ async def get_route(self, start: str, end: str) -> dict:
179
+ """Get the route between two locations in km.
180
+
181
+ Args:
182
+ start (:class:`str`): The start location
183
+ end (:class:`str`): The end location
184
+
185
+ Returns:
186
+ :class:`dict`: route information
187
+ * route (list): the route, a list of actions.
188
+ """
189
+ # Request URL
190
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
191
+ # GET request
192
+ async with aiohttp.ClientSession() as session:
193
+ async with session.get(url) as resp:
194
+ data = await resp.json()
195
+ # Extract route information
196
+ route = data['resourceSets'][0]['resources'][0]
197
+ itinerary = route['routeLegs'][0]['itineraryItems']
198
+ # Extract route text information
199
+ route_text = []
200
+ for item in itinerary:
201
+ if 'instruction' in item:
202
+ route_text.append(item['instruction']['text'])
203
+ return dict(route=route_text)
204
+
205
+ @tool_api(explode_return=True)
206
+ async def get_coordinates(self, location: str) -> dict:
207
+ """Get the coordinates of a location.
208
+
209
+ Args:
210
+ location (:class:`str`): the location need to get coordinates.
211
+
212
+ Returns:
213
+ :class:`dict`: coordinates information
214
+ * latitude (float): the latitude of the location.
215
+ * longitude (float): the longitude of the location.
216
+ """
217
+ url = self.base_url + 'Locations'
218
+ params = {'query': location, 'key': self.key}
219
+ async with aiohttp.ClientSession() as session:
220
+ async with session.get(url, params=params) as resp:
221
+ data = await resp.json()
222
+ coordinates = data['resourceSets'][0]['resources'][0]['point'][
223
+ 'coordinates']
224
+ return dict(latitude=coordinates[0], longitude=coordinates[1])
225
+
226
+ @tool_api(explode_return=True)
227
+ async def search_nearby(self,
228
+ search_term: str,
229
+ places: str = 'unknown',
230
+ latitude: float = 0.0,
231
+ longitude: float = 0.0,
232
+ radius: int = 5000) -> dict:
233
+ """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
234
+
235
+ Args:
236
+ search_term (:class:`str`): the place name.
237
+ places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
238
+ latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
239
+ longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
240
+ radius (:class:`int`): radius in meters. Defaults to ``5000``.
241
+
242
+ Returns:
243
+ :class:`dict`: places information
244
+ * places (list): the list of places, each place is a dict with name and address, at most 5 places.
245
+ """
246
+ url = self.base_url + 'LocalSearch'
247
+ if places != 'unknown':
248
+ pos = self.get_coordinates(**{'location': places})
249
+ latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
250
+ # Build the request query string
251
+ params = {
252
+ 'query': search_term,
253
+ 'userLocation': f'{latitude},{longitude}',
254
+ 'radius': radius,
255
+ 'key': self.key
256
+ }
257
+ async with aiohttp.ClientSession() as session:
258
+ async with session.get(url, params=params) as resp:
259
+ data = await resp.json()
260
+ results = data['resourceSets'][0]['resources']
261
+ addresses = []
262
+ for result in results:
263
+ name = result['name']
264
+ address = result['Address']['formattedAddress']
265
+ addresses.append(dict(name=name, address=address))
266
+ if len(addresses) == 5:
267
+ break
268
+ return dict(place=addresses)
lagent/actions/builtin_actions.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from lagent.actions.base_action import BaseAction, tool_api
4
+ from lagent.actions.parser import BaseParser
5
+ from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
6
+
7
+
8
+ class InvalidAction(BaseAction):
9
+ """This is a invalid action class, which is used to return error message
10
+ when the action is invalid.
11
+
12
+ Args:
13
+ err_msg (str): The error message. Defaults to 'The action is invalid,
14
+ please check the action name'.
15
+
16
+ Returns:
17
+ ActionReturn: The action return.
18
+ """
19
+
20
+ def __init__(self,
21
+ err_msg:
22
+ str = 'The action is invalid, please check the action name.',
23
+ description: Optional[dict] = None,
24
+ parser=BaseParser) -> None:
25
+ super().__init__(description, parser)
26
+ self._err_msg = err_msg
27
+
28
+ @tool_api
29
+ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
30
+ """Return the error message.
31
+
32
+ Args:
33
+ err_msg (str, optional): The error message. If err_msg is not None,
34
+ it will be returned, otherwise the default error message will
35
+ be returned. Defaults to None.
36
+ """
37
+ action_return = ActionReturn(
38
+ url=None,
39
+ args=dict(text=err_msg),
40
+ errmsg=err_msg or self._err_msg,
41
+ type=self.name,
42
+ valid=ActionValidCode.INVALID,
43
+ state=ActionStatusCode.API_ERROR)
44
+ return action_return
45
+
46
+
47
+ class NoAction(BaseAction):
48
+ """This is a no action class, which is used to return error message when
49
+ the response does not follow the format.
50
+
51
+ Args:
52
+ err_msg (str): The error message. Defaults to
53
+ 'Please follow the format'.
54
+ """
55
+
56
+ def __init__(self,
57
+ err_msg: str = 'Please follow the format',
58
+ description: Optional[dict] = None,
59
+ parser=BaseParser):
60
+ super().__init__(description, parser)
61
+ self._err_msg = err_msg
62
+
63
+ @tool_api
64
+ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
65
+ """Return the error message.
66
+
67
+ Args:
68
+ err_msg (str, optional): The error message. If err_msg is not None,
69
+ it will be returned, otherwise the default error message will
70
+ be returned. Defaults to None.
71
+
72
+ Returns:
73
+ ActionReturn: The action return.
74
+ """
75
+ action_return = ActionReturn(
76
+ url=None,
77
+ args=dict(text=err_msg),
78
+ type=self.name,
79
+ errmsg=err_msg or self._err_msg,
80
+ valid=ActionValidCode.INVALID,
81
+ state=ActionStatusCode.API_ERROR)
82
+ return action_return
83
+
84
+
85
+ class FinishAction(BaseAction):
86
+ """This is a finish action class, which is used to return the final
87
+ result."""
88
+
89
+ def __init__(self, description: Optional[dict] = None, parser=BaseParser):
90
+ super().__init__(description, parser)
91
+
92
+ @tool_api
93
+ def run(self, response: str) -> ActionReturn:
94
+ """Return the final result.
95
+
96
+ Args:
97
+ response (str): The final result.
98
+
99
+ Returns:
100
+ ActionReturn: The action return.
101
+ """
102
+ action_return = ActionReturn(
103
+ url=None,
104
+ args=dict(text=response),
105
+ result=[dict(type='text', content=response)],
106
+ type=self.name,
107
+ valid=ActionValidCode.FINISH,
108
+ state=ActionStatusCode.SUCCESS)
109
+ return action_return
lagent/actions/google_scholar_search.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import os
3
+ from typing import Optional, Type
4
+
5
+ from asyncer import asyncify
6
+
7
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
8
+ from lagent.schema import ActionReturn, ActionStatusCode
9
+ from .parser import BaseParser, JsonParser
10
+
11
+
12
+ class GoogleScholar(BaseAction):
13
+ """Plugin for google scholar search.
14
+
15
+ Args:
16
+ api_key (str): API KEY to use serper google search API,
17
+ You can create a free API key at https://serper.dev.
18
+ description (dict): The description of the action. Defaults to ``None``.
19
+ parser (Type[BaseParser]): The parser class to process the
20
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ api_key: Optional[str] = None,
26
+ description: Optional[dict] = None,
27
+ parser: Type[BaseParser] = JsonParser,
28
+ ):
29
+ super().__init__(description, parser)
30
+ api_key = os.environ.get('SERPER_API_KEY', api_key)
31
+ if api_key is None:
32
+ raise ValueError(
33
+ 'Please set Serper API key either in the environment '
34
+ 'as SERPER_API_KEY or pass it as `api_key` parameter.'
35
+ )
36
+ self.api_key = api_key
37
+
38
+ @tool_api(explode_return=True)
39
+ def search_google_scholar(
40
+ self,
41
+ query: str,
42
+ cites: Optional[str] = None,
43
+ as_ylo: Optional[int] = None,
44
+ as_yhi: Optional[int] = None,
45
+ scisbd: Optional[int] = None,
46
+ cluster: Optional[str] = None,
47
+ hl: Optional[str] = None,
48
+ lr: Optional[str] = None,
49
+ start: Optional[int] = None,
50
+ num: Optional[int] = None,
51
+ as_sdt: Optional[str] = None,
52
+ safe: Optional[str] = None,
53
+ filter: Optional[str] = None,
54
+ as_vis: Optional[str] = None,
55
+ ) -> dict:
56
+ """Search for scholarly articles based on a query according to the google scholar.
57
+
58
+ Args:
59
+ query (str): The query to search for.
60
+ cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
61
+ as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
62
+ as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
63
+ scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
64
+ cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
65
+ hl (Optional[str]): The language to use for the Google Scholar search.
66
+ lr (Optional[str]): One or multiple languages to limit the search to.
67
+ start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
68
+ num (Optional[int]): The maximum number of results to return, limited to 20.
69
+ as_sdt (Optional[str]): Can be used either as a search type or a filter.
70
+ safe (Optional[str]): The level of filtering for adult content.
71
+ filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
72
+ as_vis (Optional[str]): Defines whether to include citations or not.
73
+
74
+ Returns:
75
+ :class:`dict`: article information
76
+ - title: a list of the titles of the three selected papers
77
+ - cited_by: a list of the citation numbers of the three selected papers
78
+ - organic_id: a list of the organic results' ids of the three selected papers
79
+ - pub_info: publication information of selected papers
80
+ """
81
+ from serpapi import GoogleSearch
82
+
83
+ params = {
84
+ 'q': query,
85
+ 'engine': 'google_scholar',
86
+ 'api_key': self.api_key,
87
+ 'cites': cites,
88
+ 'as_ylo': as_ylo,
89
+ 'as_yhi': as_yhi,
90
+ 'scisbd': scisbd,
91
+ 'cluster': cluster,
92
+ 'hl': hl,
93
+ 'lr': lr,
94
+ 'start': start,
95
+ 'num': num,
96
+ 'as_sdt': as_sdt,
97
+ 'safe': safe,
98
+ 'filter': filter,
99
+ 'as_vis': as_vis,
100
+ }
101
+ search = GoogleSearch(params)
102
+ try:
103
+ r = search.get_dict()
104
+ results = r['organic_results']
105
+ title = []
106
+ snippets = []
107
+ cited_by = []
108
+ organic_id = []
109
+ pub_info = []
110
+ for item in results[:3]:
111
+ title.append(item['title'])
112
+ pub_info.append(item['publication_info']['summary'])
113
+ citation = item['inline_links'].get('cited_by', {'total': ''})
114
+ cited_by.append(citation['total'])
115
+ snippets.append(item['snippet'])
116
+ organic_id.append(item['result_id'])
117
+ return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
118
+ except Exception as e:
119
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
120
+
121
+ @tool_api(explode_return=True)
122
+ def get_author_information(
123
+ self,
124
+ author_id: str,
125
+ hl: Optional[str] = None,
126
+ view_op: Optional[str] = None,
127
+ sort: Optional[str] = None,
128
+ citation_id: Optional[str] = None,
129
+ start: Optional[int] = None,
130
+ num: Optional[int] = None,
131
+ no_cache: Optional[bool] = None,
132
+ async_req: Optional[bool] = None,
133
+ output: Optional[str] = None,
134
+ ) -> dict:
135
+ """Search for an author's information by author's id provided by get_author_id.
136
+
137
+ Args:
138
+ author_id (str): Required. The ID of an author.
139
+ hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
140
+ view_op (Optional[str]): Used for viewing specific parts of a page.
141
+ sort (Optional[str]): Used for sorting and refining articles.
142
+ citation_id (Optional[str]): Used for retrieving individual article citation.
143
+ start (Optional[int]): Defines the result offset. Default is 0.
144
+ num (Optional[int]): Defines the number of results to return. Default is 20.
145
+ no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
146
+ async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
147
+ output (Optional[str]): Defines the final output you want. Default is 'json'.
148
+
149
+ Returns:
150
+ :class:`dict`: author information
151
+ * name: author's name
152
+ * affliation: the affliation of the author
153
+ * articles: at most 3 articles by the author
154
+ * website: the author's homepage url
155
+ """
156
+ from serpapi import GoogleSearch
157
+
158
+ params = {
159
+ 'engine': 'google_scholar_author',
160
+ 'author_id': author_id,
161
+ 'api_key': self.api_key,
162
+ 'hl': hl,
163
+ 'view_op': view_op,
164
+ 'sort': sort,
165
+ 'citation_id': citation_id,
166
+ 'start': start,
167
+ 'num': num,
168
+ 'no_cache': no_cache,
169
+ 'async': async_req,
170
+ 'output': output,
171
+ }
172
+ try:
173
+ search = GoogleSearch(params)
174
+ results = search.get_dict()
175
+ author = results['author']
176
+ articles = results.get('articles', [])
177
+ return dict(
178
+ name=author['name'],
179
+ affiliations=author.get('affiliations', ''),
180
+ website=author.get('website', ''),
181
+ articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
182
+ )
183
+ except Exception as e:
184
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
185
+
186
+ @tool_api(explode_return=True)
187
+ def get_citation_format(
188
+ self,
189
+ q: str,
190
+ no_cache: Optional[bool] = None,
191
+ async_: Optional[bool] = None,
192
+ output: Optional[str] = 'json',
193
+ ) -> dict:
194
+ """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
195
+
196
+ Args:
197
+ q (str): ID of an individual Google Scholar organic search result.
198
+ no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
199
+ async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
200
+ output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
201
+
202
+ Returns:
203
+ :class:`dict`: citation format
204
+ * authors: the authors of the article
205
+ * citation: the citation format of the article
206
+ """
207
+ from serpapi import GoogleSearch
208
+
209
+ params = {
210
+ 'q': q,
211
+ 'engine': 'google_scholar_cite',
212
+ 'api_key': self.api_key,
213
+ 'no_cache': no_cache,
214
+ 'async': async_,
215
+ 'output': output,
216
+ }
217
+ try:
218
+ search = GoogleSearch(params)
219
+ results = search.get_dict()
220
+ citation = results['citations']
221
+ citation_info = citation[0]['snippet']
222
+ return citation_info
223
+ except Exception as e:
224
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
225
+
226
+ @tool_api(explode_return=True)
227
+ def get_author_id(
228
+ self,
229
+ mauthors: str,
230
+ hl: Optional[str] = 'en',
231
+ after_author: Optional[str] = None,
232
+ before_author: Optional[str] = None,
233
+ no_cache: Optional[bool] = False,
234
+ _async: Optional[bool] = False,
235
+ output: Optional[str] = 'json',
236
+ ) -> dict:
237
+ """The getAuthorId function is used to get the author's id by his or her name.
238
+
239
+ Args:
240
+ mauthors (str): Defines the author you want to search for.
241
+ hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
242
+ after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
243
+ before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
244
+ no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
245
+ _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
246
+ output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
247
+
248
+ Returns:
249
+ :class:`dict`: author id
250
+ * author_id: the author_id of the author
251
+ """
252
+ from serpapi import GoogleSearch
253
+
254
+ params = {
255
+ 'mauthors': mauthors,
256
+ 'engine': 'google_scholar_profiles',
257
+ 'api_key': self.api_key,
258
+ 'hl': hl,
259
+ 'after_author': after_author,
260
+ 'before_author': before_author,
261
+ 'no_cache': no_cache,
262
+ 'async': _async,
263
+ 'output': output,
264
+ }
265
+ try:
266
+ search = GoogleSearch(params)
267
+ results = search.get_dict()
268
+ profile = results['profiles']
269
+ author_info = dict(author_id=profile[0]['author_id'])
270
+ return author_info
271
+ except Exception as e:
272
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
273
+
274
+
275
+ class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
276
+ """Plugin for google scholar search.
277
+
278
+ Args:
279
+ api_key (str): API KEY to use serper google search API,
280
+ You can create a free API key at https://serper.dev.
281
+ description (dict): The description of the action. Defaults to ``None``.
282
+ parser (Type[BaseParser]): The parser class to process the
283
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
284
+ """
285
+
286
+ @tool_api(explode_return=True)
287
+ @asyncify
288
+ def search_google_scholar(
289
+ self,
290
+ query: str,
291
+ cites: Optional[str] = None,
292
+ as_ylo: Optional[int] = None,
293
+ as_yhi: Optional[int] = None,
294
+ scisbd: Optional[int] = None,
295
+ cluster: Optional[str] = None,
296
+ hl: Optional[str] = None,
297
+ lr: Optional[str] = None,
298
+ start: Optional[int] = None,
299
+ num: Optional[int] = None,
300
+ as_sdt: Optional[str] = None,
301
+ safe: Optional[str] = None,
302
+ filter: Optional[str] = None,
303
+ as_vis: Optional[str] = None,
304
+ ) -> dict:
305
+ """Search for scholarly articles based on a query according to the google scholar.
306
+
307
+ Args:
308
+ query (str): The query to search for.
309
+ cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
310
+ as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
311
+ as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
312
+ scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
313
+ cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
314
+ hl (Optional[str]): The language to use for the Google Scholar search.
315
+ lr (Optional[str]): One or multiple languages to limit the search to.
316
+ start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
317
+ num (Optional[int]): The maximum number of results to return, limited to 20.
318
+ as_sdt (Optional[str]): Can be used either as a search type or a filter.
319
+ safe (Optional[str]): The level of filtering for adult content.
320
+ filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
321
+ as_vis (Optional[str]): Defines whether to include citations or not.
322
+
323
+ Returns:
324
+ :class:`dict`: article information
325
+ - title: a list of the titles of the three selected papers
326
+ - cited_by: a list of the citation numbers of the three selected papers
327
+ - organic_id: a list of the organic results' ids of the three selected papers
328
+ - pub_info: publication information of selected papers
329
+ """
330
+ return super().search_google_scholar(
331
+ query,
332
+ cites,
333
+ as_ylo,
334
+ as_yhi,
335
+ scisbd,
336
+ cluster,
337
+ hl,
338
+ lr,
339
+ start,
340
+ num,
341
+ as_sdt,
342
+ safe,
343
+ filter,
344
+ as_vis,
345
+ )
346
+
347
+ @tool_api(explode_return=True)
348
+ @asyncify
349
+ def get_author_information(
350
+ self,
351
+ author_id: str,
352
+ hl: Optional[str] = None,
353
+ view_op: Optional[str] = None,
354
+ sort: Optional[str] = None,
355
+ citation_id: Optional[str] = None,
356
+ start: Optional[int] = None,
357
+ num: Optional[int] = None,
358
+ no_cache: Optional[bool] = None,
359
+ async_req: Optional[bool] = None,
360
+ output: Optional[str] = None,
361
+ ) -> dict:
362
+ """Search for an author's information by author's id provided by get_author_id.
363
+
364
+ Args:
365
+ author_id (str): Required. The ID of an author.
366
+ hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
367
+ view_op (Optional[str]): Used for viewing specific parts of a page.
368
+ sort (Optional[str]): Used for sorting and refining articles.
369
+ citation_id (Optional[str]): Used for retrieving individual article citation.
370
+ start (Optional[int]): Defines the result offset. Default is 0.
371
+ num (Optional[int]): Defines the number of results to return. Default is 20.
372
+ no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
373
+ async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
374
+ output (Optional[str]): Defines the final output you want. Default is 'json'.
375
+
376
+ Returns:
377
+ :class:`dict`: author information
378
+ * name: author's name
379
+ * affliation: the affliation of the author
380
+ * articles: at most 3 articles by the author
381
+ * website: the author's homepage url
382
+ """
383
+ return super().get_author_information(
384
+ author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
385
+ )
386
+
387
+ @tool_api(explode_return=True)
388
+ @asyncify
389
+ def get_citation_format(
390
+ self,
391
+ q: str,
392
+ no_cache: Optional[bool] = None,
393
+ async_: Optional[bool] = None,
394
+ output: Optional[str] = 'json',
395
+ ) -> dict:
396
+ """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
397
+
398
+ Args:
399
+ q (str): ID of an individual Google Scholar organic search result.
400
+ no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
401
+ async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
402
+ output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
403
+
404
+ Returns:
405
+ :class:`dict`: citation format
406
+ * authors: the authors of the article
407
+ * citation: the citation format of the article
408
+ """
409
+ return super().get_citation_format(q, no_cache, async_, output)
410
+
411
+ @tool_api(explode_return=True)
412
+ @asyncify
413
+ def get_author_id(
414
+ self,
415
+ mauthors: str,
416
+ hl: Optional[str] = 'en',
417
+ after_author: Optional[str] = None,
418
+ before_author: Optional[str] = None,
419
+ no_cache: Optional[bool] = False,
420
+ _async: Optional[bool] = False,
421
+ output: Optional[str] = 'json',
422
+ ) -> dict:
423
+ """The getAuthorId function is used to get the author's id by his or her name.
424
+
425
+ Args:
426
+ mauthors (str): Defines the author you want to search for.
427
+ hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
428
+ after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
429
+ before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
430
+ no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
431
+ _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
432
+ output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
433
+
434
+ Returns:
435
+ :class:`dict`: author id
436
+ * author_id: the author_id of the author
437
+ """
438
+ return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
lagent/actions/google_search.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple, Type, Union
3
+
4
+ import aiohttp
5
+ import requests
6
+
7
+ from lagent.schema import ActionReturn, ActionStatusCode
8
+ from .base_action import AsyncActionMixin, BaseAction, tool_api
9
+ from .parser import BaseParser, JsonParser
10
+
11
+
12
+ class GoogleSearch(BaseAction):
13
+ """Wrapper around the Serper.dev Google Search API.
14
+
15
+ To use, you should pass your serper API key to the constructor.
16
+
17
+ Code is modified from lang-chain GoogleSerperAPIWrapper
18
+ (https://github.com/langchain-ai/langchain/blob/ba5f
19
+ baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
20
+ langchain/utilities/google_serper.py)
21
+
22
+ Args:
23
+ api_key (str): API KEY to use serper google search API,
24
+ You can create a free API key at https://serper.dev.
25
+ timeout (int): Upper bound of waiting time for a serper request.
26
+ search_type (str): Serper API support ['search', 'images', 'news',
27
+ 'places'] types of search, currently we only support 'search'.
28
+ description (dict): The description of the action. Defaults to ``None``.
29
+ parser (Type[BaseParser]): The parser class to process the
30
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
31
+ """
32
+ result_key_for_type = {
33
+ 'news': 'news',
34
+ 'places': 'places',
35
+ 'images': 'images',
36
+ 'search': 'organic',
37
+ }
38
+
39
+ def __init__(
40
+ self,
41
+ api_key: Optional[str] = None,
42
+ timeout: int = 5,
43
+ search_type: str = 'search',
44
+ description: Optional[dict] = None,
45
+ parser: Type[BaseParser] = JsonParser,
46
+ ):
47
+ super().__init__(description, parser)
48
+ api_key = os.environ.get('SERPER_API_KEY', api_key)
49
+ if api_key is None:
50
+ raise ValueError(
51
+ 'Please set Serper API key either in the environment '
52
+ 'as SERPER_API_KEY or pass it as `api_key` parameter.')
53
+ self.api_key = api_key
54
+ self.timeout = timeout
55
+ self.search_type = search_type
56
+
57
+ @tool_api
58
+ def run(self, query: str, k: int = 10) -> ActionReturn:
59
+ """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
60
+
61
+ Args:
62
+ query (str): the search content
63
+ k (int): select first k results in the search results as response
64
+ """
65
+ tool_return = ActionReturn(type=self.name)
66
+ status_code, response = self._search(query, k=k)
67
+ # convert search results to ToolReturn format
68
+ if status_code == -1:
69
+ tool_return.errmsg = response
70
+ tool_return.state = ActionStatusCode.HTTP_ERROR
71
+ elif status_code == 200:
72
+ parsed_res = self._parse_results(response, k)
73
+ tool_return.result = [dict(type='text', content=str(parsed_res))]
74
+ tool_return.state = ActionStatusCode.SUCCESS
75
+ else:
76
+ tool_return.errmsg = str(status_code)
77
+ tool_return.state = ActionStatusCode.API_ERROR
78
+ return tool_return
79
+
80
+ def _parse_results(self, results: dict, k: int) -> Union[str, List[str]]:
81
+ """Parse the search results from Serper API.
82
+
83
+ Args:
84
+ results (dict): The search content from Serper API
85
+ in json format.
86
+
87
+ Returns:
88
+ List[str]: The parsed search results.
89
+ """
90
+
91
+ snippets = []
92
+
93
+ if results.get('answerBox'):
94
+ answer_box = results.get('answerBox', {})
95
+ if answer_box.get('answer'):
96
+ return [answer_box.get('answer')]
97
+ elif answer_box.get('snippet'):
98
+ return [answer_box.get('snippet').replace('\n', ' ')]
99
+ elif answer_box.get('snippetHighlighted'):
100
+ return answer_box.get('snippetHighlighted')
101
+
102
+ if results.get('knowledgeGraph'):
103
+ kg = results.get('knowledgeGraph', {})
104
+ title = kg.get('title')
105
+ entity_type = kg.get('type')
106
+ if entity_type:
107
+ snippets.append(f'{title}: {entity_type}.')
108
+ description = kg.get('description')
109
+ if description:
110
+ snippets.append(description)
111
+ for attribute, value in kg.get('attributes', {}).items():
112
+ snippets.append(f'{title} {attribute}: {value}.')
113
+
114
+ for result in results[self.result_key_for_type[
115
+ self.search_type]][:k]:
116
+ if 'snippet' in result:
117
+ snippets.append(result['snippet'])
118
+ for attribute, value in result.get('attributes', {}).items():
119
+ snippets.append(f'{attribute}: {value}.')
120
+
121
+ if len(snippets) == 0:
122
+ return ['No good Google Search Result was found']
123
+ return snippets
124
+
125
+ def _search(self,
126
+ search_term: str,
127
+ search_type: Optional[str] = None,
128
+ **kwargs) -> Tuple[int, Union[dict, str]]:
129
+ """HTTP requests to Serper API.
130
+
131
+ Args:
132
+ search_term (str): The search query.
133
+ search_type (str): search type supported by Serper API,
134
+ default to 'search'.
135
+
136
+ Returns:
137
+ tuple: the return value is a tuple contains:
138
+ - status_code (int): HTTP status code from Serper API.
139
+ - response (dict): response context with json format.
140
+ """
141
+ headers = {
142
+ 'X-API-KEY': self.api_key or '',
143
+ 'Content-Type': 'application/json',
144
+ }
145
+ params = {
146
+ 'q': search_term,
147
+ **{
148
+ key: value
149
+ for key, value in kwargs.items() if value is not None
150
+ },
151
+ }
152
+ try:
153
+ response = requests.post(
154
+ f'https://google.serper.dev/{search_type or self.search_type}',
155
+ headers=headers,
156
+ params=params,
157
+ timeout=self.timeout)
158
+ except Exception as e:
159
+ return -1, str(e)
160
+ return response.status_code, response.json()
161
+
162
+
163
+ class AsyncGoogleSearch(AsyncActionMixin, GoogleSearch):
164
+ """Wrapper around the Serper.dev Google Search API.
165
+
166
+ To use, you should pass your serper API key to the constructor.
167
+
168
+ Code is modified from lang-chain GoogleSerperAPIWrapper
169
+ (https://github.com/langchain-ai/langchain/blob/ba5f
170
+ baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
171
+ langchain/utilities/google_serper.py)
172
+
173
+ Args:
174
+ api_key (str): API KEY to use serper google search API,
175
+ You can create a free API key at https://serper.dev.
176
+ timeout (int): Upper bound of waiting time for a serper request.
177
+ search_type (str): Serper API support ['search', 'images', 'news',
178
+ 'places'] types of search, currently we only support 'search'.
179
+ description (dict): The description of the action. Defaults to ``None``.
180
+ parser (Type[BaseParser]): The parser class to process the
181
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
182
+ """
183
+
184
+ @tool_api
185
+ async def run(self, query: str, k: int = 10) -> ActionReturn:
186
+ """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
187
+
188
+ Args:
189
+ query (str): the search content
190
+ k (int): select first k results in the search results as response
191
+ """
192
+ tool_return = ActionReturn(type=self.name)
193
+ status_code, response = await self._search(query, k=k)
194
+ # convert search results to ToolReturn format
195
+ if status_code == -1:
196
+ tool_return.errmsg = response
197
+ tool_return.state = ActionStatusCode.HTTP_ERROR
198
+ elif status_code == 200:
199
+ parsed_res = self._parse_results(response)
200
+ tool_return.result = [dict(type='text', content=str(parsed_res))]
201
+ tool_return.state = ActionStatusCode.SUCCESS
202
+ else:
203
+ tool_return.errmsg = str(status_code)
204
+ tool_return.state = ActionStatusCode.API_ERROR
205
+ return tool_return
206
+
207
+ async def _search(self,
208
+ search_term: str,
209
+ search_type: Optional[str] = None,
210
+ **kwargs) -> Tuple[int, Union[dict, str]]:
211
+ """HTTP requests to Serper API.
212
+
213
+ Args:
214
+ search_term (str): The search query.
215
+ search_type (str): search type supported by Serper API,
216
+ default to 'search'.
217
+
218
+ Returns:
219
+ tuple: the return value is a tuple contains:
220
+ - status_code (int): HTTP status code from Serper API.
221
+ - response (dict): response context with json format.
222
+ """
223
+ headers = {
224
+ 'X-API-KEY': self.api_key or '',
225
+ 'Content-Type': 'application/json',
226
+ }
227
+ params = {
228
+ 'q': search_term,
229
+ **{
230
+ key: value
231
+ for key, value in kwargs.items() if value is not None
232
+ },
233
+ }
234
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
235
+ async with aiohttp.ClientSession(timeout=timeout) as session:
236
+ try:
237
+ async with session.post(
238
+ f'https://google.serper.dev/{search_type or self.search_type}',
239
+ headers=headers,
240
+ params=params) as resp:
241
+ code, ret = resp.status, await resp.json()
242
+ except aiohttp.ClientError as e:
243
+ code, ret = -1, str(e)
244
+ return code, ret
lagent/actions/ipython_interactive.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import signal
3
+ from contextlib import contextmanager, redirect_stdout
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from io import StringIO
7
+ from typing import Optional, Type
8
+
9
+ from ..schema import ActionReturn, ActionStatusCode
10
+ from .base_action import AsyncActionMixin, BaseAction, tool_api
11
+ from .parser import BaseParser, JsonParser
12
+
13
+
14
+ class Status(str, Enum):
15
+ """Execution status."""
16
+ SUCCESS = 'success'
17
+ FAILURE = 'failure'
18
+
19
+
20
+ @dataclass
21
+ class ExecutionResult:
22
+ """Execution result."""
23
+ status: Status
24
+ value: Optional[str] = None
25
+ msg: Optional[str] = None
26
+
27
+
28
+ @contextmanager
29
+ def _raise_timeout(timeout):
30
+
31
+ def _handler(signum, frame):
32
+ raise TimeoutError()
33
+
34
+ signal.signal(signal.SIGALRM, _handler)
35
+ signal.alarm(timeout)
36
+
37
+ try:
38
+ yield
39
+ finally:
40
+ signal.alarm(0)
41
+
42
+
43
+ class IPythonInteractive(BaseAction):
44
+ """An interactive IPython shell for code execution.
45
+
46
+ Args:
47
+ timeout (int): Upper bound of waiting time for Python script execution.
48
+ Defaults to ``20``.
49
+ max_out_len (int): maximum output length. No truncation occurs if negative.
50
+ Defaults to ``2048``.
51
+ use_signals (bool): whether signals should be used for timing function out
52
+ or the multiprocessing. Set to ``False`` when not running in the main
53
+ thread, e.g. web applications. Defaults to ``True``
54
+ description (dict): The description of the action. Defaults to ``None``.
55
+ parser (Type[BaseParser]): The parser class to process the
56
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ timeout: int = 30,
62
+ max_out_len: int = 8192,
63
+ use_signals: bool = True,
64
+ description: Optional[dict] = None,
65
+ parser: Type[BaseParser] = JsonParser,
66
+ ):
67
+ super().__init__(description, parser)
68
+ self.timeout = timeout
69
+ self._executor = self.create_shell()
70
+ self._highlighting = re.compile(
71
+ r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
72
+ self._max_out_len = max_out_len if max_out_len >= 0 else None
73
+ self._use_signals = use_signals
74
+
75
+ def reset(self):
76
+ """Clear the context."""
77
+ self._executor.reset()
78
+
79
+ @tool_api
80
+ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
81
+ """Launch an IPython Interactive Shell to execute code.
82
+
83
+ Args:
84
+ command (:class:`str`): Python code snippet
85
+ timeout (:class:`Optional[int]`): timeout for execution.
86
+ This argument only works in the main thread. Defaults to ``None``.
87
+ """
88
+ from timeout_decorator import timeout as timer
89
+ tool_return = ActionReturn(args={'text': command}, type=self.name)
90
+ ret = (
91
+ timer(timeout or self.timeout)(self.exec)(command)
92
+ if self._use_signals else self.exec(command))
93
+ if ret.status is Status.SUCCESS:
94
+ tool_return.result = [{'type': 'text', 'content': ret.value}]
95
+ tool_return.state = ActionStatusCode.SUCCESS
96
+ else:
97
+ tool_return.errmsg = ret.msg
98
+ tool_return.state = ActionStatusCode.API_ERROR
99
+ return tool_return
100
+
101
+ def exec(self, code: str) -> ExecutionResult:
102
+ """Run Python scripts in IPython shell.
103
+
104
+ Args:
105
+ code (:class:`str`): code block
106
+
107
+ Returns:
108
+ :py:class:`ExecutionResult`: execution result
109
+ """
110
+ with StringIO() as io:
111
+ with redirect_stdout(io):
112
+ ret = self._executor.run_cell(self.extract_code(code))
113
+ result = ret.result
114
+ if result is not None:
115
+ return ExecutionResult(Status.SUCCESS,
116
+ str(result)[:self._max_out_len])
117
+ outs = io.getvalue().strip().split('\n')
118
+ if not outs:
119
+ return ExecutionResult(Status.SUCCESS, '')
120
+ for i, out in enumerate(outs):
121
+ if re.search('Error|Traceback', out, re.S):
122
+ if 'TimeoutError' in out:
123
+ return ExecutionResult(
124
+ Status.FAILURE,
125
+ msg=('The code interpreter encountered '
126
+ 'a timeout error.'))
127
+ err_idx = i
128
+ break
129
+ else:
130
+ return ExecutionResult(Status.SUCCESS,
131
+ outs[-1].strip()[:self._max_out_len])
132
+ return ExecutionResult(
133
+ Status.FAILURE,
134
+ msg=self._highlighting.sub(
135
+ '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
136
+ )
137
+
138
+ @staticmethod
139
+ def create_shell():
140
+ from IPython import InteractiveShell
141
+ from traitlets.config import Config
142
+
143
+ c = Config()
144
+ c.HistoryManager.enabled = False
145
+ c.HistoryManager.hist_file = ':memory:'
146
+ return InteractiveShell(
147
+ user_ns={'_raise_timeout': _raise_timeout}, config=c)
148
+
149
+ @staticmethod
150
+ def extract_code(text: str) -> str:
151
+ """Extract Python code from markup languages.
152
+
153
+ Args:
154
+ text (:class:`str`): Markdown-formatted text
155
+
156
+ Returns:
157
+ :class:`str`: Python code
158
+ """
159
+ import json5
160
+
161
+ # Match triple backtick blocks first
162
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
163
+ # Match single backtick blocks second
164
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
165
+ if triple_match:
166
+ text = triple_match.group(1)
167
+ elif single_match:
168
+ text = single_match.group(1)
169
+ else:
170
+ try:
171
+ text = json5.loads(text)['code']
172
+ except Exception:
173
+ pass
174
+ # If no code blocks found, return original text
175
+ return text
176
+
177
+ @staticmethod
178
+ def wrap_code_with_timeout(code: str, timeout: int) -> str:
179
+ if not code.strip():
180
+ return code
181
+ code = code.strip('\n').rstrip()
182
+ indent = len(code) - len(code.lstrip())
183
+ handle = ' ' * indent + f'with _raise_timeout({timeout}):\n'
184
+ block = '\n'.join([' ' + line for line in code.split('\n')])
185
+ wrapped_code = handle + block
186
+ last_line = code.split('\n')[-1]
187
+ is_expression = True
188
+ try:
189
+ compile(last_line.lstrip(), '<stdin>', 'eval')
190
+ except SyntaxError:
191
+ is_expression = False
192
+ if is_expression:
193
+ wrapped_code += '\n' * 5 + last_line
194
+ return wrapped_code
195
+
196
+
197
+ class AsyncIPythonInteractive(AsyncActionMixin, IPythonInteractive):
198
+ """An interactive IPython shell for code execution.
199
+
200
+ Args:
201
+ timeout (int): Upper bound of waiting time for Python script execution.
202
+ Defaults to ``20``.
203
+ max_out_len (int): maximum output length. No truncation occurs if negative.
204
+ Defaults to ``2048``.
205
+ use_signals (bool): whether signals should be used for timing function out
206
+ or the multiprocessing. Set to ``False`` when not running in the main
207
+ thread, e.g. web applications. Defaults to ``True``
208
+ description (dict): The description of the action. Defaults to ``None``.
209
+ parser (Type[BaseParser]): The parser class to process the
210
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
211
+ """
212
+
213
+ @tool_api
214
+ async def run(self,
215
+ command: str,
216
+ timeout: Optional[int] = None) -> ActionReturn:
217
+ """Launch an IPython Interactive Shell to execute code.
218
+
219
+ Args:
220
+ command (:class:`str`): Python code snippet
221
+ timeout (:class:`Optional[int]`): timeout for execution.
222
+ This argument only works in the main thread. Defaults to ``None``.
223
+ """
224
+ tool_return = ActionReturn(args={'text': command}, type=self.name)
225
+ ret = await self.exec(command, timeout)
226
+ if ret.status is Status.SUCCESS:
227
+ tool_return.result = [{'type': 'text', 'content': ret.value}]
228
+ tool_return.state = ActionStatusCode.SUCCESS
229
+ else:
230
+ tool_return.errmsg = ret.msg
231
+ tool_return.state = ActionStatusCode.API_ERROR
232
+ return tool_return
233
+
234
+ async def exec(self, code: str, timeout: int = None) -> ExecutionResult:
235
+ """Asynchronously run Python scripts in IPython shell.
236
+
237
+ Args:
238
+ code (:class:`str`): code block
239
+ timeout (:class:`int`): max waiting time for code execution
240
+
241
+ Returns:
242
+ :py:class:`ExecutionResult`: execution result
243
+ """
244
+ with StringIO() as io:
245
+ with redirect_stdout(io):
246
+ ret = await self._executor.run_cell_async(
247
+ # ret = await self.create_shell().run_cell_async(
248
+ self.wrap_code_with_timeout(
249
+ self.extract_code(code), timeout or self.timeout))
250
+ result = ret.result
251
+ if result is not None:
252
+ return ExecutionResult(Status.SUCCESS,
253
+ str(result)[:self._max_out_len])
254
+ outs = io.getvalue().strip().split('\n')
255
+ if not outs:
256
+ return ExecutionResult(Status.SUCCESS, '')
257
+ for i, out in enumerate(outs):
258
+ if re.search('Error|Traceback', out, re.S):
259
+ if 'TimeoutError' in out:
260
+ return ExecutionResult(
261
+ Status.FAILURE,
262
+ msg=('The code interpreter encountered a '
263
+ 'timeout error.'))
264
+ err_idx = i
265
+ break
266
+ else:
267
+ return ExecutionResult(Status.SUCCESS,
268
+ outs[-1].strip()[:self._max_out_len])
269
+ return ExecutionResult(
270
+ Status.FAILURE,
271
+ msg=self._highlighting.sub(
272
+ '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
273
+ )
lagent/actions/ipython_interpreter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import asyncio
3
+ import base64
4
+ import io
5
+ import json
6
+ import logging
7
+ import os
8
+ import queue
9
+ import re
10
+ import signal
11
+ import sys
12
+ import tempfile
13
+ import traceback
14
+ import uuid
15
+ from typing import Optional, Tuple, Type
16
+
17
+ from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager
18
+ from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
19
+
20
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
21
+ from lagent.actions.parser import BaseParser, JsonParser
22
+ from lagent.schema import ActionReturn, ActionStatusCode
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ START_CODE = """
27
+ def input(*args, **kwargs):
28
+ raise NotImplementedError('Python input() function is disabled.')
29
+
30
+ get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
31
+ {}
32
+ """ # noqa
33
+
34
+
35
+ class TimeoutError(Exception):
36
+ pass
37
+
38
+
39
+ class KernelDeath(Exception):
40
+ pass
41
+
42
+
43
+ async def async_run_code(
44
+ km: AsyncKernelManager,
45
+ code,
46
+ *,
47
+ interrupt_after=30,
48
+ iopub_timeout=40,
49
+ wait_for_ready_timeout=60,
50
+ shutdown_kernel=True,
51
+ ):
52
+ assert iopub_timeout > interrupt_after
53
+ try:
54
+
55
+ async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
56
+ *,
57
+ timeout=None):
58
+ loop = asyncio.get_running_loop()
59
+ dead_fut = loop.create_future()
60
+
61
+ def restarting():
62
+ assert (
63
+ False
64
+ ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0"
65
+
66
+ def dead():
67
+ logger.info("Kernel has died, will NOT restart")
68
+ dead_fut.set_result(None)
69
+
70
+ msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout))
71
+ km.add_restart_callback(restarting, "restart")
72
+ km.add_restart_callback(dead, "dead")
73
+ try:
74
+ done, _ = await asyncio.wait(
75
+ [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
76
+ if dead_fut in done:
77
+ raise KernelDeath()
78
+ assert msg_task in done
79
+ return await msg_task
80
+ finally:
81
+ msg_task.cancel()
82
+ km.remove_restart_callback(restarting, "restart")
83
+ km.remove_restart_callback(dead, "dead")
84
+
85
+ async def send_interrupt():
86
+ await asyncio.sleep(interrupt_after)
87
+ logger.info("Sending interrupt to kernel")
88
+ await km.interrupt_kernel()
89
+
90
+ @retry(
91
+ retry=retry_if_result(lambda ret: ret[-1].strip() in [
92
+ 'KeyboardInterrupt',
93
+ f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
94
+ ] if isinstance(ret, tuple) else False),
95
+ stop=stop_after_attempt(3),
96
+ wait=wait_fixed(1),
97
+ retry_error_callback=lambda state: state.outcome.result())
98
+ async def run():
99
+ execute_result = None
100
+ error_traceback = None
101
+ stream_text_list = []
102
+ kc = km.client()
103
+ assert isinstance(kc, AsyncKernelClient)
104
+ kc.start_channels()
105
+ try:
106
+ await kc.wait_for_ready(timeout=wait_for_ready_timeout)
107
+ msg_id = kc.execute(code)
108
+ while True:
109
+ message = await get_iopub_msg_with_death_detection(
110
+ kc, timeout=iopub_timeout)
111
+ if logger.isEnabledFor(logging.DEBUG):
112
+ logger.debug(
113
+ json.dumps(message, indent=2, default=str))
114
+ assert message["parent_header"]["msg_id"] == msg_id
115
+ msg_type = message["msg_type"]
116
+ if msg_type == "status":
117
+ if message["content"]["execution_state"] == "idle":
118
+ break
119
+ elif msg_type == "stream":
120
+ stream_name = message["content"]["name"]
121
+ stream_text = message["content"]["text"]
122
+ stream_text_list.append(stream_text)
123
+ elif msg_type == "execute_result":
124
+ execute_result = message["content"]["data"]
125
+ elif msg_type == "error":
126
+ error_traceback_lines = message["content"]["traceback"]
127
+ error_traceback = "\n".join(error_traceback_lines)
128
+ elif msg_type == "execute_input":
129
+ pass
130
+ else:
131
+ assert False, f"Unknown message_type: {msg_type}"
132
+ finally:
133
+ kc.stop_channels()
134
+ return execute_result, error_traceback, "".join(stream_text_list)
135
+
136
+ if interrupt_after:
137
+ run_task = asyncio.create_task(run())
138
+ send_interrupt_task = asyncio.create_task(send_interrupt())
139
+ done, _ = await asyncio.wait([run_task, send_interrupt_task],
140
+ return_when=asyncio.FIRST_COMPLETED)
141
+ if run_task in done:
142
+ send_interrupt_task.cancel()
143
+ else:
144
+ assert send_interrupt_task in done
145
+ result = await run_task
146
+ else:
147
+ result = await run()
148
+ return result
149
+ finally:
150
+ if shutdown_kernel:
151
+ await km.shutdown_kernel()
152
+
153
+
154
+ class IPythonInterpreter(BaseAction):
155
+ """A IPython executor that can execute Python scripts in a jupyter manner.
156
+
157
+ Args:
158
+ timeout (int): Upper bound of waiting time for Python script execution.
159
+ Defaults to 20.
160
+ user_data_dir (str, optional): Specified the user data directory for files
161
+ loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
162
+ Defaults to `ENV`.
163
+ work_dir (str, optional): Specify which directory to save output images to.
164
+ Defaults to ``'./work_dir/tmp_dir'``.
165
+ description (dict): The description of the action. Defaults to ``None``.
166
+ parser (Type[BaseParser]): The parser class to process the
167
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
168
+ """
169
+
170
+ _KERNEL_CLIENTS = {}
171
+
172
+ def __init__(
173
+ self,
174
+ timeout: int = 20,
175
+ user_data_dir: str = 'ENV',
176
+ work_dir='./work_dir/tmp_dir',
177
+ description: Optional[dict] = None,
178
+ parser: Type[BaseParser] = JsonParser,
179
+ ):
180
+ super().__init__(description, parser)
181
+
182
+ self.timeout = timeout
183
+ if user_data_dir == 'ENV':
184
+ user_data_dir = os.environ.get('USER_DATA_DIR', '')
185
+
186
+ if user_data_dir:
187
+ user_data_dir = os.path.dirname(user_data_dir)
188
+ user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
189
+ self.user_data_dir = user_data_dir
190
+ self._initialized = False
191
+ self.work_dir = work_dir
192
+ if not os.path.exists(self.work_dir):
193
+ os.makedirs(self.work_dir, exist_ok=True)
194
+
195
+ @staticmethod
196
+ def start_kernel():
197
+ from jupyter_client import KernelManager
198
+
199
+ # start the kernel and manager
200
+ km = KernelManager()
201
+ km.start_kernel()
202
+ kc = km.client()
203
+ return km, kc
204
+
205
+ def initialize(self):
206
+ if self._initialized:
207
+ return
208
+ pid = os.getpid()
209
+ if pid not in self._KERNEL_CLIENTS:
210
+ self._KERNEL_CLIENTS[pid] = self.start_kernel()
211
+ self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
212
+ self._initialized = True
213
+ self._call(START_CODE.format(self.user_data_dir), None)
214
+
215
+ def reset(self):
216
+ if not self._initialized:
217
+ self.initialize()
218
+ else:
219
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + \
220
+ START_CODE.format(self.user_data_dir)
221
+ self._call(code, None)
222
+
223
+ def _call(self,
224
+ command: str,
225
+ timeout: Optional[int] = None) -> Tuple[str, bool]:
226
+ self.initialize()
227
+ command = extract_code(command)
228
+
229
+ # check previous remaining result
230
+ while True:
231
+ try:
232
+ msg = self.kernel_client.get_iopub_msg(timeout=5)
233
+ msg_type = msg['msg_type']
234
+ if msg_type == 'status':
235
+ if msg['content'].get('execution_state') == 'idle':
236
+ break
237
+ except queue.Empty:
238
+ # assume no result
239
+ break
240
+
241
+ self.kernel_client.execute(command)
242
+
243
+ def _inner_call():
244
+ result = ''
245
+ images = []
246
+ succeed = True
247
+ image_idx = 0
248
+
249
+ while True:
250
+ text = ''
251
+ image = ''
252
+ finished = False
253
+ msg_type = 'error'
254
+ try:
255
+ msg = self.kernel_client.get_iopub_msg(timeout=20)
256
+ msg_type = msg['msg_type']
257
+ if msg_type == 'status':
258
+ if msg['content'].get('execution_state') == 'idle':
259
+ finished = True
260
+ elif msg_type == 'execute_result':
261
+ text = msg['content']['data'].get('text/plain', '')
262
+ if 'image/png' in msg['content']['data']:
263
+ image_b64 = msg['content']['data']['image/png']
264
+ image_url = publish_image_to_local(
265
+ image_b64, self.work_dir)
266
+ image_idx += 1
267
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
268
+
269
+ elif msg_type == 'display_data':
270
+ if 'image/png' in msg['content']['data']:
271
+ image_b64 = msg['content']['data']['image/png']
272
+ image_url = publish_image_to_local(
273
+ image_b64, self.work_dir)
274
+ image_idx += 1
275
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
276
+
277
+ else:
278
+ text = msg['content']['data'].get('text/plain', '')
279
+ elif msg_type == 'stream':
280
+ msg_type = msg['content']['name'] # stdout, stderr
281
+ text = msg['content']['text']
282
+ elif msg_type == 'error':
283
+ succeed = False
284
+ text = escape_ansi('\n'.join(
285
+ msg['content']['traceback']))
286
+ if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
287
+ text = f'Timeout. No response after {timeout} seconds.' # noqa
288
+ except queue.Empty:
289
+ # stop current task in case break next input.
290
+ self.kernel_manager.interrupt_kernel()
291
+ succeed = False
292
+ text = f'Timeout. No response after {timeout} seconds.'
293
+ finished = True
294
+ except Exception:
295
+ succeed = False
296
+ msg = ''.join(traceback.format_exception(*sys.exc_info()))
297
+ # text = 'The code interpreter encountered an unexpected error.' # noqa
298
+ text = msg
299
+ logging.warning(msg)
300
+ finished = True
301
+ if text:
302
+ # result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
303
+ result += f'{text}'
304
+
305
+ if image:
306
+ images.append(image_url)
307
+ if finished:
308
+ return succeed, dict(text=result, image=images)
309
+
310
+ try:
311
+ if timeout:
312
+
313
+ def handler(signum, frame):
314
+ raise TimeoutError()
315
+
316
+ signal.signal(signal.SIGALRM, handler)
317
+ signal.alarm(timeout)
318
+ succeed, result = _inner_call()
319
+ except TimeoutError:
320
+ succeed = False
321
+ text = 'The code interpreter encountered an unexpected error.'
322
+ result = f'\n\nerror:\n\n```\n{text}\n```'
323
+ finally:
324
+ if timeout:
325
+ signal.alarm(0)
326
+
327
+ # result = result.strip('\n')
328
+ return succeed, result
329
+
330
+ @tool_api
331
+ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
332
+ r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
333
+
334
+ Args:
335
+ command (:class:`str`): Python code
336
+ timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
337
+ """
338
+ tool_return = ActionReturn(url=None, args=None, type=self.name)
339
+ tool_return.args = dict(text=command)
340
+ succeed, result = self._call(command, timeout)
341
+ if succeed:
342
+ text = result['text']
343
+ image = result.get('image', [])
344
+ resp = [dict(type='text', content=text)]
345
+ if image:
346
+ resp.extend([dict(type='image', content=im) for im in image])
347
+ tool_return.result = resp
348
+ # tool_return.result = dict(
349
+ # text=result['text'], image=result.get('image', [])[0])
350
+ tool_return.state = ActionStatusCode.SUCCESS
351
+ else:
352
+ tool_return.errmsg = result.get('text', '') if isinstance(
353
+ result, dict) else result
354
+ tool_return.state = ActionStatusCode.API_ERROR
355
+ return tool_return
356
+
357
+
358
+ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
359
+ """A IPython executor that can execute Python scripts in a jupyter manner.
360
+
361
+ Args:
362
+ timeout (int): Upper bound of waiting time for Python script execution.
363
+ Defaults to 20.
364
+ user_data_dir (str, optional): Specified the user data directory for files
365
+ loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
366
+ Defaults to `ENV`.
367
+ work_dir (str, optional): Specify which directory to save output images to.
368
+ Defaults to ``'./work_dir/tmp_dir'``.
369
+ description (dict): The description of the action. Defaults to ``None``.
370
+ parser (Type[BaseParser]): The parser class to process the
371
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
372
+ """
373
+
374
+ _UNBOUND_KERNEL_CLIENTS = asyncio.Queue()
375
+
376
+ def __init__(
377
+ self,
378
+ timeout: int = 20,
379
+ user_data_dir: str = 'ENV',
380
+ work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'),
381
+ max_kernels: Optional[int] = None,
382
+ reuse_kernel: bool = True,
383
+ startup_rate: bool = 32,
384
+ connection_dir: str = tempfile.gettempdir(),
385
+ description: Optional[dict] = None,
386
+ parser: Type[BaseParser] = JsonParser,
387
+ ):
388
+ super().__init__(timeout, user_data_dir, work_dir, description, parser)
389
+ from traitlets.config import Config
390
+
391
+ c = Config()
392
+ c.KernelManager.transport = 'ipc'
393
+ self._amkm = AsyncMultiKernelManager(
394
+ config=c, connection_dir=connection_dir)
395
+ self._max_kernels = max_kernels
396
+ self._reuse_kernel = reuse_kernel
397
+ self._sem = asyncio.Semaphore(startup_rate)
398
+ self._lock = asyncio.Lock()
399
+
400
+ async def initialize(self, session_id: str):
401
+ session_id = str(session_id)
402
+ while True:
403
+ if session_id in self._KERNEL_CLIENTS:
404
+ return self._KERNEL_CLIENTS[session_id]
405
+ if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
406
+ self._KERNEL_CLIENTS[
407
+ session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
408
+ return self._KERNEL_CLIENTS[session_id]
409
+ async with self._sem:
410
+ if self._max_kernels is None or len(
411
+ self._KERNEL_CLIENTS
412
+ ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
413
+ kernel_id = None
414
+ try:
415
+ kernel_id = await self._amkm.start_kernel()
416
+ kernel = self._amkm.get_kernel(kernel_id)
417
+ client = kernel.client()
418
+ _, error_stacktrace, stream_text = await async_run_code(
419
+ kernel,
420
+ START_CODE.format(self.user_data_dir),
421
+ shutdown_kernel=False)
422
+ # check if the output of START_CODE meets expectations
423
+ if not (error_stacktrace is None
424
+ and stream_text == ''):
425
+ raise RuntimeError
426
+ except Exception as e:
427
+ print(f'Starting kernel error: {e}')
428
+ if kernel_id:
429
+ await self._amkm.shutdown_kernel(kernel_id)
430
+ self._amkm.remove_kernel(kernel_id)
431
+ await asyncio.sleep(1)
432
+ continue
433
+ if self._max_kernels is None:
434
+ self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
435
+ client)
436
+ return kernel_id, kernel, client
437
+ async with self._lock:
438
+ if len(self._KERNEL_CLIENTS
439
+ ) + self._UNBOUND_KERNEL_CLIENTS.qsize(
440
+ ) < self._max_kernels:
441
+ self._KERNEL_CLIENTS[session_id] = (kernel_id,
442
+ kernel, client)
443
+ return kernel_id, kernel, client
444
+ await self._amkm.shutdown_kernel(kernel_id)
445
+ self._amkm.remove_kernel(kernel_id)
446
+ await asyncio.sleep(1)
447
+
448
+ async def reset(self, session_id: str):
449
+ session_id = str(session_id)
450
+ if session_id not in self._KERNEL_CLIENTS:
451
+ return
452
+ _, kernel, _ = self._KERNEL_CLIENTS[session_id]
453
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + \
454
+ START_CODE.format(self.user_data_dir)
455
+ await async_run_code(kernel, code, shutdown_kernel=False)
456
+
457
+ async def shutdown(self, session_id: str):
458
+ session_id = str(session_id)
459
+ if session_id in self._KERNEL_CLIENTS:
460
+ kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id)
461
+ await self._amkm.shutdown_kernel(kernel_id)
462
+ self._amkm.remove_kernel(kernel_id)
463
+ del self._KERNEL_CLIENTS[session_id]
464
+
465
+ async def close_session(self, session_id: str):
466
+ session_id = str(session_id)
467
+ if self._reuse_kernel:
468
+ if session_id in self._KERNEL_CLIENTS:
469
+ await self.reset(session_id)
470
+ await self._UNBOUND_KERNEL_CLIENTS.put(
471
+ self._KERNEL_CLIENTS.pop(session_id))
472
+ else:
473
+ await self.shutdown(session_id)
474
+
475
+ async def _call(self, command, timeout=None, session_id=None):
476
+ _, kernel, _ = await self.initialize(str(session_id))
477
+ result = await async_run_code(
478
+ kernel,
479
+ extract_code(command),
480
+ interrupt_after=timeout or self.timeout,
481
+ shutdown_kernel=False)
482
+ execute_result, error_stacktrace, stream_text = result
483
+ if error_stacktrace is not None:
484
+ ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
485
+ if ret.endswith('KeyboardInterrupt: '):
486
+ ret = 'The code interpreter encountered a timeout error.'
487
+ status, ret = False, ret.strip()
488
+ elif execute_result is not None:
489
+ status, ret = True, dict(text=execute_result.get('text/plain', ''))
490
+ else:
491
+ status, ret = True, dict(text=stream_text.strip())
492
+ return status, ret
493
+
494
+ @tool_api
495
+ async def run(self,
496
+ command: str,
497
+ timeout: Optional[int] = None,
498
+ session_id: Optional[str] = None) -> ActionReturn:
499
+ r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
500
+
501
+ Args:
502
+ command (:class:`str`): Python code
503
+ timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
504
+ """
505
+ tool_return = ActionReturn(url=None, args=None, type=self.name)
506
+ tool_return.args = dict(text=command)
507
+ succeed, result = await self._call(command, timeout, session_id)
508
+ if succeed:
509
+ text = result['text']
510
+ image = result.get('image', [])
511
+ resp = [dict(type='text', content=text)]
512
+ if image:
513
+ resp.extend([dict(type='image', content=im) for im in image])
514
+ tool_return.result = resp
515
+ # tool_return.result = dict(
516
+ # text=result['text'], image=result.get('image', [])[0])
517
+ tool_return.state = ActionStatusCode.SUCCESS
518
+ else:
519
+ tool_return.errmsg = result.get('text', '') if isinstance(
520
+ result, dict) else result
521
+ tool_return.state = ActionStatusCode.API_ERROR
522
+ return tool_return
523
+
524
+
525
+ def extract_code(text):
526
+ import json5
527
+
528
+ # Match triple backtick blocks first
529
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
530
+ # Match single backtick blocks second
531
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
532
+ if triple_match:
533
+ text = triple_match.group(1)
534
+ elif single_match:
535
+ text = single_match.group(1)
536
+ else:
537
+ try:
538
+ text = json5.loads(text)['code']
539
+ except Exception:
540
+ pass
541
+ # If no code blocks found, return original text
542
+ return text
543
+
544
+
545
+ def escape_ansi(line):
546
+ ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
547
+ return ansi_escape.sub('', line)
548
+
549
+
550
+ def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
551
+ import PIL.Image
552
+ image_file = str(uuid.uuid4()) + '.png'
553
+ local_image_file = os.path.join(work_dir, image_file)
554
+
555
+ png_bytes = base64.b64decode(image_base64)
556
+ assert isinstance(png_bytes, bytes)
557
+ bytes_io = io.BytesIO(png_bytes)
558
+ PIL.Image.open(bytes_io).save(local_image_file, 'png')
559
+
560
+ return local_image_file
561
+
562
+
563
+ # local test for code interpreter
564
+ def get_multiline_input(hint):
565
+ print(hint)
566
+ print('// Press ENTER to make a new line. Press CTRL-D to end input.')
567
+ lines = []
568
+ while True:
569
+ try:
570
+ line = input()
571
+ except EOFError: # CTRL-D
572
+ break
573
+ lines.append(line)
574
+ print('// Input received.')
575
+ if lines:
576
+ return '\n'.join(lines)
577
+ else:
578
+ return ''
579
+
580
+
581
+ if __name__ == '__main__':
582
+ code_interpreter = IPythonInterpreter()
583
+ while True:
584
+ print(code_interpreter(get_multiline_input('Enter python code:')))
lagent/actions/ipython_manager.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from collections import defaultdict
4
+ from contextlib import nullcontext
5
+ from io import StringIO
6
+ from multiprocessing import Process, Queue
7
+ from typing import List, Optional, Type, Union
8
+
9
+ from filelock import FileLock
10
+ from timeout_decorator import timeout as tm
11
+
12
+ from ..schema import ActionReturn, ActionStatusCode
13
+ from .base_action import BaseAction
14
+ from .parser import BaseParser, JsonParser
15
+
16
+
17
+ class IPythonProcess(Process):
18
+
19
+ def __init__(self,
20
+ in_q: Queue,
21
+ out_q: Queue,
22
+ timeout: int = 20,
23
+ ci_lock: str = None,
24
+ daemon: bool = True):
25
+ super().__init__(daemon=daemon)
26
+ self.in_q = in_q
27
+ self.out_q = out_q
28
+ self.timeout = timeout
29
+ self.session_id2shell = defaultdict(self.create_shell)
30
+ self.ci_lock = FileLock(
31
+ ci_lock) if ci_lock else nullcontext() # avoid core corruption
32
+ self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
33
+
34
+ def run(self):
35
+ while True:
36
+ msg = self.in_q.get()
37
+ if msg == 'reset':
38
+ for session_id, shell in self.session_id2shell.items():
39
+ with self.ci_lock:
40
+ try:
41
+ shell.reset(new_session=False)
42
+ # shell.run_line_magic('reset', '-sf')
43
+ except Exception:
44
+ self.session_id2shell[
45
+ session_id] = self.create_shell()
46
+ self.out_q.put('ok')
47
+ elif isinstance(msg, tuple) and len(msg) == 3:
48
+ i, session_id, code = msg
49
+ res = self.exec(session_id, code)
50
+ self.out_q.put((i, session_id, res))
51
+
52
+ def exec(self, session_id, code):
53
+ try:
54
+ shell = self.session_id2shell[session_id]
55
+ with StringIO() as io:
56
+ old_stdout = sys.stdout
57
+ sys.stdout = io
58
+ if self.timeout is False or self.timeout < 0:
59
+ shell.run_cell(self.extract_code(code))
60
+ else:
61
+ tm(self.timeout)(shell.run_cell)(self.extract_code(code))
62
+ sys.stdout = old_stdout
63
+ output = self._highlighting.sub('', io.getvalue().strip())
64
+ output = re.sub(r'^Out\[\d+\]: ', '', output)
65
+ if 'Error' in output or 'Traceback' in output:
66
+ output = output.lstrip('-').strip()
67
+ if output.startswith('TimeoutError'):
68
+ output = 'The code interpreter encountered a timeout error.'
69
+ return {'status': 'FAILURE', 'msg': output, 'code': code}
70
+ return {'status': 'SUCCESS', 'value': output, 'code': code}
71
+ except Exception as e:
72
+ return {'status': 'FAILURE', 'msg': str(e), 'code': code}
73
+
74
+ @staticmethod
75
+ def create_shell(enable_history: bool = False, in_memory: bool = True):
76
+ from IPython import InteractiveShell
77
+ from traitlets.config import Config
78
+
79
+ c = Config()
80
+ c.HistoryManager.enabled = enable_history
81
+ if in_memory:
82
+ c.HistoryManager.hist_file = ':memory:'
83
+ shell = InteractiveShell(config=c)
84
+ return shell
85
+
86
+ @staticmethod
87
+ def extract_code(text: str) -> str:
88
+ """Extract Python code from markup languages.
89
+
90
+ Args:
91
+ text (:class:`str`): Markdown-formatted text
92
+
93
+ Returns:
94
+ :class:`str`: Python code
95
+ """
96
+ import json5
97
+
98
+ # Match triple backtick blocks first
99
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
100
+ # Match single backtick blocks second
101
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
102
+ if triple_match:
103
+ text = triple_match.group(1)
104
+ elif single_match:
105
+ text = single_match.group(1)
106
+ else:
107
+ try:
108
+ text = json5.loads(text)['code']
109
+ except Exception:
110
+ pass
111
+ # If no code blocks found, return original text
112
+ return text
113
+
114
+
115
+ class IPythonInteractiveManager(BaseAction):
116
+ """An interactive IPython shell manager for code execution"""
117
+
118
+ def __init__(
119
+ self,
120
+ max_workers: int = 50,
121
+ timeout: int = 20,
122
+ ci_lock: str = None,
123
+ description: Optional[dict] = None,
124
+ parser: Type[BaseParser] = JsonParser,
125
+ ):
126
+ super().__init__(description, parser)
127
+ self.max_workers = max_workers
128
+ self.timeout = timeout
129
+ self.ci_lock = ci_lock
130
+ self.id2queue = defaultdict(Queue)
131
+ self.id2process = {}
132
+ self.out_queue = Queue()
133
+
134
+ def __call__(self,
135
+ commands: Union[str, List[str]],
136
+ session_ids: Union[int, List[int]] = None):
137
+ if isinstance(commands, list):
138
+ batch_size = len(commands)
139
+ is_batch = True
140
+ else:
141
+ batch_size = 1
142
+ commands = [commands]
143
+ is_batch = False
144
+ if session_ids is None:
145
+ session_ids = range(batch_size)
146
+ elif isinstance(session_ids, int):
147
+ session_ids = [session_ids]
148
+ if len(session_ids) != batch_size or len(session_ids) != len(
149
+ set(session_ids)):
150
+ raise ValueError(
151
+ 'the size of `session_ids` must equal that of `commands`')
152
+ try:
153
+ exec_results = self.run_code_blocks([
154
+ (session_id, command)
155
+ for session_id, command in zip(session_ids, commands)
156
+ ])
157
+ except KeyboardInterrupt:
158
+ self.clear()
159
+ exit(1)
160
+ action_returns = []
161
+ for result, code in zip(exec_results, commands):
162
+ action_return = ActionReturn({'command': code}, type=self.name)
163
+ if result['status'] == 'SUCCESS':
164
+ action_return.result = [
165
+ dict(type='text', content=result['value'])
166
+ ]
167
+ action_return.state = ActionStatusCode.SUCCESS
168
+ else:
169
+ action_return.errmsg = result['msg']
170
+ action_return.state = ActionStatusCode.API_ERROR
171
+ action_returns.append(action_return)
172
+ if not is_batch:
173
+ return action_returns[0]
174
+ return action_returns
175
+
176
+ def process_code(self, index, session_id, code):
177
+ ipy_id = session_id % self.max_workers
178
+ input_queue = self.id2queue[ipy_id]
179
+ proc = self.id2process.setdefault(
180
+ ipy_id,
181
+ IPythonProcess(
182
+ input_queue,
183
+ self.out_queue,
184
+ self.timeout,
185
+ self.ci_lock,
186
+ daemon=True))
187
+ if not proc.is_alive():
188
+ proc.start()
189
+ input_queue.put((index, session_id, code))
190
+
191
+ def run_code_blocks(self, session_code_pairs):
192
+ size = len(session_code_pairs)
193
+ for index, (session_id, code) in enumerate(session_code_pairs):
194
+ self.process_code(index, session_id, code)
195
+ results = []
196
+ while len(results) < size:
197
+ msg = self.out_queue.get()
198
+ if isinstance(msg, tuple) and len(msg) == 3:
199
+ index, _, result = msg
200
+ results.append((index, result))
201
+ results.sort()
202
+ return [item[1] for item in results]
203
+
204
+ def clear(self):
205
+ self.id2queue.clear()
206
+ for proc in self.id2process.values():
207
+ proc.terminate()
208
+ self.id2process.clear()
209
+ while not self.out_queue.empty():
210
+ self.out_queue.get()
211
+
212
+ def reset(self):
213
+ cnt = 0
214
+ for q in self.id2queue.values():
215
+ q.put('reset')
216
+ cnt += 1
217
+ while cnt > 0:
218
+ msg = self.out_queue.get()
219
+ if msg == 'ok':
220
+ cnt -= 1
lagent/actions/parser.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from ast import literal_eval
4
+ from typing import Any, List, Union
5
+
6
+
7
+ class ParseError(Exception):
8
+ """Parsing exception class."""
9
+
10
+ def __init__(self, err_msg: str):
11
+ self.err_msg = err_msg
12
+
13
+
14
+ class BaseParser:
15
+ """Base parser to process inputs and outputs of actions.
16
+
17
+ Args:
18
+ action (:class:`BaseAction`): action to validate
19
+
20
+ Attributes:
21
+ PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
22
+ LLMs should follow when generating arguments for decided tools.
23
+ """
24
+
25
+ PARAMETER_DESCRIPTION: str = ''
26
+
27
+ def __init__(self, action):
28
+ self.action = action
29
+ self._api2param = {}
30
+ self._api2required = {}
31
+ # perform basic argument validation
32
+ if action.description:
33
+ for api in action.description.get('api_list',
34
+ [action.description]):
35
+ name = (f'{action.name}.{api["name"]}'
36
+ if self.action.is_toolkit else api['name'])
37
+ required_parameters = set(api['required'])
38
+ all_parameters = {j['name'] for j in api['parameters']}
39
+ if not required_parameters.issubset(all_parameters):
40
+ raise ValueError(
41
+ f'unknown parameters for function "{name}": '
42
+ f'{required_parameters - all_parameters}')
43
+ if self.PARAMETER_DESCRIPTION:
44
+ api['parameter_description'] = self.PARAMETER_DESCRIPTION
45
+ api_name = api['name'] if self.action.is_toolkit else 'run'
46
+ self._api2param[api_name] = api['parameters']
47
+ self._api2required[api_name] = api['required']
48
+
49
+ def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
50
+ """Parse inputs LLMs generate for the action.
51
+
52
+ Args:
53
+ inputs (:class:`str`): input string extracted from responses
54
+
55
+ Returns:
56
+ :class:`dict`: processed input
57
+ """
58
+ inputs = {self._api2param[name][0]['name']: inputs}
59
+ return inputs
60
+
61
+ def parse_outputs(self, outputs: Any) -> List[dict]:
62
+ """Parser outputs returned by the action.
63
+
64
+ Args:
65
+ outputs (:class:`Any`): raw output of the action
66
+
67
+ Returns:
68
+ :class:`List[dict]`: processed output of which each member is a
69
+ dictionary with two keys - 'type' and 'content'.
70
+ """
71
+ if isinstance(outputs, dict):
72
+ outputs = json.dumps(outputs, ensure_ascii=False)
73
+ elif not isinstance(outputs, str):
74
+ outputs = str(outputs)
75
+ return [{
76
+ 'type': 'text',
77
+ 'content': outputs.encode('gbk', 'ignore').decode('gbk')
78
+ }]
79
+
80
+
81
+ class JsonParser(BaseParser):
82
+ """Json parser to convert input string into a dictionary.
83
+
84
+ Args:
85
+ action (:class:`BaseAction`): action to validate
86
+ """
87
+
88
+ PARAMETER_DESCRIPTION = (
89
+ 'If you call this tool, you must pass arguments in '
90
+ 'the JSON format {key: value}, where the key is the parameter name.')
91
+
92
+ def parse_inputs(self,
93
+ inputs: Union[str, dict],
94
+ name: str = 'run') -> dict:
95
+ if not isinstance(inputs, dict):
96
+ try:
97
+ match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
98
+ re.S)
99
+ if match:
100
+ inputs = match.group(2).strip()
101
+ inputs = json.loads(inputs)
102
+ except json.JSONDecodeError as exc:
103
+ raise ParseError(f'invalid json format: {inputs}') from exc
104
+ input_keys = set(inputs)
105
+ all_keys = {param['name'] for param in self._api2param[name]}
106
+ if not input_keys.issubset(all_keys):
107
+ raise ParseError(f'unknown arguments: {input_keys - all_keys}')
108
+ required_keys = set(self._api2required[name])
109
+ if not input_keys.issuperset(required_keys):
110
+ raise ParseError(
111
+ f'missing required arguments: {required_keys - input_keys}')
112
+ return inputs
113
+
114
+
115
+ class TupleParser(BaseParser):
116
+ """Tuple parser to convert input string into a tuple.
117
+
118
+ Args:
119
+ action (:class:`BaseAction`): action to validate
120
+ """
121
+
122
+ PARAMETER_DESCRIPTION = (
123
+ 'If you call this tool, you must pass arguments in the tuple format '
124
+ 'like (arg1, arg2, arg3), and the arguments are ordered.')
125
+
126
+ def parse_inputs(self,
127
+ inputs: Union[str, tuple],
128
+ name: str = 'run') -> dict:
129
+ if not isinstance(inputs, tuple):
130
+ try:
131
+ inputs = literal_eval(inputs)
132
+ except Exception as exc:
133
+ raise ParseError(f'invalid tuple format: {inputs}') from exc
134
+ if len(inputs) < len(self._api2required[name]):
135
+ raise ParseError(
136
+ f'API takes {len(self._api2required[name])} required positional '
137
+ f'arguments but {len(inputs)} were given')
138
+ if len(inputs) > len(self._api2param[name]):
139
+ raise ParseError(
140
+ f'API takes {len(self._api2param[name])} positional arguments '
141
+ f'but {len(inputs)} were given')
142
+ inputs = {
143
+ self._api2param[name][i]['name']: item
144
+ for i, item in enumerate(inputs)
145
+ }
146
+ return inputs
lagent/actions/ppt.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Type
2
+
3
+ from asyncer import asyncify
4
+
5
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
+ from lagent.actions.parser import BaseParser, JsonParser
7
+
8
+ THEME_MAPPING = {
9
+ 'Default': {
10
+ 'template': None,
11
+ 'title': 'Title Slide',
12
+ 'single': 'Title and Content',
13
+ 'two': 'Two Content',
14
+ }
15
+ }
16
+
17
+
18
+ class PPT(BaseAction):
19
+ """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
20
+
21
+ def __init__(
22
+ self,
23
+ theme_mapping: Optional[Dict[str, dict]] = None,
24
+ description: Optional[dict] = None,
25
+ parser: Type[BaseParser] = JsonParser,
26
+ ):
27
+ super().__init__(description, parser)
28
+ self.theme_mapping = theme_mapping or THEME_MAPPING
29
+ self.pointer = None
30
+ self.location = None
31
+
32
+ @tool_api(explode_return=True)
33
+ def create_file(self, theme: str, abs_location: str) -> dict:
34
+ """Create a pptx file with specific themes.
35
+
36
+ Args:
37
+ theme (:class:`str`): the theme used. The value should be one of ['Default'].
38
+ abs_location (:class:`str`): the ppt file's absolute location
39
+
40
+ Returns:
41
+ :class:`dict`: operation status
42
+ * status: the result of the execution
43
+ """
44
+ from pptx import Presentation
45
+
46
+ self.location = abs_location
47
+ try:
48
+ self.pointer = Presentation(self.theme_mapping[theme]['template'])
49
+ self.pointer.slide_master.name = theme
50
+ # print('created')
51
+ except Exception as e:
52
+ print(e)
53
+ return dict(status='created a ppt file.')
54
+
55
+ @tool_api(explode_return=True)
56
+ def add_first_page(self, title: str, subtitle: str) -> dict:
57
+ """Add the first page of ppt.
58
+
59
+ Args:
60
+ title (:class:`str`): the title of ppt
61
+ subtitle (:class:`str`): the subtitle of ppt
62
+
63
+ Returns:
64
+ :class:`dict`: operation status
65
+ * status: the result of the execution
66
+ """
67
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['title']
68
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
69
+ slide = self.pointer.slides.add_slide(layout)
70
+ ph_title, ph_subtitle = slide.placeholders
71
+ ph_title.text = title
72
+ if subtitle:
73
+ ph_subtitle.text = subtitle
74
+ return dict(status='added page')
75
+
76
+ @tool_api(explode_return=True)
77
+ def add_text_page(self, title: str, bullet_items: str) -> dict:
78
+ """Add text page of ppt.
79
+
80
+ Args:
81
+ title (:class:`str`): the title of the page
82
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
83
+
84
+ Returns:
85
+ :class:`dict`: operation status
86
+ * status: the result of the execution
87
+ """ # noqa: E501
88
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['single']
89
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
90
+ slide = self.pointer.slides.add_slide(layout)
91
+ ph_title, ph_body = slide.placeholders
92
+ ph_title.text = title
93
+ ph = ph_body
94
+ tf = ph.text_frame
95
+ for i, item in enumerate(bullet_items.split('[SPAN]')):
96
+ if i == 0:
97
+ p = tf.paragraphs[0]
98
+ else:
99
+ p = tf.add_paragraph()
100
+ p.text = item.strip()
101
+ p.level = 0
102
+ return dict(status='added page')
103
+
104
+ @tool_api(explode_return=True)
105
+ def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
106
+ """Add a text page with one image. Image should be a path.
107
+
108
+ Args:
109
+ title (:class:`str`): the title of the page
110
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
111
+ image (:class:`str`): the path of the image
112
+
113
+ Returns:
114
+ :class:`dict`: operation status
115
+ * status: the result of the execution
116
+ """ # noqa: E501
117
+ from PIL import Image
118
+
119
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
120
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
121
+ slide = self.pointer.slides.add_slide(layout)
122
+ ph_title, ph_body1, ph_body2 = slide.placeholders
123
+ ph_title.text = title
124
+ ph = ph_body2
125
+ image = Image.open(image)
126
+ image_pil = image.to_pil()
127
+ left = ph.left
128
+ width = ph.width
129
+ height = int(width / image_pil.width * image_pil.height)
130
+ top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
131
+ slide.shapes.add_picture(image.to_path(), left, top, width, height)
132
+
133
+ ph = ph_body1
134
+ tf = ph.text_frame
135
+ for i, item in enumerate(bullet_items.split('[SPAN]')):
136
+ if i == 0:
137
+ p = tf.paragraphs[0]
138
+ else:
139
+ p = tf.add_paragraph()
140
+ p.text = item.strip()
141
+ p.level = 0
142
+
143
+ return dict(status='added page')
144
+
145
+ @tool_api(explode_return=True)
146
+ def submit_file(self) -> dict:
147
+ """When all steps done, YOU MUST use submit_file() to submit your work.
148
+
149
+ Returns:
150
+ :class:`dict`: operation status
151
+ * status: the result of the execution
152
+ """
153
+ # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
154
+ # self.pointer.save(file_path)
155
+ # retreival_url = upload_file(file_path)
156
+ self.pointer.save(self.location)
157
+ return dict(status=f'submitted. view ppt at {self.location}')
158
+
159
+
160
+ class AsyncPPT(AsyncActionMixin, PPT):
161
+ """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
162
+
163
+ @tool_api(explode_return=True)
164
+ @asyncify
165
+ def create_file(self, theme: str, abs_location: str) -> dict:
166
+ """Create a pptx file with specific themes.
167
+
168
+ Args:
169
+ theme (:class:`str`): the theme used. The value should be one of ['Default'].
170
+ abs_location (:class:`str`): the ppt file's absolute location
171
+
172
+ Returns:
173
+ :class:`dict`: operation status
174
+ * status: the result of the execution
175
+ """
176
+ return super().create_file(theme, abs_location)
177
+
178
+ @tool_api(explode_return=True)
179
+ @asyncify
180
+ def add_first_page(self, title: str, subtitle: str) -> dict:
181
+ """Add the first page of ppt.
182
+
183
+ Args:
184
+ title (:class:`str`): the title of ppt
185
+ subtitle (:class:`str`): the subtitle of ppt
186
+
187
+ Returns:
188
+ :class:`dict`: operation status
189
+ * status: the result of the execution
190
+ """
191
+ return super().add_first_page(title, subtitle)
192
+
193
+ @tool_api(explode_return=True)
194
+ @asyncify
195
+ def add_text_page(self, title: str, bullet_items: str) -> dict:
196
+ """Add text page of ppt.
197
+
198
+ Args:
199
+ title (:class:`str`): the title of the page
200
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
201
+
202
+ Returns:
203
+ :class:`dict`: operation status
204
+ * status: the result of the execution
205
+ """ # noqa: E501
206
+ return super().add_text_page(title, bullet_items)
207
+
208
+ @tool_api(explode_return=True)
209
+ @asyncify
210
+ def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
211
+ """Add a text page with one image. Image should be a path.
212
+
213
+ Args:
214
+ title (:class:`str`): the title of the page
215
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
216
+ image (:class:`str`): the path of the image
217
+
218
+ Returns:
219
+ :class:`dict`: operation status
220
+ * status: the result of the execution
221
+ """ # noqa: E501
222
+ return super().add_text_image_page(title, bullet_items, image)
223
+
224
+ @tool_api(explode_return=True)
225
+ @asyncify
226
+ def submit_file(self) -> dict:
227
+ """When all steps done, YOU MUST use submit_file() to submit your work.
228
+
229
+ Returns:
230
+ :class:`dict`: operation status
231
+ * status: the result of the execution
232
+ """
233
+ return super().submit_file()
lagent/actions/python_interpreter.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import copy
3
+ import io
4
+ from contextlib import redirect_stdout
5
+ from typing import Any, Optional, Type
6
+
7
+ from asyncer import asyncify
8
+
9
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
+ from lagent.actions.parser import BaseParser, JsonParser
11
+ from lagent.schema import ActionReturn, ActionStatusCode
12
+
13
+
14
+ class GenericRuntime:
15
+ GLOBAL_DICT = {}
16
+ LOCAL_DICT = None
17
+ HEADERS = []
18
+
19
+ def __init__(self):
20
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
21
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
22
+
23
+ for c in self.HEADERS:
24
+ self.exec_code(c)
25
+
26
+ def exec_code(self, code_piece: str) -> None:
27
+ exec(code_piece, self._global_vars)
28
+
29
+ def eval_code(self, expr: str) -> Any:
30
+ return eval(expr, self._global_vars)
31
+
32
+
33
+ class PythonInterpreter(BaseAction):
34
+ """A Python executor that can execute Python scripts.
35
+
36
+ Args:
37
+ answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
38
+ answer_expr (str, Optional): the answer function name of the Python
39
+ script. Defaults to ``'solution()'``.
40
+ answer_from_stdout (boolean, Optional): whether the execution results is from
41
+ stdout. Defaults to ``False``.
42
+ timeout (int, Optional): Upper bound of waiting time for Python script execution.
43
+ Defaults to ``20``.
44
+ description (dict, Optional): The description of the action. Defaults to ``None``.
45
+ parser (Type[BaseParser]): The parser class to process the
46
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ answer_symbol: Optional[str] = None,
52
+ answer_expr: Optional[str] = 'solution()',
53
+ answer_from_stdout: bool = False,
54
+ timeout: int = 20,
55
+ description: Optional[dict] = None,
56
+ parser: Type[BaseParser] = JsonParser,
57
+ ) -> None:
58
+ super().__init__(description, parser)
59
+ self.answer_symbol = answer_symbol
60
+ self.answer_expr = answer_expr
61
+ self.answer_from_stdout = answer_from_stdout
62
+ self.timeout = timeout
63
+
64
+ @tool_api
65
+ def run(self, command: str) -> ActionReturn:
66
+ """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
67
+
68
+ ```python
69
+ # import 依赖包
70
+ import xxx
71
+ def solution():
72
+ # 初始化一些变量
73
+ variable_names_with_real_meaning = xxx
74
+ # 步骤一
75
+ mid_variable = func(variable_names_with_real_meaning)
76
+ # 步骤 x
77
+ mid_variable = func(mid_variable)
78
+ # 最后结果
79
+ final_answer = func(mid_variable)
80
+ return final_answer
81
+ ```
82
+
83
+ Args:
84
+ command (:class:`str`): Python code snippet
85
+ """
86
+ from func_timeout import FunctionTimedOut, func_set_timeout
87
+
88
+ self.runtime = GenericRuntime()
89
+ try:
90
+ tool_return = func_set_timeout(self.timeout)(self._call)(command)
91
+ except FunctionTimedOut as e:
92
+ tool_return = ActionReturn(type=self.name)
93
+ tool_return.errmsg = repr(e)
94
+ tool_return.state = ActionStatusCode.API_ERROR
95
+ return tool_return
96
+
97
+ def _call(self, command: str) -> ActionReturn:
98
+ tool_return = ActionReturn(type=self.name)
99
+ try:
100
+ if '```python' in command:
101
+ command = command.split('```python')[1].split('```')[0]
102
+ elif '```' in command:
103
+ command = command.split('```')[1].split('```')[0]
104
+ tool_return.args = dict(text='```python\n' + command + '\n```')
105
+ command = command.split('\n')
106
+
107
+ if self.answer_from_stdout:
108
+ program_io = io.StringIO()
109
+ with redirect_stdout(program_io):
110
+ self.runtime.exec_code('\n'.join(command))
111
+ program_io.seek(0)
112
+ res = program_io.readlines()[-1]
113
+ elif self.answer_symbol:
114
+ self.runtime.exec_code('\n'.join(command))
115
+ res = self.runtime._global_vars[self.answer_symbol]
116
+ elif self.answer_expr:
117
+ self.runtime.exec_code('\n'.join(command))
118
+ res = self.runtime.eval_code(self.answer_expr)
119
+ else:
120
+ self.runtime.exec_code('\n'.join(command[:-1]))
121
+ res = self.runtime.eval_code(command[-1])
122
+ except Exception as e:
123
+ tool_return.errmsg = repr(e)
124
+ tool_return.type = self.name
125
+ tool_return.state = ActionStatusCode.API_ERROR
126
+ return tool_return
127
+ try:
128
+ tool_return.result = [dict(type='text', content=str(res))]
129
+ tool_return.state = ActionStatusCode.SUCCESS
130
+ except Exception as e:
131
+ tool_return.errmsg = repr(e)
132
+ tool_return.type = self.name
133
+ tool_return.state = ActionStatusCode.API_ERROR
134
+ return tool_return
135
+
136
+
137
+ class AsyncPythonInterpreter(AsyncActionMixin, PythonInterpreter):
138
+ """A Python executor that can execute Python scripts.
139
+
140
+ Args:
141
+ answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
142
+ answer_expr (str, Optional): the answer function name of the Python
143
+ script. Defaults to ``'solution()'``.
144
+ answer_from_stdout (boolean, Optional): whether the execution results is from
145
+ stdout. Defaults to ``False``.
146
+ timeout (int, Optional): Upper bound of waiting time for Python script execution.
147
+ Defaults to ``20``.
148
+ description (dict, Optional): The description of the action. Defaults to ``None``.
149
+ parser (Type[BaseParser]): The parser class to process the
150
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
151
+ """
152
+
153
+ @tool_api
154
+ @asyncify
155
+ def run(self, command: str) -> ActionReturn:
156
+ """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
157
+
158
+ ```python
159
+ # import 依赖包
160
+ import xxx
161
+ def solution():
162
+ # 初始化一些变量
163
+ variable_names_with_real_meaning = xxx
164
+ # 步骤一
165
+ mid_variable = func(variable_names_with_real_meaning)
166
+ # 步骤 x
167
+ mid_variable = func(mid_variable)
168
+ # 最后结果
169
+ final_answer = func(mid_variable)
170
+ return final_answer
171
+ ```
172
+
173
+ Args:
174
+ command (:class:`str`): Python code snippet
175
+ """
176
+ return super().run(command)
lagent/actions/weather_query.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from lagent.actions.base_action import BaseAction, tool_api
4
+ from lagent.schema import ActionReturn, ActionStatusCode
5
+
6
+ class WeatherQuery(BaseAction):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.api_key = os.getenv("weather_token")
10
+ print(self.api_key)
11
+ if not self.api_key:
12
+ raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
13
+
14
+ @tool_api
15
+ def run(self, location: str) -> dict:
16
+ """
17
+ 查询实时天气信息。
18
+
19
+ Args:
20
+ location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
21
+
22
+ Returns:
23
+ dict: 包含天气信息的字典
24
+ * location: 地点名称
25
+ * weather: 天气状况
26
+ * temperature: 当前温度
27
+ * wind_direction: 风向
28
+ * wind_speed: 风速(公里/小时)
29
+ * humidity: 相对湿度(%)
30
+ * report_time: 数据报告时间
31
+ """
32
+ try:
33
+ # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
34
+ if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
35
+ # 使用 GeoAPI 获取 LocationID
36
+ geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
37
+ geo_response = requests.get(geo_url)
38
+ geo_data = geo_response.json()
39
+
40
+ if geo_data.get("code") != "200" or not geo_data.get("location"):
41
+ raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
42
+
43
+ location = geo_data["location"][0]["id"]
44
+
45
+ # 构建天气查询的 API 请求 URL
46
+ weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
47
+ response = requests.get(weather_url)
48
+ data = response.json()
49
+
50
+ # 检查 API 响应码
51
+ if data.get("code") != "200":
52
+ raise Exception(f"Weather API 返回错误码:{data.get('code')}")
53
+
54
+ # 解析和组织天气信息
55
+ weather_info = {
56
+ "location": location,
57
+ "weather": data["now"]["text"],
58
+ "temperature": data["now"]["temp"] + "°C",
59
+ "wind_direction": data["now"]["windDir"],
60
+ "wind_speed": data["now"]["windSpeed"] + " km/h",
61
+ "humidity": data["now"]["humidity"] + "%",
62
+ "report_time": data["updateTime"]
63
+ }
64
+
65
+ return {"result": weather_info}
66
+
67
+ except Exception as exc:
68
+ return ActionReturn(
69
+ errmsg=f"WeatherQuery 异常:{exc}",
70
+ state=ActionStatusCode.HTTP_ERROR
71
+ )
lagent/actions/web_browser.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import hashlib
3
+ import hmac
4
+ import json
5
+ import logging
6
+ import random
7
+ import re
8
+ import time
9
+ import warnings
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from datetime import datetime
12
+ from http.client import HTTPSConnection
13
+ from typing import List, Optional, Tuple, Type, Union
14
+
15
+ import aiohttp
16
+ import aiohttp.client_exceptions
17
+ import requests
18
+ from asyncache import cached as acached
19
+ from bs4 import BeautifulSoup
20
+ from cachetools import TTLCache, cached
21
+ from duckduckgo_search import DDGS, AsyncDDGS
22
+
23
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
24
+ from lagent.actions.parser import BaseParser, JsonParser
25
+ from lagent.utils import async_as_completed
26
+
27
+
28
+ class BaseSearch:
29
+
30
+ def __init__(self, topk: int = 3, black_list: List[str] = None):
31
+ self.topk = topk
32
+ self.black_list = black_list
33
+
34
+ def _filter_results(self, results: List[tuple]) -> dict:
35
+ filtered_results = {}
36
+ count = 0
37
+ for url, snippet, title in results:
38
+ if all(domain not in url
39
+ for domain in self.black_list) and not url.endswith('.pdf'):
40
+ filtered_results[count] = {
41
+ 'url': url,
42
+ 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
43
+ 'title': title
44
+ }
45
+ count += 1
46
+ if count >= self.topk:
47
+ break
48
+ return filtered_results
49
+
50
+
51
+ class DuckDuckGoSearch(BaseSearch):
52
+
53
+ def __init__(self,
54
+ topk: int = 3,
55
+ black_list: List[str] = [
56
+ 'enoN',
57
+ 'youtube.com',
58
+ 'bilibili.com',
59
+ 'researchgate.net',
60
+ ],
61
+ **kwargs):
62
+ self.proxy = kwargs.get('proxy')
63
+ self.timeout = kwargs.get('timeout', 30)
64
+ super().__init__(topk, black_list)
65
+
66
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
67
+ def search(self, query: str, max_retry: int = 3) -> dict:
68
+ for attempt in range(max_retry):
69
+ try:
70
+ response = self._call_ddgs(
71
+ query, timeout=self.timeout, proxy=self.proxy)
72
+ return self._parse_response(response)
73
+ except Exception as e:
74
+ logging.exception(str(e))
75
+ warnings.warn(
76
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
77
+ time.sleep(random.randint(2, 5))
78
+ raise Exception(
79
+ 'Failed to get search results from DuckDuckGo after retries.')
80
+
81
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
82
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
83
+ for attempt in range(max_retry):
84
+ try:
85
+ ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
86
+ response = await ddgs.atext(query.strip("'"), max_results=10)
87
+ return self._parse_response(response)
88
+ except Exception as e:
89
+ if isinstance(e, asyncio.TimeoutError):
90
+ logging.exception('Request to DDGS timed out.')
91
+ logging.exception(str(e))
92
+ warnings.warn(
93
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
94
+ await asyncio.sleep(random.randint(2, 5))
95
+ raise Exception(
96
+ 'Failed to get search results from DuckDuckGo after retries.')
97
+
98
+ async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
99
+ ddgs = DDGS(**kwargs)
100
+ try:
101
+ response = await asyncio.wait_for(
102
+ asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
103
+ timeout=self.timeout)
104
+ return response
105
+ except asyncio.TimeoutError:
106
+ logging.exception('Request to DDGS timed out.')
107
+ raise
108
+
109
+ def _call_ddgs(self, query: str, **kwargs) -> dict:
110
+ loop = asyncio.new_event_loop()
111
+ asyncio.set_event_loop(loop)
112
+ try:
113
+ response = loop.run_until_complete(
114
+ self._async_call_ddgs(query, **kwargs))
115
+ return response
116
+ finally:
117
+ loop.close()
118
+
119
+ def _parse_response(self, response: dict) -> dict:
120
+ raw_results = []
121
+ for item in response:
122
+ raw_results.append(
123
+ (item['href'], item['description']
124
+ if 'description' in item else item['body'], item['title']))
125
+ return self._filter_results(raw_results)
126
+
127
+
128
+ class BingSearch(BaseSearch):
129
+
130
+ def __init__(self,
131
+ api_key: str,
132
+ region: str = 'zh-CN',
133
+ topk: int = 3,
134
+ black_list: List[str] = [
135
+ 'enoN',
136
+ 'youtube.com',
137
+ 'bilibili.com',
138
+ 'researchgate.net',
139
+ ],
140
+ **kwargs):
141
+ self.api_key = api_key
142
+ self.market = region
143
+ self.proxy = kwargs.get('proxy')
144
+ super().__init__(topk, black_list)
145
+
146
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
147
+ def search(self, query: str, max_retry: int = 3) -> dict:
148
+ for attempt in range(max_retry):
149
+ try:
150
+ response = self._call_bing_api(query)
151
+ return self._parse_response(response)
152
+ except Exception as e:
153
+ logging.exception(str(e))
154
+ warnings.warn(
155
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
156
+ time.sleep(random.randint(2, 5))
157
+ raise Exception(
158
+ 'Failed to get search results from Bing Search after retries.')
159
+
160
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
161
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
162
+ for attempt in range(max_retry):
163
+ try:
164
+ response = await self._async_call_bing_api(query)
165
+ return self._parse_response(response)
166
+ except Exception as e:
167
+ logging.exception(str(e))
168
+ warnings.warn(
169
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
170
+ await asyncio.sleep(random.randint(2, 5))
171
+ raise Exception(
172
+ 'Failed to get search results from Bing Search after retries.')
173
+
174
+ def _call_bing_api(self, query: str) -> dict:
175
+ endpoint = 'https://api.bing.microsoft.com/v7.0/search'
176
+ params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
177
+ headers = {'Ocp-Apim-Subscription-Key': self.api_key}
178
+ response = requests.get(
179
+ endpoint, headers=headers, params=params, proxies=self.proxy)
180
+ response.raise_for_status()
181
+ return response.json()
182
+
183
+ async def _async_call_bing_api(self, query: str) -> dict:
184
+ endpoint = 'https://api.bing.microsoft.com/v7.0/search'
185
+ params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
186
+ headers = {'Ocp-Apim-Subscription-Key': self.api_key}
187
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
188
+ async with session.get(
189
+ endpoint,
190
+ headers=headers,
191
+ params=params,
192
+ proxy=self.proxy and
193
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
194
+ return await resp.json()
195
+
196
+ def _parse_response(self, response: dict) -> dict:
197
+ webpages = {
198
+ w['id']: w
199
+ for w in response.get('webPages', {}).get('value', [])
200
+ }
201
+ raw_results = []
202
+
203
+ for item in response.get('rankingResponse',
204
+ {}).get('mainline', {}).get('items', []):
205
+ if item['answerType'] == 'WebPages':
206
+ webpage = webpages.get(item['value']['id'])
207
+ if webpage:
208
+ raw_results.append(
209
+ (webpage['url'], webpage['snippet'], webpage['name']))
210
+ elif item['answerType'] == 'News' and item['value'][
211
+ 'id'] == response.get('news', {}).get('id'):
212
+ for news in response.get('news', {}).get('value', []):
213
+ raw_results.append(
214
+ (news['url'], news['description'], news['name']))
215
+
216
+ return self._filter_results(raw_results)
217
+
218
+
219
+ class BraveSearch(BaseSearch):
220
+ """
221
+ Wrapper around the Brave Search API.
222
+
223
+ To use, you should pass your Brave Search API key to the constructor.
224
+
225
+ Args:
226
+ api_key (str): API KEY to use Brave Search API.
227
+ You can create a free API key at https://api.search.brave.com/app/keys.
228
+ search_type (str): Brave Search API supports ['web', 'news', 'images', 'videos'],
229
+ currently only supports 'news' and 'web'.
230
+ topk (int): The number of search results returned in response from API search results.
231
+ region (str): The country code string. Specifies the country where the search results come from.
232
+ language (str): The language code string. Specifies the preferred language for the search results.
233
+ extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
234
+ **kwargs: Any other parameters related to the Brave Search API. Find more details at
235
+ https://api.search.brave.com/app/documentation/web-search/get-started.
236
+ """
237
+
238
+ def __init__(self,
239
+ api_key: str,
240
+ region: str = 'ALL',
241
+ language: str = 'zh-hans',
242
+ extra_snippests: bool = True,
243
+ topk: int = 3,
244
+ black_list: List[str] = [
245
+ 'enoN',
246
+ 'youtube.com',
247
+ 'bilibili.com',
248
+ 'researchgate.net',
249
+ ],
250
+ **kwargs):
251
+ self.api_key = api_key
252
+ self.market = region
253
+ self.proxy = kwargs.get('proxy')
254
+ self.language = language
255
+ self.extra_snippests = extra_snippests
256
+ self.search_type = kwargs.get('search_type', 'web')
257
+ self.kwargs = kwargs
258
+ super().__init__(topk, black_list)
259
+
260
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
261
+ def search(self, query: str, max_retry: int = 3) -> dict:
262
+ for attempt in range(max_retry):
263
+ try:
264
+ response = self._call_brave_api(query)
265
+ return self._parse_response(response)
266
+ except Exception as e:
267
+ logging.exception(str(e))
268
+ warnings.warn(
269
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
270
+ time.sleep(random.randint(2, 5))
271
+ raise Exception(
272
+ 'Failed to get search results from Brave Search after retries.')
273
+
274
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
275
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
276
+ for attempt in range(max_retry):
277
+ try:
278
+ response = await self._async_call_brave_api(query)
279
+ return self._parse_response(response)
280
+ except Exception as e:
281
+ logging.exception(str(e))
282
+ warnings.warn(
283
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
284
+ await asyncio.sleep(random.randint(2, 5))
285
+ raise Exception(
286
+ 'Failed to get search results from Brave Search after retries.')
287
+
288
+ def _call_brave_api(self, query: str) -> dict:
289
+ endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
290
+ params = {
291
+ 'q': query,
292
+ 'country': self.market,
293
+ 'search_lang': self.language,
294
+ 'extra_snippets': self.extra_snippests,
295
+ 'count': self.topk,
296
+ **{
297
+ key: value
298
+ for key, value in self.kwargs.items() if value is not None
299
+ },
300
+ }
301
+ headers = {
302
+ 'X-Subscription-Token': self.api_key or '',
303
+ 'Accept': 'application/json'
304
+ }
305
+ response = requests.get(
306
+ endpoint, headers=headers, params=params, proxies=self.proxy)
307
+ response.raise_for_status()
308
+ return response.json()
309
+
310
+ async def _async_call_brave_api(self, query: str) -> dict:
311
+ endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
312
+ params = {
313
+ 'q': query,
314
+ 'country': self.market,
315
+ 'search_lang': self.language,
316
+ 'extra_snippets': self.extra_snippests,
317
+ 'count': self.topk,
318
+ **{
319
+ key: value
320
+ for key, value in self.kwargs.items() if value is not None
321
+ },
322
+ }
323
+ headers = {
324
+ 'X-Subscription-Token': self.api_key or '',
325
+ 'Accept': 'application/json'
326
+ }
327
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
328
+ async with session.get(
329
+ endpoint,
330
+ headers=headers,
331
+ params=params,
332
+ proxy=self.proxy and
333
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
334
+ return await resp.json()
335
+
336
+ def _parse_response(self, response: dict) -> dict:
337
+ if self.search_type == 'web':
338
+ filtered_result = response.get('web', {}).get('results', [])
339
+ else:
340
+ filtered_result = response.get('results', {})
341
+ raw_results = []
342
+
343
+ for item in filtered_result:
344
+ raw_results.append((
345
+ item.get('url', ''),
346
+ ' '.join(
347
+ filter(None, [
348
+ item.get('description'),
349
+ *item.get('extra_snippets', [])
350
+ ])),
351
+ item.get('title', ''),
352
+ ))
353
+ return self._filter_results(raw_results)
354
+
355
+
356
+ class GoogleSearch(BaseSearch):
357
+ """
358
+ Wrapper around the Serper.dev Google Search API.
359
+
360
+ To use, you should pass your serper API key to the constructor.
361
+
362
+ Args:
363
+ api_key (str): API KEY to use serper google search API.
364
+ You can create a free API key at https://serper.dev.
365
+ search_type (str): Serper API supports ['search', 'images', 'news',
366
+ 'places'] types of search, currently we only support 'search' and 'news'.
367
+ topk (int): The number of search results returned in response from api search results.
368
+ **kwargs: Any other parameters related to the Serper API. Find more details at
369
+ https://serper.dev/playground
370
+ """
371
+
372
+ result_key_for_type = {
373
+ 'news': 'news',
374
+ 'places': 'places',
375
+ 'images': 'images',
376
+ 'search': 'organic',
377
+ }
378
+
379
+ def __init__(self,
380
+ api_key: str,
381
+ topk: int = 3,
382
+ black_list: List[str] = [
383
+ 'enoN',
384
+ 'youtube.com',
385
+ 'bilibili.com',
386
+ 'researchgate.net',
387
+ ],
388
+ **kwargs):
389
+ self.api_key = api_key
390
+ self.proxy = kwargs.get('proxy')
391
+ self.search_type = kwargs.get('search_type', 'search')
392
+ self.kwargs = kwargs
393
+ super().__init__(topk, black_list)
394
+
395
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
396
+ def search(self, query: str, max_retry: int = 3) -> dict:
397
+ for attempt in range(max_retry):
398
+ try:
399
+ response = self._call_serper_api(query)
400
+ return self._parse_response(response)
401
+ except Exception as e:
402
+ logging.exception(str(e))
403
+ warnings.warn(
404
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
405
+ time.sleep(random.randint(2, 5))
406
+ raise Exception(
407
+ 'Failed to get search results from Google Serper Search after retries.'
408
+ )
409
+
410
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
411
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
412
+ for attempt in range(max_retry):
413
+ try:
414
+ response = await self._async_call_serper_api(query)
415
+ return self._parse_response(response)
416
+ except Exception as e:
417
+ logging.exception(str(e))
418
+ warnings.warn(
419
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
420
+ await asyncio.sleep(random.randint(2, 5))
421
+ raise Exception(
422
+ 'Failed to get search results from Google Serper Search after retries.'
423
+ )
424
+
425
+ def _call_serper_api(self, query: str) -> dict:
426
+ endpoint = f'https://google.serper.dev/{self.search_type}'
427
+ params = {
428
+ 'q': query,
429
+ 'num': self.topk,
430
+ **{
431
+ key: value
432
+ for key, value in self.kwargs.items() if value is not None
433
+ },
434
+ }
435
+ headers = {
436
+ 'X-API-KEY': self.api_key or '',
437
+ 'Content-Type': 'application/json'
438
+ }
439
+ response = requests.get(
440
+ endpoint, headers=headers, params=params, proxies=self.proxy)
441
+ response.raise_for_status()
442
+ return response.json()
443
+
444
+ async def _async_call_serper_api(self, query: str) -> dict:
445
+ endpoint = f'https://google.serper.dev/{self.search_type}'
446
+ params = {
447
+ 'q': query,
448
+ 'num': self.topk,
449
+ **{
450
+ key: value
451
+ for key, value in self.kwargs.items() if value is not None
452
+ },
453
+ }
454
+ headers = {
455
+ 'X-API-KEY': self.api_key or '',
456
+ 'Content-Type': 'application/json'
457
+ }
458
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
459
+ async with session.get(
460
+ endpoint,
461
+ headers=headers,
462
+ params=params,
463
+ proxy=self.proxy and
464
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
465
+ return await resp.json()
466
+
467
+ def _parse_response(self, response: dict) -> dict:
468
+ raw_results = []
469
+
470
+ if response.get('answerBox'):
471
+ answer_box = response.get('answerBox', {})
472
+ if answer_box.get('answer'):
473
+ raw_results.append(('', answer_box.get('answer'), ''))
474
+ elif answer_box.get('snippet'):
475
+ raw_results.append(
476
+ ('', answer_box.get('snippet').replace('\n', ' '), ''))
477
+ elif answer_box.get('snippetHighlighted'):
478
+ raw_results.append(
479
+ ('', answer_box.get('snippetHighlighted'), ''))
480
+
481
+ if response.get('knowledgeGraph'):
482
+ kg = response.get('knowledgeGraph', {})
483
+ description = kg.get('description', '')
484
+ attributes = '. '.join(
485
+ f'{attribute}: {value}'
486
+ for attribute, value in kg.get('attributes', {}).items())
487
+ raw_results.append(
488
+ (kg.get('descriptionLink', ''),
489
+ f'{description}. {attributes}' if attributes else description,
490
+ f"{kg.get('title', '')}: {kg.get('type', '')}."))
491
+
492
+ for result in response[self.result_key_for_type[
493
+ self.search_type]][:self.topk]:
494
+ description = result.get('snippet', '')
495
+ attributes = '. '.join(
496
+ f'{attribute}: {value}'
497
+ for attribute, value in result.get('attributes', {}).items())
498
+ raw_results.append(
499
+ (result.get('link', ''),
500
+ f'{description}. {attributes}' if attributes else description,
501
+ result.get('title', '')))
502
+
503
+ return self._filter_results(raw_results)
504
+
505
+
506
+ class TencentSearch(BaseSearch):
507
+ """Wrapper around the tencentclound Search API.
508
+
509
+ To use, you should pass your secret_id and secret_key to the constructor.
510
+
511
+ Args:
512
+ secret_id (str): Your Tencent Cloud secret ID for accessing the API.
513
+ For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488.
514
+ secret_key (str): Your Tencent Cloud secret key for accessing the API.
515
+ api_key (str, optional): Additional API key, if required.
516
+ action (str): The action for this interface, use `SearchCommon`.
517
+ version (str): The API version, use `2020-12-29`.
518
+ service (str): The service name, use `tms`.
519
+ host (str): The API host, use `tms.tencentcloudapi.com`.
520
+ topk (int): The maximum number of search results to return.
521
+ tsn (int): Time filter for search results. Valid values:
522
+ 1 (within 1 day), 2 (within 1 week), 3 (within 1 month),
523
+ 4 (within 1 year), 5 (within 6 months), 6 (within 3 years).
524
+ insite (str): Specify a site to search within (supports only a single site).
525
+ If not specified, the entire web is searched. Example: `zhihu.com`.
526
+ category (str): Vertical category for filtering results. Optional values include:
527
+ `baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope).
528
+ vrid (str): Result card type(s). Different `vrid` values represent different types of result cards.
529
+ Supports multiple values separated by commas. Example: `30010255`.
530
+ """
531
+
532
+ def __init__(self,
533
+ secret_id: str = 'Your SecretId',
534
+ secret_key: str = 'Your SecretKey',
535
+ api_key: str = '',
536
+ action: str = 'SearchCommon',
537
+ version: str = '2020-12-29',
538
+ service: str = 'tms',
539
+ host: str = 'tms.tencentcloudapi.com',
540
+ topk: int = 3,
541
+ tsn: int = None,
542
+ insite: str = None,
543
+ category: str = None,
544
+ vrid: str = None,
545
+ black_list: List[str] = [
546
+ 'enoN',
547
+ 'youtube.com',
548
+ 'bilibili.com',
549
+ 'researchgate.net',
550
+ ]):
551
+ self.secret_id = secret_id
552
+ self.secret_key = secret_key
553
+ self.api_key = api_key
554
+ self.action = action
555
+ self.version = version
556
+ self.service = service
557
+ self.host = host
558
+ self.tsn = tsn
559
+ self.insite = insite
560
+ self.category = category
561
+ self.vrid = vrid
562
+ super().__init__(topk, black_list=black_list)
563
+
564
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
565
+ def search(self, query: str, max_retry: int = 3) -> dict:
566
+ for attempt in range(max_retry):
567
+ try:
568
+ response = self._call_tencent_api(query)
569
+ return self._parse_response(response)
570
+ except Exception as e:
571
+ logging.exception(str(e))
572
+ warnings.warn(
573
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
574
+ time.sleep(random.randint(2, 5))
575
+ raise Exception(
576
+ 'Failed to get search results from Bing Search after retries.')
577
+
578
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
579
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
580
+ for attempt in range(max_retry):
581
+ try:
582
+ response = await self._async_call_tencent_api(query)
583
+ return self._parse_response(response)
584
+ except Exception as e:
585
+ logging.exception(str(e))
586
+ warnings.warn(
587
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
588
+ await asyncio.sleep(random.randint(2, 5))
589
+ raise Exception(
590
+ 'Failed to get search results from Bing Search after retries.')
591
+
592
+ def _get_headers_and_payload(self, query: str) -> tuple:
593
+
594
+ def sign(key, msg):
595
+ return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
596
+
597
+ params = dict(Query=query)
598
+ # if self.topk:
599
+ # params['Cnt'] = self.topk
600
+ if self.tsn:
601
+ params['Tsn'] = self.tsn
602
+ if self.insite:
603
+ params['Insite'] = self.insite
604
+ if self.category:
605
+ params['Category'] = self.category
606
+ if self.vrid:
607
+ params['Vrid'] = self.vrid
608
+ payload = json.dumps(params)
609
+ algorithm = 'TC3-HMAC-SHA256'
610
+ timestamp = int(time.time())
611
+ date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d')
612
+
613
+ # ************* 步骤 1:拼接规范请求串 *************
614
+ http_request_method = 'POST'
615
+ canonical_uri = '/'
616
+ canonical_querystring = ''
617
+ ct = 'application/json; charset=utf-8'
618
+ canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
619
+ signed_headers = 'content-type;host;x-tc-action'
620
+ hashed_request_payload = hashlib.sha256(
621
+ payload.encode('utf-8')).hexdigest()
622
+ canonical_request = (
623
+ http_request_method + '\n' + canonical_uri + '\n' +
624
+ canonical_querystring + '\n' + canonical_headers + '\n' +
625
+ signed_headers + '\n' + hashed_request_payload)
626
+
627
+ # ************* 步骤 2:拼接待签名字符串 *************
628
+ credential_scope = date + '/' + self.service + '/' + 'tc3_request'
629
+ hashed_canonical_request = hashlib.sha256(
630
+ canonical_request.encode('utf-8')).hexdigest()
631
+ string_to_sign = (
632
+ algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
633
+ '\n' + hashed_canonical_request)
634
+
635
+ # ************* 步骤 3:计算签名 *************
636
+ secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
637
+ secret_service = sign(secret_date, self.service)
638
+ secret_signing = sign(secret_service, 'tc3_request')
639
+ signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
640
+ hashlib.sha256).hexdigest()
641
+
642
+ # ************* 步骤 4:拼接 Authorization *************
643
+ authorization = (
644
+ algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
645
+ credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
646
+ ', ' + 'Signature=' + signature)
647
+
648
+ # ************* 步骤 5:构造并发起请求 *************
649
+ headers = {
650
+ 'Authorization': authorization,
651
+ 'Content-Type': 'application/json; charset=utf-8',
652
+ 'Host': self.host,
653
+ 'X-TC-Action': self.action,
654
+ 'X-TC-Timestamp': str(timestamp),
655
+ 'X-TC-Version': self.version
656
+ }
657
+ # if self.region:
658
+ # headers["X-TC-Region"] = self.region
659
+ if self.api_key:
660
+ headers['X-TC-Token'] = self.api_key
661
+ return headers, payload
662
+
663
+ def _call_tencent_api(self, query: str) -> dict:
664
+ headers, payload = self._get_headers_and_payload(query)
665
+ req = HTTPSConnection(self.host)
666
+ req.request('POST', '/', headers=headers, body=payload.encode('utf-8'))
667
+ resp = req.getresponse()
668
+ try:
669
+ resp = json.loads(resp.read().decode('utf-8'))
670
+ except Exception as e:
671
+ logging.warning(str(e))
672
+ import ast
673
+ resp = ast.literal_eval(resp)
674
+ return resp.get('Response', dict())
675
+
676
+ async def _async_call_tencent_api(self, query: str):
677
+ headers, payload = self._get_headers_and_payload(query)
678
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
679
+ async with session.post(
680
+ 'https://' + self.host.lstrip('/'),
681
+ headers=headers,
682
+ data=payload) as resp:
683
+ return (await resp.json()).get('Response', {})
684
+
685
+ def _parse_response(self, response: dict) -> dict:
686
+ raw_results = []
687
+ for item in response.get('Pages', []):
688
+ display = json.loads(item['Display'])
689
+ if not display['url']:
690
+ continue
691
+ raw_results.append((display['url'], display['content']
692
+ or display['abstract_info'], display['title']))
693
+ return self._filter_results(raw_results)
694
+
695
+
696
+ class ContentFetcher:
697
+
698
+ def __init__(self, timeout: int = 5):
699
+ self.timeout = timeout
700
+
701
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
702
+ def fetch(self, url: str) -> Tuple[bool, str]:
703
+ try:
704
+ response = requests.get(url, timeout=self.timeout)
705
+ response.raise_for_status()
706
+ html = response.content
707
+ except requests.RequestException as e:
708
+ return False, str(e)
709
+
710
+ text = BeautifulSoup(html, 'html.parser').get_text()
711
+ cleaned_text = re.sub(r'\n+', '\n', text)
712
+ return True, cleaned_text
713
+
714
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
715
+ async def afetch(self, url: str) -> Tuple[bool, str]:
716
+ try:
717
+ async with aiohttp.ClientSession(
718
+ raise_for_status=True,
719
+ timeout=aiohttp.ClientTimeout(self.timeout)) as session:
720
+ async with session.get(url) as resp:
721
+ html = await resp.text(errors='ignore')
722
+ text = BeautifulSoup(html, 'html.parser').get_text()
723
+ cleaned_text = re.sub(r'\n+', '\n', text)
724
+ return True, cleaned_text
725
+ except Exception as e:
726
+ return False, str(e)
727
+
728
+
729
+ class WebBrowser(BaseAction):
730
+ """Wrapper around the Web Browser Tool.
731
+ """
732
+
733
+ def __init__(self,
734
+ searcher_type: str = 'DuckDuckGoSearch',
735
+ timeout: int = 5,
736
+ black_list: Optional[List[str]] = [
737
+ 'enoN',
738
+ 'youtube.com',
739
+ 'bilibili.com',
740
+ 'researchgate.net',
741
+ ],
742
+ topk: int = 20,
743
+ description: Optional[dict] = None,
744
+ parser: Type[BaseParser] = JsonParser,
745
+ **kwargs):
746
+ self.searcher = eval(searcher_type)(
747
+ black_list=black_list, topk=topk, **kwargs)
748
+ self.fetcher = ContentFetcher(timeout=timeout)
749
+ self.search_results = None
750
+ super().__init__(description, parser)
751
+
752
+ @tool_api
753
+ def search(self, query: Union[str, List[str]]) -> dict:
754
+ """BING search API
755
+ Args:
756
+ query (List[str]): list of search query strings
757
+ """
758
+ queries = query if isinstance(query, list) else [query]
759
+ search_results = {}
760
+
761
+ with ThreadPoolExecutor() as executor:
762
+ future_to_query = {
763
+ executor.submit(self.searcher.search, q): q
764
+ for q in queries
765
+ }
766
+
767
+ for future in as_completed(future_to_query):
768
+ query = future_to_query[future]
769
+ try:
770
+ results = future.result()
771
+ except Exception as exc:
772
+ warnings.warn(f'{query} generated an exception: {exc}')
773
+ else:
774
+ for result in results.values():
775
+ if result['url'] not in search_results:
776
+ search_results[result['url']] = result
777
+ else:
778
+ search_results[
779
+ result['url']]['summ'] += f"\n{result['summ']}"
780
+
781
+ self.search_results = {
782
+ idx: result
783
+ for idx, result in enumerate(search_results.values())
784
+ }
785
+ return self.search_results
786
+
787
+ @tool_api
788
+ def select(self, select_ids: List[int]) -> dict:
789
+ """get the detailed content on the selected pages.
790
+
791
+ Args:
792
+ select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
793
+ """
794
+ if not self.search_results:
795
+ raise ValueError('No search results to select from.')
796
+
797
+ new_search_results = {}
798
+ with ThreadPoolExecutor() as executor:
799
+ future_to_id = {
800
+ executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
801
+ for select_id in select_ids if select_id in self.search_results
802
+ }
803
+ for future in as_completed(future_to_id):
804
+ select_id = future_to_id[future]
805
+ try:
806
+ web_success, web_content = future.result()
807
+ except Exception as exc:
808
+ warnings.warn(f'{select_id} generated an exception: {exc}')
809
+ else:
810
+ if web_success:
811
+ self.search_results[select_id][
812
+ 'content'] = web_content[:8192]
813
+ new_search_results[select_id] = self.search_results[
814
+ select_id].copy()
815
+ new_search_results[select_id].pop('summ')
816
+
817
+ return new_search_results
818
+
819
+ @tool_api
820
+ def open_url(self, url: str) -> dict:
821
+ print(f'Start Browsing: {url}')
822
+ web_success, web_content = self.fetcher.fetch(url)
823
+ if web_success:
824
+ return {'type': 'text', 'content': web_content}
825
+ else:
826
+ return {'error': web_content}
827
+
828
+
829
+ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
830
+ """Wrapper around the Web Browser Tool.
831
+ """
832
+
833
+ @tool_api
834
+ async def search(self, query: Union[str, List[str]]) -> dict:
835
+ """BING search API
836
+
837
+ Args:
838
+ query (List[str]): list of search query strings
839
+ """
840
+ queries = query if isinstance(query, list) else [query]
841
+ search_results = {}
842
+
843
+ tasks = []
844
+ for q in queries:
845
+ task = asyncio.create_task(self.searcher.asearch(q))
846
+ task.query = q
847
+ tasks.append(task)
848
+ async for future in async_as_completed(tasks):
849
+ query = future.query
850
+ try:
851
+ results = await future
852
+ except Exception as exc:
853
+ warnings.warn(f'{query} generated an exception: {exc}')
854
+ else:
855
+ for result in results.values():
856
+ if result['url'] not in search_results:
857
+ search_results[result['url']] = result
858
+ else:
859
+ search_results[
860
+ result['url']]['summ'] += f"\n{result['summ']}"
861
+
862
+ self.search_results = {
863
+ idx: result
864
+ for idx, result in enumerate(search_results.values())
865
+ }
866
+ return self.search_results
867
+
868
+ @tool_api
869
+ async def select(self, select_ids: List[int]) -> dict:
870
+ """get the detailed content on the selected pages.
871
+
872
+ Args:
873
+ select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
874
+ """
875
+ if not self.search_results:
876
+ raise ValueError('No search results to select from.')
877
+
878
+ new_search_results = {}
879
+ tasks = []
880
+ for select_id in select_ids:
881
+ if select_id in self.search_results:
882
+ task = asyncio.create_task(
883
+ self.fetcher.afetch(self.search_results[select_id]['url']))
884
+ task.select_id = select_id
885
+ tasks.append(task)
886
+ async for future in async_as_completed(tasks):
887
+ select_id = future.select_id
888
+ try:
889
+ web_success, web_content = await future
890
+ except Exception as exc:
891
+ warnings.warn(f'{select_id} generated an exception: {exc}')
892
+ else:
893
+ if web_success:
894
+ self.search_results[select_id][
895
+ 'content'] = web_content[:8192]
896
+ new_search_results[select_id] = self.search_results[
897
+ select_id].copy()
898
+ new_search_results[select_id].pop('summ')
899
+ return new_search_results
900
+
901
+ @tool_api
902
+ async def open_url(self, url: str) -> dict:
903
+ print(f'Start Browsing: {url}')
904
+ web_success, web_content = await self.fetcher.afetch(url)
905
+ if web_success:
906
+ return {'type': 'text', 'content': web_content}
907
+ else:
908
+ return {'error': web_content}
lagent/agents/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
2
+ from .react import AsyncReAct, ReAct
3
+ from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
4
+
5
+ __all__ = [
6
+ 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
7
+ 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
8
+ 'AsyncReAct', 'Sequential', 'AsyncSequential'
9
+ ]
lagent/agents/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (676 Bytes). View file
 
lagent/agents/__pycache__/agent.cpython-311.pyc ADDED
Binary file (24 kB). View file
 
lagent/agents/__pycache__/react.cpython-311.pyc ADDED
Binary file (8.92 kB). View file
 
lagent/agents/__pycache__/stream.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
lagent/agents/agent.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from collections import OrderedDict, UserDict, UserList, abc
4
+ from functools import wraps
5
+ from itertools import chain, repeat
6
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
7
+
8
+ from lagent.agents.aggregator import DefaultAggregator
9
+ from lagent.hooks import Hook, RemovableHandle
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory, MemoryManager
12
+ from lagent.prompts.parsers import StrParser
13
+ from lagent.prompts.prompt_template import PromptTemplate
14
+ from lagent.schema import AgentMessage
15
+ from lagent.utils import create_object
16
+
17
+
18
+ class Agent:
19
+ """Agent is the basic unit of the system. It is responsible for
20
+ communicating with the LLM, managing the memory, and handling the
21
+ message aggregation and parsing. It can also be extended with hooks
22
+
23
+ Args:
24
+ llm (Union[BaseLLM, Dict]): The language model used by the agent.
25
+ template (Union[PromptTemplate, str]): The template used to format the
26
+ messages.
27
+ memory (Dict): The memory used by the agent.
28
+ output_format (Dict): The output format used by the agent.
29
+ aggregator (Dict): The aggregator used by the agent.
30
+ name (Optional[str]): The name of the agent.
31
+ description (Optional[str]): The description of the agent.
32
+ hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent.
33
+
34
+ Returns:
35
+ AgentMessage: The response message.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ llm: Union[BaseLLM, Dict] = None,
41
+ template: Union[PromptTemplate, str, dict, List[dict]] = None,
42
+ memory: Dict = dict(type=Memory),
43
+ output_format: Optional[Dict] = None,
44
+ aggregator: Dict = dict(type=DefaultAggregator),
45
+ name: Optional[str] = None,
46
+ description: Optional[str] = None,
47
+ hooks: Optional[Union[List[Dict], Dict]] = None,
48
+ ):
49
+ self.name = name or self.__class__.__name__
50
+ self.llm: BaseLLM = create_object(llm)
51
+ self.memory: MemoryManager = MemoryManager(memory) if memory else None
52
+ self.output_format: StrParser = create_object(output_format)
53
+ self.template = template
54
+ self.description = description
55
+ self.aggregator: DefaultAggregator = create_object(aggregator)
56
+ self._hooks: Dict[int, Hook] = OrderedDict()
57
+ if hooks:
58
+ for hook in hooks:
59
+ hook = create_object(hook)
60
+ self.register_hook(hook)
61
+
62
+ def update_memory(self, message, session_id=0):
63
+ if self.memory:
64
+ self.memory.add(message, session_id=session_id)
65
+
66
+ def __call__(
67
+ self,
68
+ *message: Union[str, AgentMessage, List[AgentMessage]],
69
+ session_id=0,
70
+ **kwargs,
71
+ ) -> AgentMessage:
72
+ # message.receiver = self.name
73
+ message = [
74
+ AgentMessage(sender='user', content=m)
75
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
76
+ ]
77
+ for hook in self._hooks.values():
78
+ result = hook.before_agent(self, message, session_id)
79
+ if result:
80
+ message = result
81
+ self.update_memory(message, session_id=session_id)
82
+ response_message = self.forward(
83
+ *message, session_id=session_id, **kwargs)
84
+ if not isinstance(response_message, AgentMessage):
85
+ response_message = AgentMessage(
86
+ sender=self.name,
87
+ content=response_message,
88
+ )
89
+ self.update_memory(response_message, session_id=session_id)
90
+ response_message = copy.deepcopy(response_message)
91
+ for hook in self._hooks.values():
92
+ result = hook.after_agent(self, response_message, session_id)
93
+ if result:
94
+ response_message = result
95
+ return response_message
96
+
97
+ def forward(self,
98
+ *message: AgentMessage,
99
+ session_id=0,
100
+ **kwargs) -> Union[AgentMessage, str]:
101
+ formatted_messages = self.aggregator.aggregate(
102
+ self.memory.get(session_id),
103
+ self.name,
104
+ self.output_format,
105
+ self.template,
106
+ )
107
+ llm_response = self.llm.chat(formatted_messages, **kwargs)
108
+ if self.output_format:
109
+ formatted_messages = self.output_format.parse_response(
110
+ llm_response)
111
+ return AgentMessage(
112
+ sender=self.name,
113
+ content=llm_response,
114
+ formatted=formatted_messages,
115
+ )
116
+ return llm_response
117
+
118
+ def __setattr__(self, __name: str, __value: Any) -> None:
119
+ if isinstance(__value, Agent):
120
+ _agents = getattr(self, '_agents', OrderedDict())
121
+ _agents[__name] = __value
122
+ super().__setattr__('_agents', _agents)
123
+ super().__setattr__(__name, __value)
124
+
125
+ def state_dict(self, session_id=0):
126
+ state_dict, stack = {}, [('', self)]
127
+ while stack:
128
+ prefix, node = stack.pop()
129
+ key = prefix + 'memory'
130
+ if node.memory is not None:
131
+ if session_id not in node.memory.memory_map:
132
+ warnings.warn(f'No session id {session_id} in {key}')
133
+ memory = node.memory.get(session_id)
134
+ state_dict[key] = memory and memory.save() or []
135
+ if hasattr(node, '_agents'):
136
+ for name, value in reversed(node._agents.items()):
137
+ stack.append((prefix + name + '.', value))
138
+ return state_dict
139
+
140
+ def load_state_dict(self, state_dict: Dict, session_id=0):
141
+ _state_dict = self.state_dict()
142
+ missing_keys = set(_state_dict) - set(state_dict)
143
+ if missing_keys:
144
+ raise KeyError(f'Missing keys: {missing_keys}')
145
+ extra_keys = set(state_dict) - set(_state_dict)
146
+ if extra_keys:
147
+ warnings.warn(f'Mismatch keys which are not used: {extra_keys}')
148
+ for key in _state_dict:
149
+ obj = self
150
+ for attr in key.split('.')[:-1]:
151
+ if isinstance(obj, AgentList):
152
+ assert attr.isdigit()
153
+ obj = obj[int(attr)]
154
+ elif isinstance(obj, AgentDict):
155
+ obj = obj[attr]
156
+ else:
157
+ obj = getattr(obj, attr)
158
+ if obj.memory is not None:
159
+ if session_id not in obj.memory.memory_map:
160
+ obj.memory.create_instance(session_id)
161
+ obj.memory.memory_map[session_id].load(state_dict[key] or [])
162
+
163
+ def register_hook(self, hook: Callable):
164
+ handle = RemovableHandle(self._hooks)
165
+ self._hooks[handle.id] = hook
166
+ return handle
167
+
168
+ def reset(self,
169
+ session_id=0,
170
+ keypath: Optional[str] = None,
171
+ recursive: bool = False):
172
+ assert not (keypath and
173
+ recursive), 'keypath and recursive can\'t be used together'
174
+ if keypath:
175
+ keys, agent = keypath.split('.'), self
176
+ for key in keys:
177
+ agents = getattr(agent, '_agents', {})
178
+ if key not in agents:
179
+ raise KeyError(f'No sub-agent named {key} in {agent}')
180
+ agent = agents[key]
181
+ agent.reset(session_id, recursive=False)
182
+ else:
183
+ if self.memory:
184
+ self.memory.reset(session_id=session_id)
185
+ if recursive:
186
+ for agent in getattr(self, '_agents', {}).values():
187
+ agent.reset(session_id, recursive=True)
188
+
189
+ def __repr__(self):
190
+
191
+ def _rcsv_repr(agent, n_indent=1):
192
+ res = agent.__class__.__name__ + (f"(name='{agent.name}')"
193
+ if agent.name else '')
194
+ modules = [
195
+ f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
196
+ for name, agent in getattr(agent, '_agents', {}).items()
197
+ ]
198
+ if modules:
199
+ res += '(\n' + '\n'.join(
200
+ modules) + f'\n{(n_indent - 1) * " "})'
201
+ elif not res.endswith(')'):
202
+ res += '()'
203
+ return res
204
+
205
+ return _rcsv_repr(self)
206
+
207
+
208
+ class AsyncAgent(Agent):
209
+
210
+ async def __call__(self,
211
+ *message: AgentMessage | List[AgentMessage],
212
+ session_id=0,
213
+ **kwargs) -> AgentMessage:
214
+ message = [
215
+ AgentMessage(sender='user', content=m)
216
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
217
+ ]
218
+ for hook in self._hooks.values():
219
+ result = hook.before_agent(self, message, session_id)
220
+ if result:
221
+ message = result
222
+ self.update_memory(message, session_id=session_id)
223
+ response_message = await self.forward(
224
+ *message, session_id=session_id, **kwargs)
225
+ if not isinstance(response_message, AgentMessage):
226
+ response_message = AgentMessage(
227
+ sender=self.name,
228
+ content=response_message,
229
+ )
230
+ self.update_memory(response_message, session_id=session_id)
231
+ response_message = copy.deepcopy(response_message)
232
+ for hook in self._hooks.values():
233
+ result = hook.after_agent(self, response_message, session_id)
234
+ if result:
235
+ response_message = result
236
+ return response_message
237
+
238
+ async def forward(self,
239
+ *message: AgentMessage,
240
+ session_id=0,
241
+ **kwargs) -> Union[AgentMessage, str]:
242
+ formatted_messages = self.aggregator.aggregate(
243
+ self.memory.get(session_id),
244
+ self.name,
245
+ self.output_format,
246
+ self.template,
247
+ )
248
+ llm_response = await self.llm.chat(formatted_messages, session_id,
249
+ **kwargs)
250
+ if self.output_format:
251
+ formatted_messages = self.output_format.parse_response(
252
+ llm_response)
253
+ return AgentMessage(
254
+ sender=self.name,
255
+ content=llm_response,
256
+ formatted=formatted_messages,
257
+ )
258
+ return llm_response
259
+
260
+
261
+ class Sequential(Agent):
262
+ """Sequential is an agent container that forwards messages to each agent
263
+ in the order they are added."""
264
+
265
+ def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
266
+ super().__init__(**kwargs)
267
+ self._agents = OrderedDict()
268
+ if not agents:
269
+ raise ValueError('At least one agent should be provided')
270
+ if isinstance(agents[0],
271
+ Iterable) and not isinstance(agents[0], Agent):
272
+ if not agents[0]:
273
+ raise ValueError('At least one agent should be provided')
274
+ agents = agents[0]
275
+ for key, agent in enumerate(agents):
276
+ if isinstance(agents, Mapping):
277
+ key, agent = agent, agents[agent]
278
+ elif isinstance(agent, tuple):
279
+ key, agent = agent
280
+ self.add_agent(key, agent)
281
+
282
+ def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
283
+ assert isinstance(
284
+ agent, (Agent, AsyncAgent
285
+ )), f'{type(agent)} is not an Agent or AsyncAgent subclass'
286
+ self._agents[str(name)] = agent
287
+
288
+ def forward(self,
289
+ *message: AgentMessage,
290
+ session_id=0,
291
+ exit_at: Optional[int] = None,
292
+ **kwargs) -> AgentMessage:
293
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
294
+ if exit_at is None:
295
+ exit_at = len(self) - 1
296
+ iterator = chain.from_iterable(repeat(self._agents.values()))
297
+ for _ in range(exit_at + 1):
298
+ agent = next(iterator)
299
+ if isinstance(message, AgentMessage):
300
+ message = (message, )
301
+ message = agent(*message, session_id=session_id, **kwargs)
302
+ return message
303
+
304
+ def __getitem__(self, key):
305
+ if isinstance(key, int) and key < 0:
306
+ assert key >= -len(self), 'index out of range'
307
+ key = len(self) + key
308
+ return self._agents[str(key)]
309
+
310
+ def __len__(self):
311
+ return len(self._agents)
312
+
313
+
314
+ class AsyncSequential(Sequential, AsyncAgent):
315
+
316
+ async def forward(self,
317
+ *message: AgentMessage,
318
+ session_id=0,
319
+ exit_at: Optional[int] = None,
320
+ **kwargs) -> AgentMessage:
321
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
322
+ if exit_at is None:
323
+ exit_at = len(self) - 1
324
+ iterator = chain.from_iterable(repeat(self._agents.values()))
325
+ for _ in range(exit_at + 1):
326
+ agent = next(iterator)
327
+ if isinstance(message, AgentMessage):
328
+ message = (message, )
329
+ message = await agent(*message, session_id=session_id, **kwargs)
330
+ return message
331
+
332
+
333
+ class AgentContainerMixin:
334
+
335
+ def __init_subclass__(cls):
336
+ super().__init_subclass__()
337
+
338
+ def wrap_api(func):
339
+
340
+ @wraps(func)
341
+ def wrapped_func(self, *args, **kwargs):
342
+ data = self.data.copy() if hasattr(self, 'data') else None
343
+
344
+ def _backup(d):
345
+ if d is None:
346
+ self.data.clear()
347
+ else:
348
+ self.data = d
349
+
350
+ ret = func(self, *args, **kwargs)
351
+ agents = OrderedDict()
352
+ for k, item in (self.data.items() if isinstance(
353
+ self.data, abc.Mapping) else enumerate(self.data)):
354
+ if isinstance(self.data,
355
+ abc.Mapping) and not isinstance(k, str):
356
+ _backup(data)
357
+ raise KeyError(
358
+ f'agent name should be a string, got {type(k)}')
359
+ if isinstance(k, str) and '.' in k:
360
+ _backup(data)
361
+ raise KeyError(
362
+ f'agent name can\'t contain ".", got {k}')
363
+ if not isinstance(item, (Agent, AsyncAgent)):
364
+ _backup(data)
365
+ raise TypeError(
366
+ f'{type(item)} is not an Agent or AsyncAgent subclass'
367
+ )
368
+ agents[str(k)] = item
369
+ self._agents = agents
370
+ return ret
371
+
372
+ return wrapped_func
373
+
374
+ for method in [
375
+ 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
376
+ 'insert', 'extend', 'remove', '__init__', '__setitem__',
377
+ '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
378
+ '__imul__', '__rmul__'
379
+ ]:
380
+ if hasattr(cls, method):
381
+ setattr(cls, method, wrap_api(getattr(cls, method)))
382
+
383
+
384
+ class AgentList(Agent, UserList, AgentContainerMixin):
385
+
386
+ def __init__(self,
387
+ agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
388
+ Agent.__init__(self, memory=None)
389
+ UserList.__init__(self, agents)
390
+ self.name = None
391
+
392
+
393
+ class AgentDict(Agent, UserDict, AgentContainerMixin):
394
+
395
+ def __init__(self,
396
+ agents: Optional[Mapping[str, Union[Agent,
397
+ AsyncAgent]]] = None):
398
+ Agent.__init__(self, memory=None)
399
+ UserDict.__init__(self, agents)
400
+ self.name = None
lagent/agents/aggregator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .default_aggregator import DefaultAggregator
2
+ from .tool_aggregator import InternLMToolAggregator
3
+
4
+ __all__ = ['DefaultAggregator', 'InternLMToolAggregator']
lagent/agents/aggregator/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (358 Bytes). View file
 
lagent/agents/aggregator/__pycache__/default_aggregator.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-311.pyc ADDED
Binary file (5.61 kB). View file
 
lagent/agents/aggregator/default_aggregator.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ from lagent.memory import Memory
4
+ from lagent.prompts import StrParser
5
+
6
+
7
+ class DefaultAggregator:
8
+
9
+ def aggregate(self,
10
+ messages: Memory,
11
+ name: str,
12
+ parser: StrParser = None,
13
+ system_instruction: str = None) -> List[Dict[str, str]]:
14
+ _message = []
15
+ messages = messages.get_memory()
16
+ if system_instruction:
17
+ _message.extend(
18
+ self.aggregate_system_intruction(system_instruction))
19
+ for message in messages:
20
+ if message.sender == name:
21
+ _message.append(
22
+ dict(role='assistant', content=str(message.content)))
23
+ else:
24
+ user_message = message.content
25
+ if len(_message) > 0 and _message[-1]['role'] == 'user':
26
+ _message[-1]['content'] += user_message
27
+ else:
28
+ _message.append(dict(role='user', content=user_message))
29
+ return _message
30
+
31
+ @staticmethod
32
+ def aggregate_system_intruction(system_intruction) -> List[dict]:
33
+ if isinstance(system_intruction, str):
34
+ system_intruction = dict(role='system', content=system_intruction)
35
+ if isinstance(system_intruction, dict):
36
+ system_intruction = [system_intruction]
37
+ if isinstance(system_intruction, list):
38
+ for msg in system_intruction:
39
+ if not isinstance(msg, dict):
40
+ raise TypeError(f'Unsupported message type: {type(msg)}')
41
+ if not ('role' in msg and 'content' in msg):
42
+ raise KeyError(
43
+ f"Missing required key 'role' or 'content': {msg}")
44
+ return system_intruction
lagent/agents/aggregator/tool_aggregator.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ from lagent.agents.aggregator.default_aggregator import DefaultAggregator
4
+ from lagent.memory.base_memory import Memory
5
+ from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode
6
+
7
+
8
+ class InternLMToolAggregator(DefaultAggregator):
9
+
10
+ def __init__(self,
11
+ environment_role='environment',
12
+ environment_begin='',
13
+ environment_end='',
14
+ user_names: Optional[List[str]] = None,
15
+ few_shot: Optional[List[List[dict]]] = None):
16
+ self.environment_role = environment_role
17
+ self.environment_begin = environment_begin
18
+ self.environment_end = environment_end
19
+ self.user_names = user_names or ['user']
20
+ self.few_shot = few_shot or []
21
+
22
+ def aggregate(self,
23
+ messages: Memory,
24
+ name: str,
25
+ parser: Union[ToolParser, MixedToolParser],
26
+ system_instruction: str = None) -> List[Dict[str, str]]:
27
+ _message = []
28
+ messages = messages.get_memory()
29
+ if system_instruction:
30
+ _message.extend(
31
+ self.aggregate_system_intruction(system_instruction))
32
+ tool_instruction = parser.format_instruction()
33
+ if tool_instruction:
34
+ if isinstance(tool_instruction, str):
35
+ tool_instruction = dict(
36
+ role='system', content=tool_instruction)
37
+ if parser.tool_type:
38
+ tool_instruction['name'] = parser.tool_type
39
+ if isinstance(tool_instruction, dict):
40
+ tool_instruction = [tool_instruction]
41
+ _message.extend(tool_instruction)
42
+
43
+ for shot in self.few_shot:
44
+ i = 0
45
+ while i < len(shot):
46
+ msg = shot[i]
47
+ if msg['role'] in ['assistant', 'user', 'system']:
48
+ _message.append(msg)
49
+ elif msg['role'] == self.environment_role:
50
+ if not msg['content'].startswith(self.environment_begin):
51
+ msg['content'] = self.environment_begin + msg['content']
52
+ if not msg['content'].endswith(self.environment_end):
53
+ msg['content'] += self.environment_end
54
+ _message.append(msg)
55
+ elif msg['role'] in ['thought', 'language']:
56
+ if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool':
57
+ _message.append(
58
+ dict(
59
+ role='assistant',
60
+ content=parser.format_response(
61
+ dict(
62
+ tool_type=shot[i + 1]['name'],
63
+ thought=msg['content'],
64
+ action=shot[i + 1]['content'],
65
+ status=None))))
66
+ i += 1
67
+ else:
68
+ _message.append(
69
+ dict(
70
+ role='assistant',
71
+ content=parser.format_response(
72
+ dict(
73
+ tool_type=None,
74
+ thought=msg['content'],
75
+ action=None,
76
+ status=None))))
77
+ else:
78
+ raise KeyError(f'Unkown role: {msg["role"]}')
79
+ i += 1
80
+
81
+ tool_type = None
82
+ for message in messages:
83
+ if message.sender == name:
84
+ if isinstance(message.formatted, dict):
85
+ parsed = message.formatted
86
+ if parsed['status'] == ToolStatusCode.PARSING_ERROR:
87
+ continue
88
+ _message.append(
89
+ dict(
90
+ role='assistant',
91
+ content=parser.format_response(parsed)))
92
+ tool_type = parsed['tool_type']
93
+ else:
94
+ _message.append(
95
+ dict(role='assistant', content=str(message.content)))
96
+ elif message.sender in self.user_names:
97
+ _message.append(dict(role='user', content=message.content))
98
+ else:
99
+ msg = dict(
100
+ role=self.environment_role,
101
+ content=self.environment_begin + str(message.content) +
102
+ self.environment_end)
103
+ if tool_type:
104
+ msg['name'] = tool_type
105
+ _message.append(msg)
106
+ return _message
lagent/agents/react.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Callable, Dict, List, Union
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
7
+ from lagent.agents.agent import Agent, AsyncAgent
8
+ from lagent.agents.aggregator import DefaultAggregator
9
+ from lagent.hooks import ActionPreprocessor
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory
12
+ from lagent.prompts.parsers.json_parser import JSONParser
13
+ from lagent.prompts.prompt_template import PromptTemplate
14
+ from lagent.schema import AgentMessage
15
+ from lagent.utils import create_object
16
+
17
+ select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
18
+ {action_info}
19
+ {output_format}
20
+ 开始!"""
21
+
22
+ output_format_template = """如果使用工具请遵循以下格式回复:
23
+ {function_format}
24
+
25
+ 如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
26
+ {finish_format}"""
27
+
28
+
29
+ class ReAct(Agent):
30
+
31
+ def __init__(self,
32
+ llm: Union[BaseLLM, Dict],
33
+ actions: Union[BaseAction, List[BaseAction]],
34
+ template: Union[PromptTemplate, str] = None,
35
+ memory: Dict = dict(type=Memory),
36
+ output_format: Dict = dict(type=JSONParser),
37
+ aggregator: Dict = dict(type=DefaultAggregator),
38
+ hooks: List = [dict(type=ActionPreprocessor)],
39
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
40
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
41
+ max_turn: int = 5,
42
+ **kwargs):
43
+ self.max_turn = max_turn
44
+ self.finish_condition = finish_condition
45
+ actions = dict(
46
+ type=ActionExecutor,
47
+ actions=actions,
48
+ hooks=hooks,
49
+ )
50
+ self.actions: ActionExecutor = create_object(actions)
51
+ select_agent = dict(
52
+ type=Agent,
53
+ llm=llm,
54
+ template=template.format(
55
+ action_info=json.dumps(self.actions.description()),
56
+ output_format=output_format.format_instruction()),
57
+ output_format=output_format,
58
+ memory=memory,
59
+ aggregator=aggregator,
60
+ hooks=hooks,
61
+ )
62
+ self.select_agent = create_object(select_agent)
63
+ super().__init__(**kwargs)
64
+
65
+ def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
66
+ for _ in range(self.max_turn):
67
+ message = self.select_agent(message)
68
+ if self.finish_condition(message):
69
+ return message
70
+ message = self.actions(message)
71
+ return message
72
+
73
+
74
+ class AsyncReAct(AsyncAgent):
75
+
76
+ def __init__(self,
77
+ llm: Union[BaseLLM, Dict],
78
+ actions: Union[BaseAction, List[BaseAction]],
79
+ template: Union[PromptTemplate, str] = None,
80
+ memory: Dict = dict(type=Memory),
81
+ output_format: Dict = dict(type=JSONParser),
82
+ aggregator: Dict = dict(type=DefaultAggregator),
83
+ hooks: List = [dict(type=ActionPreprocessor)],
84
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
85
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
86
+ max_turn: int = 5,
87
+ **kwargs):
88
+ self.max_turn = max_turn
89
+ self.finish_condition = finish_condition
90
+ actions = dict(
91
+ type=AsyncActionExecutor,
92
+ actions=actions,
93
+ hooks=hooks,
94
+ )
95
+ self.actions: AsyncActionExecutor = create_object(actions)
96
+ select_agent = dict(
97
+ type=AsyncAgent,
98
+ llm=llm,
99
+ template=template.format(
100
+ action_info=json.dumps(self.actions.description()),
101
+ output_format=output_format.format_instruction()),
102
+ output_format=output_format,
103
+ memory=memory,
104
+ aggregator=aggregator,
105
+ hooks=hooks,
106
+ )
107
+ self.select_agent = create_object(select_agent)
108
+ super().__init__(**kwargs)
109
+
110
+ async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
111
+ for _ in range(self.max_turn):
112
+ message = await self.select_agent(message)
113
+ if self.finish_condition(message):
114
+ return message
115
+ message = await self.actions(message)
116
+ return message
117
+
118
+
119
+ if __name__ == '__main__':
120
+ from lagent.llms import GPTAPI
121
+
122
+ class ActionCall(BaseModel):
123
+ name: str = Field(description='调用的函数名称')
124
+ parameters: Dict = Field(description='调用函数的参数')
125
+
126
+ class ActionFormat(BaseModel):
127
+ thought_process: str = Field(
128
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
129
+ action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名��和参数。')
130
+
131
+ class FinishFormat(BaseModel):
132
+ thought_process: str = Field(
133
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
134
+ conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
135
+
136
+ prompt_template = PromptTemplate(select_action_template)
137
+ output_format = JSONParser(
138
+ output_format_template,
139
+ function_format=ActionFormat,
140
+ finish_format=FinishFormat)
141
+
142
+ llm = dict(
143
+ type=GPTAPI,
144
+ model_type='gpt-4o-2024-05-13',
145
+ key=None,
146
+ max_new_tokens=4096,
147
+ proxies=dict(),
148
+ retry=1000)
149
+
150
+ agent = ReAct(
151
+ llm=llm,
152
+ template=prompt_template,
153
+ output_format=output_format,
154
+ aggregator=dict(type='DefaultAggregator'),
155
+ actions=[dict(type='PythonInterpreter')],
156
+ )
157
+ response = agent(
158
+ AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
159
+ print(response)
160
+ response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
161
+ print(response)
lagent/agents/stream.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import warnings
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Union
5
+
6
+ from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive
7
+ from lagent.agents.agent import Agent, AsyncAgent
8
+ from lagent.agents.aggregator import InternLMToolAggregator
9
+ from lagent.hooks import InternLMActionProcessor
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory
12
+ from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode
13
+ from lagent.schema import AgentMessage
14
+ from lagent.utils import create_object
15
+
16
+ API_PREFIX = (
17
+ "This is the subfunction for tool '{tool_name}', you can use this tool. "
18
+ 'The description of this function is: \n{description}')
19
+
20
+ META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
21
+
22
+ INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
23
+ '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
24
+ '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
25
+ '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
26
+ '文本处理和分析(比如文本解析和自然语言处理),'
27
+ '机器学习和数据科学(用于展示模型训练和数据可视化),'
28
+ '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
29
+
30
+ PLUGIN_CN = ('你可以使用如下工具:'
31
+ '\n{prompt}\n'
32
+ '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
33
+ '同时注意你可以使用的工具,不要随意捏造!')
34
+
35
+
36
+ def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
37
+ plugin_descriptions = []
38
+ for action in actions if isinstance(actions, list) else [actions]:
39
+ action = create_object(action)
40
+ action_desc = deepcopy(action.description)
41
+ if action.is_toolkit:
42
+ for api in action_desc['api_list']:
43
+ api['name'] = f"{action.name}.{api['name']}"
44
+ api['description'] = api_desc_template.format(
45
+ tool_name=action.name, description=api['description'])
46
+ api['parameters'] = [
47
+ param for param in api['parameters']
48
+ if param['name'] in api['required']
49
+ ]
50
+ plugin_descriptions.append(api)
51
+ else:
52
+ action_desc['description'] = api_desc_template.format(
53
+ tool_name=action.name, description=action_desc['description'])
54
+ action_desc['parameters'] = [
55
+ param for param in action_desc['parameters']
56
+ if param['name'] in action_desc['required']
57
+ ]
58
+ plugin_descriptions.append(action_desc)
59
+ return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
60
+
61
+
62
+ class AgentForInternLM(Agent):
63
+
64
+ _INTERNAL_AGENT_CLS = Agent
65
+
66
+ def __init__(
67
+ self,
68
+ llm: Union[BaseLLM, Dict],
69
+ plugins: Union[dict, List[dict]] = None,
70
+ interpreter: dict = None,
71
+ template: Union[str, dict, List[dict]] = None,
72
+ memory: Dict = dict(type=Memory),
73
+ output_format: Dict = dict(
74
+ type=MixedToolParser,
75
+ template=META_CN,
76
+ parsers=[
77
+ dict(type=PluginParser, template=PLUGIN_CN),
78
+ dict(type=InterpreterParser, template=INTERPRETER_CN),
79
+ ]),
80
+ aggregator: Dict = dict(type=InternLMToolAggregator),
81
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
82
+ finish_condition: Callable[
83
+ [AgentMessage],
84
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
85
+ max_turn: int = 4,
86
+ **kwargs,
87
+ ):
88
+ agent = dict(
89
+ type=self._INTERNAL_AGENT_CLS,
90
+ llm=llm,
91
+ template=template,
92
+ output_format=output_format,
93
+ memory=memory,
94
+ aggregator=aggregator,
95
+ hooks=kwargs.pop('hooks', None),
96
+ )
97
+ self.agent = create_object(agent)
98
+ self.plugin_executor = plugins and ActionExecutor(
99
+ plugins, hooks=action_hooks)
100
+ self.interpreter_executor = interpreter and ActionExecutor(
101
+ interpreter, hooks=action_hooks)
102
+ if not (self.plugin_executor or self.interpreter_executor):
103
+ warnings.warn(
104
+ 'Neither plugin nor interpreter executor is initialized. '
105
+ 'An exception will be thrown when the agent call a tool.')
106
+ self.finish_condition = finish_condition
107
+ self.max_turn = max_turn
108
+ super().__init__(**kwargs)
109
+
110
+ def forward(self, message: AgentMessage, session_id=0, **kwargs):
111
+ if isinstance(message, str):
112
+ message = AgentMessage(sender='user', content=message)
113
+ for _ in range(self.max_turn):
114
+ message = self.agent(message, session_id=session_id, **kwargs)
115
+ assert isinstance(message.formatted, dict)
116
+ if self.finish_condition(message):
117
+ return message
118
+ if message.formatted['tool_type']:
119
+ tool_type = message.formatted["tool_type"]
120
+ executor = getattr(self, f'{tool_type}_executor', None)
121
+ if not executor:
122
+ raise RuntimeError(f'No available {tool_type} executor')
123
+ message = executor(message, session_id=session_id)
124
+ return message
125
+
126
+ def get_steps(self, session_id=0):
127
+ steps, tool_type = [], None
128
+ for msg in self.agent.memory.get_memory(session_id):
129
+ if msg.sender == self.agent.name:
130
+ steps.append(
131
+ dict(role='thought', content=msg.formatted['thought']))
132
+ if msg.formatted['tool_type']:
133
+ tool_type = msg.formatted['tool_type']
134
+ steps.append(
135
+ dict(
136
+ role='tool',
137
+ content=msg.formatted['action'],
138
+ name=tool_type))
139
+ elif msg.sender != 'user':
140
+ feedback = dict(role='environment', content=msg.content)
141
+ if tool_type:
142
+ feedback['name'] = tool_type
143
+ steps.append(feedback)
144
+ return steps
145
+
146
+
147
+ class MathCoder(AgentForInternLM):
148
+
149
+ def __init__(
150
+ self,
151
+ llm: Union[BaseLLM, Dict],
152
+ interpreter: dict = dict(
153
+ type=IPythonInteractive, timeout=20, max_out_len=8192),
154
+ template: Union[str, dict, List[dict]] = None,
155
+ memory: Dict = dict(type=Memory),
156
+ output_format: Dict = dict(
157
+ type=InterpreterParser,
158
+ template=
159
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
160
+ 'using the following guidelines:\n'
161
+ '- Analyze the question and write jupyter code to solve the problem;\n'
162
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
163
+ 'units. \n')),
164
+ aggregator: Dict = dict(type=InternLMToolAggregator),
165
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
166
+ finish_condition: Callable[
167
+ [AgentMessage],
168
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
169
+ max_turn: int = 6,
170
+ **kwargs,
171
+ ):
172
+ kwargs.pop('plugins', None)
173
+ super().__init__(
174
+ llm=llm,
175
+ interpreter=interpreter,
176
+ template=template,
177
+ memory=memory,
178
+ output_format=output_format,
179
+ aggregator=aggregator,
180
+ action_hooks=action_hooks,
181
+ finish_condition=finish_condition,
182
+ max_turn=max_turn,
183
+ **kwargs)
184
+
185
+
186
+ class AsyncAgentForInternLM(AsyncAgent):
187
+
188
+ _INTERNAL_AGENT_CLS = AsyncAgent
189
+
190
+ def __init__(
191
+ self,
192
+ llm: Union[BaseLLM, Dict],
193
+ plugins: Union[dict, List[dict]] = None,
194
+ interpreter: dict = None,
195
+ template: Union[str, dict, List[dict]] = None,
196
+ memory: Dict = dict(type=Memory),
197
+ output_format: Dict = dict(
198
+ type=MixedToolParser,
199
+ template=META_CN,
200
+ parsers=[
201
+ dict(type=PluginParser, template=PLUGIN_CN),
202
+ dict(type=InterpreterParser, template=INTERPRETER_CN),
203
+ ]),
204
+ aggregator: Dict = dict(type=InternLMToolAggregator),
205
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
206
+ finish_condition: Callable[
207
+ [AgentMessage],
208
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
209
+ max_turn: int = 4,
210
+ **kwargs,
211
+ ):
212
+ agent = dict(
213
+ type=self._INTERNAL_AGENT_CLS,
214
+ llm=llm,
215
+ template=template,
216
+ output_format=output_format,
217
+ memory=memory,
218
+ aggregator=aggregator,
219
+ hooks=kwargs.pop('hooks', None),
220
+ )
221
+ self.agent = create_object(agent)
222
+ self.plugin_executor = plugins and AsyncActionExecutor(
223
+ plugins, hooks=action_hooks)
224
+ self.interpreter_executor = interpreter and AsyncActionExecutor(
225
+ interpreter, hooks=action_hooks)
226
+ if not (self.plugin_executor or self.interpreter_executor):
227
+ warnings.warn(
228
+ 'Neither plugin nor interpreter executor is initialized. '
229
+ 'An exception will be thrown when the agent call a tool.')
230
+ self.finish_condition = finish_condition
231
+ self.max_turn = max_turn
232
+ super().__init__(**kwargs)
233
+
234
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
235
+ if isinstance(message, str):
236
+ message = AgentMessage(sender='user', content=message)
237
+ for _ in range(self.max_turn):
238
+ message = await self.agent(
239
+ message, session_id=session_id, **kwargs)
240
+ assert isinstance(message.formatted, dict)
241
+ if self.finish_condition(message):
242
+ return message
243
+ if message.formatted['tool_type']:
244
+ tool_type = message.formatted["tool_type"]
245
+ executor = getattr(self, f'{tool_type}_executor', None)
246
+ if not executor:
247
+ raise RuntimeError(f'No available {tool_type} executor')
248
+ message = await executor(message, session_id=session_id)
249
+ return message
250
+
251
+ def get_steps(self, session_id=0):
252
+ steps, tool_type = [], None
253
+ for msg in self.agent.memory.get_memory(session_id):
254
+ if msg.sender == self.agent.name:
255
+ steps.append(
256
+ dict(role='thought', content=msg.formatted['thought']))
257
+ if msg.formatted['tool_type']:
258
+ tool_type = msg.formatted['tool_type']
259
+ steps.append(
260
+ dict(
261
+ role='tool',
262
+ content=msg.formatted['action'],
263
+ name=tool_type))
264
+ elif msg.sender != 'user':
265
+ feedback = dict(role='environment', content=msg.content)
266
+ if tool_type:
267
+ feedback['name'] = tool_type
268
+ steps.append(feedback)
269
+ return steps
270
+
271
+
272
+ class AsyncMathCoder(AsyncAgentForInternLM):
273
+
274
+ def __init__(
275
+ self,
276
+ llm: Union[BaseLLM, Dict],
277
+ interpreter: dict = dict(type=AsyncIPythonInterpreter),
278
+ template: Union[str, dict, List[dict]] = None,
279
+ memory: Dict = dict(type=Memory),
280
+ output_format: Dict = dict(
281
+ type=InterpreterParser,
282
+ template=
283
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
284
+ 'using the following guidelines:\n'
285
+ '- Analyze the question and write jupyter code to solve the problem;\n'
286
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
287
+ 'units. \n')),
288
+ aggregator: Dict = dict(type=InternLMToolAggregator),
289
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
290
+ finish_condition: Callable[
291
+ [AgentMessage],
292
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
293
+ max_turn: int = 6,
294
+ **kwargs,
295
+ ):
296
+ kwargs.pop('plugins', None)
297
+ super().__init__(
298
+ llm=llm,
299
+ interpreter=interpreter,
300
+ template=template,
301
+ memory=memory,
302
+ output_format=output_format,
303
+ aggregator=aggregator,
304
+ action_hooks=action_hooks,
305
+ finish_condition=finish_condition,
306
+ max_turn=max_turn,
307
+ **kwargs)
308
+
309
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
310
+ try:
311
+ return await super().forward(message, session_id, **kwargs)
312
+ finally:
313
+ interpreter = next(
314
+ iter(self.interpreter_executor.actions.values()))
315
+ if interpreter.name == 'AsyncIPythonInterpreter':
316
+ await interpreter.close_session(session_id)