| | |
| | import asyncio |
| | import base64 |
| | import io |
| | import json |
| | import logging |
| | import os |
| | import queue |
| | import re |
| | import signal |
| | import sys |
| | import tempfile |
| | import traceback |
| | import uuid |
| | from typing import Optional, Tuple, Type |
| |
|
| | from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager |
| | from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed |
| |
|
| | from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api |
| | from lagent.actions.parser import BaseParser, JsonParser |
| | from lagent.schema import ActionReturn, ActionStatusCode |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | START_CODE = """ |
| | def input(*args, **kwargs): |
| | raise NotImplementedError('Python input() function is disabled.') |
| | |
| | get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') |
| | {} |
| | """ |
| |
|
| |
|
| | class TimeoutError(Exception): |
| | pass |
| |
|
| |
|
| | class KernelDeath(Exception): |
| | pass |
| |
|
| |
|
| | async def async_run_code( |
| | km: AsyncKernelManager, |
| | code, |
| | *, |
| | interrupt_after=30, |
| | iopub_timeout=40, |
| | wait_for_ready_timeout=60, |
| | shutdown_kernel=True, |
| | ): |
| | assert iopub_timeout > interrupt_after |
| | try: |
| |
|
| | async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, |
| | *, |
| | timeout=None): |
| | loop = asyncio.get_running_loop() |
| | dead_fut = loop.create_future() |
| |
|
| | def restarting(): |
| | assert ( |
| | False |
| | ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0" |
| |
|
| | def dead(): |
| | logger.info("Kernel has died, will NOT restart") |
| | dead_fut.set_result(None) |
| |
|
| | msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout)) |
| | km.add_restart_callback(restarting, "restart") |
| | km.add_restart_callback(dead, "dead") |
| | try: |
| | done, _ = await asyncio.wait( |
| | [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) |
| | if dead_fut in done: |
| | raise KernelDeath() |
| | assert msg_task in done |
| | return await msg_task |
| | finally: |
| | msg_task.cancel() |
| | km.remove_restart_callback(restarting, "restart") |
| | km.remove_restart_callback(dead, "dead") |
| |
|
| | async def send_interrupt(): |
| | await asyncio.sleep(interrupt_after) |
| | logger.info("Sending interrupt to kernel") |
| | await km.interrupt_kernel() |
| |
|
| | @retry( |
| | retry=retry_if_result(lambda ret: ret[-1].strip() in [ |
| | 'KeyboardInterrupt', |
| | f"Kernel didn't respond in {wait_for_ready_timeout} seconds", |
| | ] if isinstance(ret, tuple) else False), |
| | stop=stop_after_attempt(3), |
| | wait=wait_fixed(1), |
| | retry_error_callback=lambda state: state.outcome.result()) |
| | async def run(): |
| | execute_result = None |
| | error_traceback = None |
| | stream_text_list = [] |
| | kc = km.client() |
| | assert isinstance(kc, AsyncKernelClient) |
| | kc.start_channels() |
| | try: |
| | await kc.wait_for_ready(timeout=wait_for_ready_timeout) |
| | msg_id = kc.execute(code) |
| | while True: |
| | message = await get_iopub_msg_with_death_detection( |
| | kc, timeout=iopub_timeout) |
| | if logger.isEnabledFor(logging.DEBUG): |
| | logger.debug( |
| | json.dumps(message, indent=2, default=str)) |
| | assert message["parent_header"]["msg_id"] == msg_id |
| | msg_type = message["msg_type"] |
| | if msg_type == "status": |
| | if message["content"]["execution_state"] == "idle": |
| | break |
| | elif msg_type == "stream": |
| | stream_name = message["content"]["name"] |
| | stream_text = message["content"]["text"] |
| | stream_text_list.append(stream_text) |
| | elif msg_type == "execute_result": |
| | execute_result = message["content"]["data"] |
| | elif msg_type == "error": |
| | error_traceback_lines = message["content"]["traceback"] |
| | error_traceback = "\n".join(error_traceback_lines) |
| | elif msg_type == "execute_input": |
| | pass |
| | else: |
| | assert False, f"Unknown message_type: {msg_type}" |
| | finally: |
| | kc.stop_channels() |
| | return execute_result, error_traceback, "".join(stream_text_list) |
| |
|
| | if interrupt_after: |
| | run_task = asyncio.create_task(run()) |
| | send_interrupt_task = asyncio.create_task(send_interrupt()) |
| | done, _ = await asyncio.wait([run_task, send_interrupt_task], |
| | return_when=asyncio.FIRST_COMPLETED) |
| | if run_task in done: |
| | send_interrupt_task.cancel() |
| | else: |
| | assert send_interrupt_task in done |
| | result = await run_task |
| | else: |
| | result = await run() |
| | return result |
| | finally: |
| | if shutdown_kernel: |
| | await km.shutdown_kernel() |
| |
|
| |
|
| | class IPythonInterpreter(BaseAction): |
| | """A IPython executor that can execute Python scripts in a jupyter manner. |
| | |
| | Args: |
| | timeout (int): Upper bound of waiting time for Python script execution. |
| | Defaults to 20. |
| | user_data_dir (str, optional): Specified the user data directory for files |
| | loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. |
| | Defaults to `ENV`. |
| | work_dir (str, optional): Specify which directory to save output images to. |
| | Defaults to ``'./work_dir/tmp_dir'``. |
| | description (dict): The description of the action. Defaults to ``None``. |
| | parser (Type[BaseParser]): The parser class to process the |
| | action's inputs and outputs. Defaults to :class:`JsonParser`. |
| | """ |
| |
|
| | _KERNEL_CLIENTS = {} |
| |
|
| | def __init__( |
| | self, |
| | timeout: int = 20, |
| | user_data_dir: str = 'ENV', |
| | work_dir='./work_dir/tmp_dir', |
| | description: Optional[dict] = None, |
| | parser: Type[BaseParser] = JsonParser, |
| | ): |
| | super().__init__(description, parser) |
| |
|
| | self.timeout = timeout |
| | if user_data_dir == 'ENV': |
| | user_data_dir = os.environ.get('USER_DATA_DIR', '') |
| |
|
| | if user_data_dir: |
| | user_data_dir = os.path.dirname(user_data_dir) |
| | user_data_dir = f"import os\nos.chdir('{user_data_dir}')" |
| | self.user_data_dir = user_data_dir |
| | self._initialized = False |
| | self.work_dir = work_dir |
| | if not os.path.exists(self.work_dir): |
| | os.makedirs(self.work_dir, exist_ok=True) |
| |
|
| | @staticmethod |
| | def start_kernel(): |
| | from jupyter_client import KernelManager |
| |
|
| | |
| | km = KernelManager() |
| | km.start_kernel() |
| | kc = km.client() |
| | return km, kc |
| |
|
| | def initialize(self): |
| | if self._initialized: |
| | return |
| | pid = os.getpid() |
| | if pid not in self._KERNEL_CLIENTS: |
| | self._KERNEL_CLIENTS[pid] = self.start_kernel() |
| | self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] |
| | self._initialized = True |
| | self._call(START_CODE.format(self.user_data_dir), None) |
| |
|
| | def reset(self): |
| | if not self._initialized: |
| | self.initialize() |
| | else: |
| | code = "get_ipython().run_line_magic('reset', '-f')\n" + \ |
| | START_CODE.format(self.user_data_dir) |
| | self._call(code, None) |
| |
|
| | def _call(self, |
| | command: str, |
| | timeout: Optional[int] = None) -> Tuple[str, bool]: |
| | self.initialize() |
| | command = extract_code(command) |
| |
|
| | |
| | while True: |
| | try: |
| | msg = self.kernel_client.get_iopub_msg(timeout=5) |
| | msg_type = msg['msg_type'] |
| | if msg_type == 'status': |
| | if msg['content'].get('execution_state') == 'idle': |
| | break |
| | except queue.Empty: |
| | |
| | break |
| |
|
| | self.kernel_client.execute(command) |
| |
|
| | def _inner_call(): |
| | result = '' |
| | images = [] |
| | succeed = True |
| | image_idx = 0 |
| |
|
| | while True: |
| | text = '' |
| | image = '' |
| | finished = False |
| | msg_type = 'error' |
| | try: |
| | msg = self.kernel_client.get_iopub_msg(timeout=20) |
| | msg_type = msg['msg_type'] |
| | if msg_type == 'status': |
| | if msg['content'].get('execution_state') == 'idle': |
| | finished = True |
| | elif msg_type == 'execute_result': |
| | text = msg['content']['data'].get('text/plain', '') |
| | if 'image/png' in msg['content']['data']: |
| | image_b64 = msg['content']['data']['image/png'] |
| | image_url = publish_image_to_local( |
| | image_b64, self.work_dir) |
| | image_idx += 1 |
| | image = '' % (image_idx, image_url) |
| |
|
| | elif msg_type == 'display_data': |
| | if 'image/png' in msg['content']['data']: |
| | image_b64 = msg['content']['data']['image/png'] |
| | image_url = publish_image_to_local( |
| | image_b64, self.work_dir) |
| | image_idx += 1 |
| | image = '' % (image_idx, image_url) |
| |
|
| | else: |
| | text = msg['content']['data'].get('text/plain', '') |
| | elif msg_type == 'stream': |
| | msg_type = msg['content']['name'] |
| | text = msg['content']['text'] |
| | elif msg_type == 'error': |
| | succeed = False |
| | text = escape_ansi('\n'.join( |
| | msg['content']['traceback'])) |
| | if 'M6_CODE_INTERPRETER_TIMEOUT' in text: |
| | text = f'Timeout. No response after {timeout} seconds.' |
| | except queue.Empty: |
| | |
| | self.kernel_manager.interrupt_kernel() |
| | succeed = False |
| | text = f'Timeout. No response after {timeout} seconds.' |
| | finished = True |
| | except Exception: |
| | succeed = False |
| | msg = ''.join(traceback.format_exception(*sys.exc_info())) |
| | |
| | text = msg |
| | logging.warning(msg) |
| | finished = True |
| | if text: |
| | |
| | result += f'{text}' |
| |
|
| | if image: |
| | images.append(image_url) |
| | if finished: |
| | return succeed, dict(text=result, image=images) |
| |
|
| | try: |
| | if timeout: |
| |
|
| | def handler(signum, frame): |
| | raise TimeoutError() |
| |
|
| | signal.signal(signal.SIGALRM, handler) |
| | signal.alarm(timeout) |
| | succeed, result = _inner_call() |
| | except TimeoutError: |
| | succeed = False |
| | text = 'The code interpreter encountered an unexpected error.' |
| | result = f'\n\nerror:\n\n```\n{text}\n```' |
| | finally: |
| | if timeout: |
| | signal.alarm(0) |
| |
|
| | |
| | return succeed, result |
| |
|
| | @tool_api |
| | def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: |
| | r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. |
| | |
| | Args: |
| | command (:class:`str`): Python code |
| | timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. |
| | """ |
| | tool_return = ActionReturn(url=None, args=None, type=self.name) |
| | tool_return.args = dict(text=command) |
| | succeed, result = self._call(command, timeout) |
| | if succeed: |
| | text = result['text'] |
| | image = result.get('image', []) |
| | resp = [dict(type='text', content=text)] |
| | if image: |
| | resp.extend([dict(type='image', content=im) for im in image]) |
| | tool_return.result = resp |
| | |
| | |
| | tool_return.state = ActionStatusCode.SUCCESS |
| | else: |
| | tool_return.errmsg = result.get('text', '') if isinstance( |
| | result, dict) else result |
| | tool_return.state = ActionStatusCode.API_ERROR |
| | return tool_return |
| |
|
| |
|
| | class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): |
| | """A IPython executor that can execute Python scripts in a jupyter manner. |
| | |
| | Args: |
| | timeout (int): Upper bound of waiting time for Python script execution. |
| | Defaults to 20. |
| | user_data_dir (str, optional): Specified the user data directory for files |
| | loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. |
| | Defaults to `ENV`. |
| | work_dir (str, optional): Specify which directory to save output images to. |
| | Defaults to ``'./work_dir/tmp_dir'``. |
| | description (dict): The description of the action. Defaults to ``None``. |
| | parser (Type[BaseParser]): The parser class to process the |
| | action's inputs and outputs. Defaults to :class:`JsonParser`. |
| | """ |
| |
|
| | _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() |
| |
|
| | def __init__( |
| | self, |
| | timeout: int = 20, |
| | user_data_dir: str = 'ENV', |
| | work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'), |
| | max_kernels: Optional[int] = None, |
| | reuse_kernel: bool = True, |
| | startup_rate: bool = 32, |
| | connection_dir: str = tempfile.gettempdir(), |
| | description: Optional[dict] = None, |
| | parser: Type[BaseParser] = JsonParser, |
| | ): |
| | super().__init__(timeout, user_data_dir, work_dir, description, parser) |
| | from traitlets.config import Config |
| |
|
| | c = Config() |
| | c.KernelManager.transport = 'ipc' |
| | self._amkm = AsyncMultiKernelManager( |
| | config=c, connection_dir=connection_dir) |
| | self._max_kernels = max_kernels |
| | self._reuse_kernel = reuse_kernel |
| | self._sem = asyncio.Semaphore(startup_rate) |
| | self._lock = asyncio.Lock() |
| |
|
| | async def initialize(self, session_id: str): |
| | session_id = str(session_id) |
| | while True: |
| | if session_id in self._KERNEL_CLIENTS: |
| | return self._KERNEL_CLIENTS[session_id] |
| | if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): |
| | self._KERNEL_CLIENTS[ |
| | session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() |
| | return self._KERNEL_CLIENTS[session_id] |
| | async with self._sem: |
| | if self._max_kernels is None or len( |
| | self._KERNEL_CLIENTS |
| | ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: |
| | kernel_id = None |
| | try: |
| | kernel_id = await self._amkm.start_kernel() |
| | kernel = self._amkm.get_kernel(kernel_id) |
| | client = kernel.client() |
| | _, error_stacktrace, stream_text = await async_run_code( |
| | kernel, |
| | START_CODE.format(self.user_data_dir), |
| | shutdown_kernel=False) |
| | |
| | if not (error_stacktrace is None |
| | and stream_text == ''): |
| | raise RuntimeError |
| | except Exception as e: |
| | print(f'Starting kernel error: {e}') |
| | if kernel_id: |
| | await self._amkm.shutdown_kernel(kernel_id) |
| | self._amkm.remove_kernel(kernel_id) |
| | await asyncio.sleep(1) |
| | continue |
| | if self._max_kernels is None: |
| | self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, |
| | client) |
| | return kernel_id, kernel, client |
| | async with self._lock: |
| | if len(self._KERNEL_CLIENTS |
| | ) + self._UNBOUND_KERNEL_CLIENTS.qsize( |
| | ) < self._max_kernels: |
| | self._KERNEL_CLIENTS[session_id] = (kernel_id, |
| | kernel, client) |
| | return kernel_id, kernel, client |
| | await self._amkm.shutdown_kernel(kernel_id) |
| | self._amkm.remove_kernel(kernel_id) |
| | await asyncio.sleep(1) |
| |
|
| | async def reset(self, session_id: str): |
| | session_id = str(session_id) |
| | if session_id not in self._KERNEL_CLIENTS: |
| | return |
| | _, kernel, _ = self._KERNEL_CLIENTS[session_id] |
| | code = "get_ipython().run_line_magic('reset', '-f')\n" + \ |
| | START_CODE.format(self.user_data_dir) |
| | await async_run_code(kernel, code, shutdown_kernel=False) |
| |
|
| | async def shutdown(self, session_id: str): |
| | session_id = str(session_id) |
| | if session_id in self._KERNEL_CLIENTS: |
| | kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id) |
| | await self._amkm.shutdown_kernel(kernel_id) |
| | self._amkm.remove_kernel(kernel_id) |
| | del self._KERNEL_CLIENTS[session_id] |
| |
|
| | async def close_session(self, session_id: str): |
| | session_id = str(session_id) |
| | if self._reuse_kernel: |
| | if session_id in self._KERNEL_CLIENTS: |
| | await self.reset(session_id) |
| | await self._UNBOUND_KERNEL_CLIENTS.put( |
| | self._KERNEL_CLIENTS.pop(session_id)) |
| | else: |
| | await self.shutdown(session_id) |
| |
|
| | async def _call(self, command, timeout=None, session_id=None): |
| | _, kernel, _ = await self.initialize(str(session_id)) |
| | result = await async_run_code( |
| | kernel, |
| | extract_code(command), |
| | interrupt_after=timeout or self.timeout, |
| | shutdown_kernel=False) |
| | execute_result, error_stacktrace, stream_text = result |
| | if error_stacktrace is not None: |
| | ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) |
| | if ret.endswith('KeyboardInterrupt: '): |
| | ret = 'The code interpreter encountered a timeout error.' |
| | status, ret = False, ret.strip() |
| | elif execute_result is not None: |
| | status, ret = True, dict(text=execute_result.get('text/plain', '')) |
| | else: |
| | status, ret = True, dict(text=stream_text.strip()) |
| | return status, ret |
| |
|
| | @tool_api |
| | async def run(self, |
| | command: str, |
| | timeout: Optional[int] = None, |
| | session_id: Optional[str] = None) -> ActionReturn: |
| | r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. |
| | |
| | Args: |
| | command (:class:`str`): Python code |
| | timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. |
| | """ |
| | tool_return = ActionReturn(url=None, args=None, type=self.name) |
| | tool_return.args = dict(text=command) |
| | succeed, result = await self._call(command, timeout, session_id) |
| | if succeed: |
| | text = result['text'] |
| | image = result.get('image', []) |
| | resp = [dict(type='text', content=text)] |
| | if image: |
| | resp.extend([dict(type='image', content=im) for im in image]) |
| | tool_return.result = resp |
| | |
| | |
| | tool_return.state = ActionStatusCode.SUCCESS |
| | else: |
| | tool_return.errmsg = result.get('text', '') if isinstance( |
| | result, dict) else result |
| | tool_return.state = ActionStatusCode.API_ERROR |
| | return tool_return |
| |
|
| |
|
| | def extract_code(text): |
| | import json5 |
| |
|
| | |
| | triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) |
| | |
| | single_match = re.search(r'`([^`]*)`', text, re.DOTALL) |
| | if triple_match: |
| | text = triple_match.group(1) |
| | elif single_match: |
| | text = single_match.group(1) |
| | else: |
| | try: |
| | text = json5.loads(text)['code'] |
| | except Exception: |
| | pass |
| | |
| | return text |
| |
|
| |
|
| | def escape_ansi(line): |
| | ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') |
| | return ansi_escape.sub('', line) |
| |
|
| |
|
| | def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): |
| | import PIL.Image |
| | image_file = str(uuid.uuid4()) + '.png' |
| | local_image_file = os.path.join(work_dir, image_file) |
| |
|
| | png_bytes = base64.b64decode(image_base64) |
| | assert isinstance(png_bytes, bytes) |
| | bytes_io = io.BytesIO(png_bytes) |
| | PIL.Image.open(bytes_io).save(local_image_file, 'png') |
| |
|
| | return local_image_file |
| |
|
| |
|
| | |
| | def get_multiline_input(hint): |
| | print(hint) |
| | print('// Press ENTER to make a new line. Press CTRL-D to end input.') |
| | lines = [] |
| | while True: |
| | try: |
| | line = input() |
| | except EOFError: |
| | break |
| | lines.append(line) |
| | print('// Input received.') |
| | if lines: |
| | return '\n'.join(lines) |
| | else: |
| | return '' |
| |
|
| |
|
| | if __name__ == '__main__': |
| | code_interpreter = IPythonInterpreter() |
| | while True: |
| | print(code_interpreter(get_multiline_input('Enter python code:'))) |
| |
|