| |
| import asyncio |
| import base64 |
| import io |
| import json |
| import logging |
| import os |
| import queue |
| import re |
| import signal |
| import sys |
| import tempfile |
| import traceback |
| import uuid |
| from typing import Optional, Tuple, Type |
|
|
| from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager |
| from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed |
|
|
| from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api |
| from lagent.actions.parser import BaseParser, JsonParser |
| from lagent.schema import ActionReturn, ActionStatusCode |
|
|
| logger = logging.getLogger(__name__) |
|
|
| START_CODE = """ |
| def input(*args, **kwargs): |
| raise NotImplementedError('Python input() function is disabled.') |
| |
| get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') |
| {} |
| """ |
|
|
|
|
| class TimeoutError(Exception): |
| pass |
|
|
|
|
| class KernelDeath(Exception): |
| pass |
|
|
|
|
| async def async_run_code( |
| km: AsyncKernelManager, |
| code, |
| *, |
| interrupt_after=30, |
| iopub_timeout=40, |
| wait_for_ready_timeout=60, |
| shutdown_kernel=True, |
| ): |
| assert iopub_timeout > interrupt_after |
| try: |
|
|
| async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, |
| *, |
| timeout=None): |
| loop = asyncio.get_running_loop() |
| dead_fut = loop.create_future() |
|
|
| def restarting(): |
| assert ( |
| False |
| ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0" |
|
|
| def dead(): |
| logger.info("Kernel has died, will NOT restart") |
| dead_fut.set_result(None) |
|
|
| msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout)) |
| km.add_restart_callback(restarting, "restart") |
| km.add_restart_callback(dead, "dead") |
| try: |
| done, _ = await asyncio.wait( |
| [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) |
| if dead_fut in done: |
| raise KernelDeath() |
| assert msg_task in done |
| return await msg_task |
| finally: |
| msg_task.cancel() |
| km.remove_restart_callback(restarting, "restart") |
| km.remove_restart_callback(dead, "dead") |
|
|
| async def send_interrupt(): |
| await asyncio.sleep(interrupt_after) |
| logger.info("Sending interrupt to kernel") |
| await km.interrupt_kernel() |
|
|
| @retry( |
| retry=retry_if_result(lambda ret: ret[-1].strip() in [ |
| 'KeyboardInterrupt', |
| f"Kernel didn't respond in {wait_for_ready_timeout} seconds", |
| ] if isinstance(ret, tuple) else False), |
| stop=stop_after_attempt(3), |
| wait=wait_fixed(1), |
| retry_error_callback=lambda state: state.outcome.result()) |
| async def run(): |
| execute_result = None |
| error_traceback = None |
| stream_text_list = [] |
| kc = km.client() |
| assert isinstance(kc, AsyncKernelClient) |
| kc.start_channels() |
| try: |
| await kc.wait_for_ready(timeout=wait_for_ready_timeout) |
| msg_id = kc.execute(code) |
| while True: |
| message = await get_iopub_msg_with_death_detection( |
| kc, timeout=iopub_timeout) |
| if logger.isEnabledFor(logging.DEBUG): |
| logger.debug( |
| json.dumps(message, indent=2, default=str)) |
| assert message["parent_header"]["msg_id"] == msg_id |
| msg_type = message["msg_type"] |
| if msg_type == "status": |
| if message["content"]["execution_state"] == "idle": |
| break |
| elif msg_type == "stream": |
| stream_name = message["content"]["name"] |
| stream_text = message["content"]["text"] |
| stream_text_list.append(stream_text) |
| elif msg_type == "execute_result": |
| execute_result = message["content"]["data"] |
| elif msg_type == "error": |
| error_traceback_lines = message["content"]["traceback"] |
| error_traceback = "\n".join(error_traceback_lines) |
| elif msg_type == "execute_input": |
| pass |
| else: |
| assert False, f"Unknown message_type: {msg_type}" |
| finally: |
| kc.stop_channels() |
| return execute_result, error_traceback, "".join(stream_text_list) |
|
|
| if interrupt_after: |
| run_task = asyncio.create_task(run()) |
| send_interrupt_task = asyncio.create_task(send_interrupt()) |
| done, _ = await asyncio.wait([run_task, send_interrupt_task], |
| return_when=asyncio.FIRST_COMPLETED) |
| if run_task in done: |
| send_interrupt_task.cancel() |
| else: |
| assert send_interrupt_task in done |
| result = await run_task |
| else: |
| result = await run() |
| return result |
| finally: |
| if shutdown_kernel: |
| await km.shutdown_kernel() |
|
|
|
|
| class IPythonInterpreter(BaseAction): |
| """A IPython executor that can execute Python scripts in a jupyter manner. |
| |
| Args: |
| timeout (int): Upper bound of waiting time for Python script execution. |
| Defaults to 20. |
| user_data_dir (str, optional): Specified the user data directory for files |
| loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. |
| Defaults to `ENV`. |
| work_dir (str, optional): Specify which directory to save output images to. |
| Defaults to ``'./work_dir/tmp_dir'``. |
| description (dict): The description of the action. Defaults to ``None``. |
| parser (Type[BaseParser]): The parser class to process the |
| action's inputs and outputs. Defaults to :class:`JsonParser`. |
| """ |
|
|
| _KERNEL_CLIENTS = {} |
|
|
| def __init__( |
| self, |
| timeout: int = 20, |
| user_data_dir: str = 'ENV', |
| work_dir='./work_dir/tmp_dir', |
| description: Optional[dict] = None, |
| parser: Type[BaseParser] = JsonParser, |
| ): |
| super().__init__(description, parser) |
|
|
| self.timeout = timeout |
| if user_data_dir == 'ENV': |
| user_data_dir = os.environ.get('USER_DATA_DIR', '') |
|
|
| if user_data_dir: |
| user_data_dir = os.path.dirname(user_data_dir) |
| user_data_dir = f"import os\nos.chdir('{user_data_dir}')" |
| self.user_data_dir = user_data_dir |
| self._initialized = False |
| self.work_dir = work_dir |
| if not os.path.exists(self.work_dir): |
| os.makedirs(self.work_dir, exist_ok=True) |
|
|
| @staticmethod |
| def start_kernel(): |
| from jupyter_client import KernelManager |
|
|
| |
| km = KernelManager() |
| km.start_kernel() |
| kc = km.client() |
| return km, kc |
|
|
| def initialize(self): |
| if self._initialized: |
| return |
| pid = os.getpid() |
| if pid not in self._KERNEL_CLIENTS: |
| self._KERNEL_CLIENTS[pid] = self.start_kernel() |
| self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] |
| self._initialized = True |
| self._call(START_CODE.format(self.user_data_dir), None) |
|
|
| def reset(self): |
| if not self._initialized: |
| self.initialize() |
| else: |
| code = "get_ipython().run_line_magic('reset', '-f')\n" + \ |
| START_CODE.format(self.user_data_dir) |
| self._call(code, None) |
|
|
| def _call(self, |
| command: str, |
| timeout: Optional[int] = None) -> Tuple[str, bool]: |
| self.initialize() |
| command = extract_code(command) |
|
|
| |
| while True: |
| try: |
| msg = self.kernel_client.get_iopub_msg(timeout=5) |
| msg_type = msg['msg_type'] |
| if msg_type == 'status': |
| if msg['content'].get('execution_state') == 'idle': |
| break |
| except queue.Empty: |
| |
| break |
|
|
| self.kernel_client.execute(command) |
|
|
| def _inner_call(): |
| result = '' |
| images = [] |
| succeed = True |
| image_idx = 0 |
|
|
| while True: |
| text = '' |
| image = '' |
| finished = False |
| msg_type = 'error' |
| try: |
| msg = self.kernel_client.get_iopub_msg(timeout=20) |
| msg_type = msg['msg_type'] |
| if msg_type == 'status': |
| if msg['content'].get('execution_state') == 'idle': |
| finished = True |
| elif msg_type == 'execute_result': |
| text = msg['content']['data'].get('text/plain', '') |
| if 'image/png' in msg['content']['data']: |
| image_b64 = msg['content']['data']['image/png'] |
| image_url = publish_image_to_local( |
| image_b64, self.work_dir) |
| image_idx += 1 |
| image = '' % (image_idx, image_url) |
|
|
| elif msg_type == 'display_data': |
| if 'image/png' in msg['content']['data']: |
| image_b64 = msg['content']['data']['image/png'] |
| image_url = publish_image_to_local( |
| image_b64, self.work_dir) |
| image_idx += 1 |
| image = '' % (image_idx, image_url) |
|
|
| else: |
| text = msg['content']['data'].get('text/plain', '') |
| elif msg_type == 'stream': |
| msg_type = msg['content']['name'] |
| text = msg['content']['text'] |
| elif msg_type == 'error': |
| succeed = False |
| text = escape_ansi('\n'.join( |
| msg['content']['traceback'])) |
| if 'M6_CODE_INTERPRETER_TIMEOUT' in text: |
| text = f'Timeout. No response after {timeout} seconds.' |
| except queue.Empty: |
| |
| self.kernel_manager.interrupt_kernel() |
| succeed = False |
| text = f'Timeout. No response after {timeout} seconds.' |
| finished = True |
| except Exception: |
| succeed = False |
| msg = ''.join(traceback.format_exception(*sys.exc_info())) |
| |
| text = msg |
| logging.warning(msg) |
| finished = True |
| if text: |
| |
| result += f'{text}' |
|
|
| if image: |
| images.append(image_url) |
| if finished: |
| return succeed, dict(text=result, image=images) |
|
|
| try: |
| if timeout: |
|
|
| def handler(signum, frame): |
| raise TimeoutError() |
|
|
| signal.signal(signal.SIGALRM, handler) |
| signal.alarm(timeout) |
| succeed, result = _inner_call() |
| except TimeoutError: |
| succeed = False |
| text = 'The code interpreter encountered an unexpected error.' |
| result = f'\n\nerror:\n\n```\n{text}\n```' |
| finally: |
| if timeout: |
| signal.alarm(0) |
|
|
| |
| return succeed, result |
|
|
| @tool_api |
| def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: |
| r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. |
| |
| Args: |
| command (:class:`str`): Python code |
| timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. |
| """ |
| tool_return = ActionReturn(url=None, args=None, type=self.name) |
| tool_return.args = dict(text=command) |
| succeed, result = self._call(command, timeout) |
| if succeed: |
| text = result['text'] |
| image = result.get('image', []) |
| resp = [dict(type='text', content=text)] |
| if image: |
| resp.extend([dict(type='image', content=im) for im in image]) |
| tool_return.result = resp |
| |
| |
| tool_return.state = ActionStatusCode.SUCCESS |
| else: |
| tool_return.errmsg = result.get('text', '') if isinstance( |
| result, dict) else result |
| tool_return.state = ActionStatusCode.API_ERROR |
| return tool_return |
|
|
|
|
| class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): |
| """A IPython executor that can execute Python scripts in a jupyter manner. |
| |
| Args: |
| timeout (int): Upper bound of waiting time for Python script execution. |
| Defaults to 20. |
| user_data_dir (str, optional): Specified the user data directory for files |
| loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. |
| Defaults to `ENV`. |
| work_dir (str, optional): Specify which directory to save output images to. |
| Defaults to ``'./work_dir/tmp_dir'``. |
| description (dict): The description of the action. Defaults to ``None``. |
| parser (Type[BaseParser]): The parser class to process the |
| action's inputs and outputs. Defaults to :class:`JsonParser`. |
| """ |
|
|
| _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() |
|
|
| def __init__( |
| self, |
| timeout: int = 20, |
| user_data_dir: str = 'ENV', |
| work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'), |
| max_kernels: Optional[int] = None, |
| reuse_kernel: bool = True, |
| startup_rate: bool = 32, |
| connection_dir: str = tempfile.gettempdir(), |
| description: Optional[dict] = None, |
| parser: Type[BaseParser] = JsonParser, |
| ): |
| super().__init__(timeout, user_data_dir, work_dir, description, parser) |
| from traitlets.config import Config |
|
|
| c = Config() |
| c.KernelManager.transport = 'ipc' |
| self._amkm = AsyncMultiKernelManager( |
| config=c, connection_dir=connection_dir) |
| self._max_kernels = max_kernels |
| self._reuse_kernel = reuse_kernel |
| self._sem = asyncio.Semaphore(startup_rate) |
| self._lock = asyncio.Lock() |
|
|
| async def initialize(self, session_id: str): |
| session_id = str(session_id) |
| while True: |
| if session_id in self._KERNEL_CLIENTS: |
| return self._KERNEL_CLIENTS[session_id] |
| if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): |
| self._KERNEL_CLIENTS[ |
| session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() |
| return self._KERNEL_CLIENTS[session_id] |
| async with self._sem: |
| if self._max_kernels is None or len( |
| self._KERNEL_CLIENTS |
| ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: |
| kernel_id = None |
| try: |
| kernel_id = await self._amkm.start_kernel() |
| kernel = self._amkm.get_kernel(kernel_id) |
| client = kernel.client() |
| _, error_stacktrace, stream_text = await async_run_code( |
| kernel, |
| START_CODE.format(self.user_data_dir), |
| shutdown_kernel=False) |
| |
| if not (error_stacktrace is None |
| and stream_text == ''): |
| raise RuntimeError |
| except Exception as e: |
| print(f'Starting kernel error: {e}') |
| if kernel_id: |
| await self._amkm.shutdown_kernel(kernel_id) |
| self._amkm.remove_kernel(kernel_id) |
| await asyncio.sleep(1) |
| continue |
| if self._max_kernels is None: |
| self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, |
| client) |
| return kernel_id, kernel, client |
| async with self._lock: |
| if len(self._KERNEL_CLIENTS |
| ) + self._UNBOUND_KERNEL_CLIENTS.qsize( |
| ) < self._max_kernels: |
| self._KERNEL_CLIENTS[session_id] = (kernel_id, |
| kernel, client) |
| return kernel_id, kernel, client |
| await self._amkm.shutdown_kernel(kernel_id) |
| self._amkm.remove_kernel(kernel_id) |
| await asyncio.sleep(1) |
|
|
| async def reset(self, session_id: str): |
| session_id = str(session_id) |
| if session_id not in self._KERNEL_CLIENTS: |
| return |
| _, kernel, _ = self._KERNEL_CLIENTS[session_id] |
| code = "get_ipython().run_line_magic('reset', '-f')\n" + \ |
| START_CODE.format(self.user_data_dir) |
| await async_run_code(kernel, code, shutdown_kernel=False) |
|
|
| async def shutdown(self, session_id: str): |
| session_id = str(session_id) |
| if session_id in self._KERNEL_CLIENTS: |
| kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id) |
| await self._amkm.shutdown_kernel(kernel_id) |
| self._amkm.remove_kernel(kernel_id) |
| del self._KERNEL_CLIENTS[session_id] |
|
|
| async def close_session(self, session_id: str): |
| session_id = str(session_id) |
| if self._reuse_kernel: |
| if session_id in self._KERNEL_CLIENTS: |
| await self.reset(session_id) |
| await self._UNBOUND_KERNEL_CLIENTS.put( |
| self._KERNEL_CLIENTS.pop(session_id)) |
| else: |
| await self.shutdown(session_id) |
|
|
| async def _call(self, command, timeout=None, session_id=None): |
| _, kernel, _ = await self.initialize(str(session_id)) |
| result = await async_run_code( |
| kernel, |
| extract_code(command), |
| interrupt_after=timeout or self.timeout, |
| shutdown_kernel=False) |
| execute_result, error_stacktrace, stream_text = result |
| if error_stacktrace is not None: |
| ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) |
| if ret.endswith('KeyboardInterrupt: '): |
| ret = 'The code interpreter encountered a timeout error.' |
| status, ret = False, ret.strip() |
| elif execute_result is not None: |
| status, ret = True, dict(text=execute_result.get('text/plain', '')) |
| else: |
| status, ret = True, dict(text=stream_text.strip()) |
| return status, ret |
|
|
| @tool_api |
| async def run(self, |
| command: str, |
| timeout: Optional[int] = None, |
| session_id: Optional[str] = None) -> ActionReturn: |
| r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. |
| |
| Args: |
| command (:class:`str`): Python code |
| timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. |
| """ |
| tool_return = ActionReturn(url=None, args=None, type=self.name) |
| tool_return.args = dict(text=command) |
| succeed, result = await self._call(command, timeout, session_id) |
| if succeed: |
| text = result['text'] |
| image = result.get('image', []) |
| resp = [dict(type='text', content=text)] |
| if image: |
| resp.extend([dict(type='image', content=im) for im in image]) |
| tool_return.result = resp |
| |
| |
| tool_return.state = ActionStatusCode.SUCCESS |
| else: |
| tool_return.errmsg = result.get('text', '') if isinstance( |
| result, dict) else result |
| tool_return.state = ActionStatusCode.API_ERROR |
| return tool_return |
|
|
|
|
| def extract_code(text): |
| import json5 |
|
|
| |
| triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) |
| |
| single_match = re.search(r'`([^`]*)`', text, re.DOTALL) |
| if triple_match: |
| text = triple_match.group(1) |
| elif single_match: |
| text = single_match.group(1) |
| else: |
| try: |
| text = json5.loads(text)['code'] |
| except Exception: |
| pass |
| |
| return text |
|
|
|
|
| def escape_ansi(line): |
| ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') |
| return ansi_escape.sub('', line) |
|
|
|
|
| def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): |
| import PIL.Image |
| image_file = str(uuid.uuid4()) + '.png' |
| local_image_file = os.path.join(work_dir, image_file) |
|
|
| png_bytes = base64.b64decode(image_base64) |
| assert isinstance(png_bytes, bytes) |
| bytes_io = io.BytesIO(png_bytes) |
| PIL.Image.open(bytes_io).save(local_image_file, 'png') |
|
|
| return local_image_file |
|
|
|
|
| |
| def get_multiline_input(hint): |
| print(hint) |
| print('// Press ENTER to make a new line. Press CTRL-D to end input.') |
| lines = [] |
| while True: |
| try: |
| line = input() |
| except EOFError: |
| break |
| lines.append(line) |
| print('// Input received.') |
| if lines: |
| return '\n'.join(lines) |
| else: |
| return '' |
|
|
|
|
| if __name__ == '__main__': |
| code_interpreter = IPythonInterpreter() |
| while True: |
| print(code_interpreter(get_multiline_input('Enter python code:'))) |
|
|