Spaces:
Sleeping
Sleeping
| # flake8: noqa: E501 | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import queue | |
| import re | |
| import signal | |
| import sys | |
| import tempfile | |
| import traceback | |
| import uuid | |
| from typing import Optional, Tuple, Type | |
| from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager | |
| from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed | |
| from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api | |
| from lagent.actions.parser import BaseParser, JsonParser | |
| from lagent.schema import ActionReturn, ActionStatusCode | |
| logger = logging.getLogger(__name__) | |
| START_CODE = """ | |
| def input(*args, **kwargs): | |
| raise NotImplementedError('Python input() function is disabled.') | |
| get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') | |
| {} | |
| """ # noqa | |
| class TimeoutError(Exception): | |
| pass | |
| class KernelDeath(Exception): | |
| pass | |
| async def async_run_code( | |
| km: AsyncKernelManager, | |
| code, | |
| *, | |
| interrupt_after=30, | |
| iopub_timeout=40, | |
| wait_for_ready_timeout=60, | |
| shutdown_kernel=True, | |
| ): | |
| assert iopub_timeout > interrupt_after | |
| try: | |
| async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, | |
| *, | |
| timeout=None): | |
| loop = asyncio.get_running_loop() | |
| dead_fut = loop.create_future() | |
| def restarting(): | |
| assert ( | |
| False | |
| ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0" | |
| def dead(): | |
| logger.info("Kernel has died, will NOT restart") | |
| dead_fut.set_result(None) | |
| msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout)) | |
| km.add_restart_callback(restarting, "restart") | |
| km.add_restart_callback(dead, "dead") | |
| try: | |
| done, _ = await asyncio.wait( | |
| [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) | |
| if dead_fut in done: | |
| raise KernelDeath() | |
| assert msg_task in done | |
| return await msg_task | |
| finally: | |
| msg_task.cancel() | |
| km.remove_restart_callback(restarting, "restart") | |
| km.remove_restart_callback(dead, "dead") | |
| async def send_interrupt(): | |
| await asyncio.sleep(interrupt_after) | |
| logger.info("Sending interrupt to kernel") | |
| await km.interrupt_kernel() | |
| async def run(): | |
| execute_result = None | |
| error_traceback = None | |
| stream_text_list = [] | |
| kc = km.client() | |
| assert isinstance(kc, AsyncKernelClient) | |
| kc.start_channels() | |
| try: | |
| await kc.wait_for_ready(timeout=wait_for_ready_timeout) | |
| msg_id = kc.execute(code) | |
| while True: | |
| message = await get_iopub_msg_with_death_detection( | |
| kc, timeout=iopub_timeout) | |
| if logger.isEnabledFor(logging.DEBUG): | |
| logger.debug( | |
| json.dumps(message, indent=2, default=str)) | |
| assert message["parent_header"]["msg_id"] == msg_id | |
| msg_type = message["msg_type"] | |
| if msg_type == "status": | |
| if message["content"]["execution_state"] == "idle": | |
| break | |
| elif msg_type == "stream": | |
| stream_name = message["content"]["name"] | |
| stream_text = message["content"]["text"] | |
| stream_text_list.append(stream_text) | |
| elif msg_type == "execute_result": | |
| execute_result = message["content"]["data"] | |
| elif msg_type == "error": | |
| error_traceback_lines = message["content"]["traceback"] | |
| error_traceback = "\n".join(error_traceback_lines) | |
| elif msg_type == "execute_input": | |
| pass | |
| else: | |
| assert False, f"Unknown message_type: {msg_type}" | |
| finally: | |
| kc.stop_channels() | |
| return execute_result, error_traceback, "".join(stream_text_list) | |
| if interrupt_after: | |
| run_task = asyncio.create_task(run()) | |
| send_interrupt_task = asyncio.create_task(send_interrupt()) | |
| done, _ = await asyncio.wait([run_task, send_interrupt_task], | |
| return_when=asyncio.FIRST_COMPLETED) | |
| if run_task in done: | |
| send_interrupt_task.cancel() | |
| else: | |
| assert send_interrupt_task in done | |
| result = await run_task | |
| else: | |
| result = await run() | |
| return result | |
| finally: | |
| if shutdown_kernel: | |
| await km.shutdown_kernel() | |
| class IPythonInterpreter(BaseAction): | |
| """A IPython executor that can execute Python scripts in a jupyter manner. | |
| Args: | |
| timeout (int): Upper bound of waiting time for Python script execution. | |
| Defaults to 20. | |
| user_data_dir (str, optional): Specified the user data directory for files | |
| loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. | |
| Defaults to `ENV`. | |
| work_dir (str, optional): Specify which directory to save output images to. | |
| Defaults to ``'./work_dir/tmp_dir'``. | |
| description (dict): The description of the action. Defaults to ``None``. | |
| parser (Type[BaseParser]): The parser class to process the | |
| action's inputs and outputs. Defaults to :class:`JsonParser`. | |
| """ | |
| _KERNEL_CLIENTS = {} | |
| def __init__( | |
| self, | |
| timeout: int = 20, | |
| user_data_dir: str = 'ENV', | |
| work_dir='./work_dir/tmp_dir', | |
| description: Optional[dict] = None, | |
| parser: Type[BaseParser] = JsonParser, | |
| ): | |
| super().__init__(description, parser) | |
| self.timeout = timeout | |
| if user_data_dir == 'ENV': | |
| user_data_dir = os.environ.get('USER_DATA_DIR', '') | |
| if user_data_dir: | |
| user_data_dir = os.path.dirname(user_data_dir) | |
| user_data_dir = f"import os\nos.chdir('{user_data_dir}')" | |
| self.user_data_dir = user_data_dir | |
| self._initialized = False | |
| self.work_dir = work_dir | |
| if not os.path.exists(self.work_dir): | |
| os.makedirs(self.work_dir, exist_ok=True) | |
| def start_kernel(): | |
| from jupyter_client import KernelManager | |
| # start the kernel and manager | |
| km = KernelManager() | |
| km.start_kernel() | |
| kc = km.client() | |
| return km, kc | |
| def initialize(self): | |
| if self._initialized: | |
| return | |
| pid = os.getpid() | |
| if pid not in self._KERNEL_CLIENTS: | |
| self._KERNEL_CLIENTS[pid] = self.start_kernel() | |
| self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] | |
| self._initialized = True | |
| self._call(START_CODE.format(self.user_data_dir), None) | |
| def reset(self): | |
| if not self._initialized: | |
| self.initialize() | |
| else: | |
| code = "get_ipython().run_line_magic('reset', '-f')\n" + \ | |
| START_CODE.format(self.user_data_dir) | |
| self._call(code, None) | |
| def _call(self, | |
| command: str, | |
| timeout: Optional[int] = None) -> Tuple[str, bool]: | |
| self.initialize() | |
| command = extract_code(command) | |
| # check previous remaining result | |
| while True: | |
| try: | |
| msg = self.kernel_client.get_iopub_msg(timeout=5) | |
| msg_type = msg['msg_type'] | |
| if msg_type == 'status': | |
| if msg['content'].get('execution_state') == 'idle': | |
| break | |
| except queue.Empty: | |
| # assume no result | |
| break | |
| self.kernel_client.execute(command) | |
| def _inner_call(): | |
| result = '' | |
| images = [] | |
| succeed = True | |
| image_idx = 0 | |
| while True: | |
| text = '' | |
| image = '' | |
| finished = False | |
| msg_type = 'error' | |
| try: | |
| msg = self.kernel_client.get_iopub_msg(timeout=20) | |
| msg_type = msg['msg_type'] | |
| if msg_type == 'status': | |
| if msg['content'].get('execution_state') == 'idle': | |
| finished = True | |
| elif msg_type == 'execute_result': | |
| text = msg['content']['data'].get('text/plain', '') | |
| if 'image/png' in msg['content']['data']: | |
| image_b64 = msg['content']['data']['image/png'] | |
| image_url = publish_image_to_local( | |
| image_b64, self.work_dir) | |
| image_idx += 1 | |
| image = '' % (image_idx, image_url) | |
| elif msg_type == 'display_data': | |
| if 'image/png' in msg['content']['data']: | |
| image_b64 = msg['content']['data']['image/png'] | |
| image_url = publish_image_to_local( | |
| image_b64, self.work_dir) | |
| image_idx += 1 | |
| image = '' % (image_idx, image_url) | |
| else: | |
| text = msg['content']['data'].get('text/plain', '') | |
| elif msg_type == 'stream': | |
| msg_type = msg['content']['name'] # stdout, stderr | |
| text = msg['content']['text'] | |
| elif msg_type == 'error': | |
| succeed = False | |
| text = escape_ansi('\n'.join( | |
| msg['content']['traceback'])) | |
| if 'M6_CODE_INTERPRETER_TIMEOUT' in text: | |
| text = f'Timeout. No response after {timeout} seconds.' # noqa | |
| except queue.Empty: | |
| # stop current task in case break next input. | |
| self.kernel_manager.interrupt_kernel() | |
| succeed = False | |
| text = f'Timeout. No response after {timeout} seconds.' | |
| finished = True | |
| except Exception: | |
| succeed = False | |
| msg = ''.join(traceback.format_exception(*sys.exc_info())) | |
| # text = 'The code interpreter encountered an unexpected error.' # noqa | |
| text = msg | |
| logging.warning(msg) | |
| finished = True | |
| if text: | |
| # result += f'\n\n{msg_type}:\n\n```\n{text}\n```' | |
| result += f'{text}' | |
| if image: | |
| images.append(image_url) | |
| if finished: | |
| return succeed, dict(text=result, image=images) | |
| try: | |
| if timeout: | |
| def handler(signum, frame): | |
| raise TimeoutError() | |
| signal.signal(signal.SIGALRM, handler) | |
| signal.alarm(timeout) | |
| succeed, result = _inner_call() | |
| except TimeoutError: | |
| succeed = False | |
| text = 'The code interpreter encountered an unexpected error.' | |
| result = f'\n\nerror:\n\n```\n{text}\n```' | |
| finally: | |
| if timeout: | |
| signal.alarm(0) | |
| # result = result.strip('\n') | |
| return succeed, result | |
| def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: | |
| r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. | |
| Args: | |
| command (:class:`str`): Python code | |
| timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. | |
| """ | |
| tool_return = ActionReturn(url=None, args=None, type=self.name) | |
| tool_return.args = dict(text=command) | |
| succeed, result = self._call(command, timeout) | |
| if succeed: | |
| text = result['text'] | |
| image = result.get('image', []) | |
| resp = [dict(type='text', content=text)] | |
| if image: | |
| resp.extend([dict(type='image', content=im) for im in image]) | |
| tool_return.result = resp | |
| # tool_return.result = dict( | |
| # text=result['text'], image=result.get('image', [])[0]) | |
| tool_return.state = ActionStatusCode.SUCCESS | |
| else: | |
| tool_return.errmsg = result.get('text', '') if isinstance( | |
| result, dict) else result | |
| tool_return.state = ActionStatusCode.API_ERROR | |
| return tool_return | |
| class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): | |
| """A IPython executor that can execute Python scripts in a jupyter manner. | |
| Args: | |
| timeout (int): Upper bound of waiting time for Python script execution. | |
| Defaults to 20. | |
| user_data_dir (str, optional): Specified the user data directory for files | |
| loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. | |
| Defaults to `ENV`. | |
| work_dir (str, optional): Specify which directory to save output images to. | |
| Defaults to ``'./work_dir/tmp_dir'``. | |
| description (dict): The description of the action. Defaults to ``None``. | |
| parser (Type[BaseParser]): The parser class to process the | |
| action's inputs and outputs. Defaults to :class:`JsonParser`. | |
| """ | |
| _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() | |
| def __init__( | |
| self, | |
| timeout: int = 20, | |
| user_data_dir: str = 'ENV', | |
| work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'), | |
| max_kernels: Optional[int] = None, | |
| reuse_kernel: bool = True, | |
| startup_rate: bool = 32, | |
| connection_dir: str = tempfile.gettempdir(), | |
| description: Optional[dict] = None, | |
| parser: Type[BaseParser] = JsonParser, | |
| ): | |
| super().__init__(timeout, user_data_dir, work_dir, description, parser) | |
| from traitlets.config import Config | |
| c = Config() | |
| c.KernelManager.transport = 'ipc' | |
| self._amkm = AsyncMultiKernelManager( | |
| config=c, connection_dir=connection_dir) | |
| self._max_kernels = max_kernels | |
| self._reuse_kernel = reuse_kernel | |
| self._sem = asyncio.Semaphore(startup_rate) | |
| self._lock = asyncio.Lock() | |
| async def initialize(self, session_id: str): | |
| session_id = str(session_id) | |
| while True: | |
| if session_id in self._KERNEL_CLIENTS: | |
| return self._KERNEL_CLIENTS[session_id] | |
| if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): | |
| self._KERNEL_CLIENTS[ | |
| session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() | |
| return self._KERNEL_CLIENTS[session_id] | |
| async with self._sem: | |
| if self._max_kernels is None or len( | |
| self._KERNEL_CLIENTS | |
| ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: | |
| kernel_id = None | |
| try: | |
| kernel_id = await self._amkm.start_kernel() | |
| kernel = self._amkm.get_kernel(kernel_id) | |
| client = kernel.client() | |
| _, error_stacktrace, stream_text = await async_run_code( | |
| kernel, | |
| START_CODE.format(self.user_data_dir), | |
| shutdown_kernel=False) | |
| # check if the output of START_CODE meets expectations | |
| if not (error_stacktrace is None | |
| and stream_text == ''): | |
| raise RuntimeError | |
| except Exception as e: | |
| print(f'Starting kernel error: {e}') | |
| if kernel_id: | |
| await self._amkm.shutdown_kernel(kernel_id) | |
| self._amkm.remove_kernel(kernel_id) | |
| await asyncio.sleep(1) | |
| continue | |
| if self._max_kernels is None: | |
| self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, | |
| client) | |
| return kernel_id, kernel, client | |
| async with self._lock: | |
| if len(self._KERNEL_CLIENTS | |
| ) + self._UNBOUND_KERNEL_CLIENTS.qsize( | |
| ) < self._max_kernels: | |
| self._KERNEL_CLIENTS[session_id] = (kernel_id, | |
| kernel, client) | |
| return kernel_id, kernel, client | |
| await self._amkm.shutdown_kernel(kernel_id) | |
| self._amkm.remove_kernel(kernel_id) | |
| await asyncio.sleep(1) | |
| async def reset(self, session_id: str): | |
| session_id = str(session_id) | |
| if session_id not in self._KERNEL_CLIENTS: | |
| return | |
| _, kernel, _ = self._KERNEL_CLIENTS[session_id] | |
| code = "get_ipython().run_line_magic('reset', '-f')\n" + \ | |
| START_CODE.format(self.user_data_dir) | |
| await async_run_code(kernel, code, shutdown_kernel=False) | |
| async def shutdown(self, session_id: str): | |
| session_id = str(session_id) | |
| if session_id in self._KERNEL_CLIENTS: | |
| kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id) | |
| await self._amkm.shutdown_kernel(kernel_id) | |
| self._amkm.remove_kernel(kernel_id) | |
| del self._KERNEL_CLIENTS[session_id] | |
| async def close_session(self, session_id: str): | |
| session_id = str(session_id) | |
| if self._reuse_kernel: | |
| if session_id in self._KERNEL_CLIENTS: | |
| await self.reset(session_id) | |
| await self._UNBOUND_KERNEL_CLIENTS.put( | |
| self._KERNEL_CLIENTS.pop(session_id)) | |
| else: | |
| await self.shutdown(session_id) | |
| async def _call(self, command, timeout=None, session_id=None): | |
| _, kernel, _ = await self.initialize(str(session_id)) | |
| result = await async_run_code( | |
| kernel, | |
| extract_code(command), | |
| interrupt_after=timeout or self.timeout, | |
| shutdown_kernel=False) | |
| execute_result, error_stacktrace, stream_text = result | |
| if error_stacktrace is not None: | |
| ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) | |
| if ret.endswith('KeyboardInterrupt: '): | |
| ret = 'The code interpreter encountered a timeout error.' | |
| status, ret = False, ret.strip() | |
| elif execute_result is not None: | |
| status, ret = True, dict(text=execute_result.get('text/plain', '')) | |
| else: | |
| status, ret = True, dict(text=stream_text.strip()) | |
| return status, ret | |
| async def run(self, | |
| command: str, | |
| timeout: Optional[int] = None, | |
| session_id: Optional[str] = None) -> ActionReturn: | |
| r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. | |
| Args: | |
| command (:class:`str`): Python code | |
| timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. | |
| """ | |
| tool_return = ActionReturn(url=None, args=None, type=self.name) | |
| tool_return.args = dict(text=command) | |
| succeed, result = await self._call(command, timeout, session_id) | |
| if succeed: | |
| text = result['text'] | |
| image = result.get('image', []) | |
| resp = [dict(type='text', content=text)] | |
| if image: | |
| resp.extend([dict(type='image', content=im) for im in image]) | |
| tool_return.result = resp | |
| # tool_return.result = dict( | |
| # text=result['text'], image=result.get('image', [])[0]) | |
| tool_return.state = ActionStatusCode.SUCCESS | |
| else: | |
| tool_return.errmsg = result.get('text', '') if isinstance( | |
| result, dict) else result | |
| tool_return.state = ActionStatusCode.API_ERROR | |
| return tool_return | |
| def extract_code(text): | |
| import json5 | |
| # Match triple backtick blocks first | |
| triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) | |
| # Match single backtick blocks second | |
| single_match = re.search(r'`([^`]*)`', text, re.DOTALL) | |
| if triple_match: | |
| text = triple_match.group(1) | |
| elif single_match: | |
| text = single_match.group(1) | |
| else: | |
| try: | |
| text = json5.loads(text)['code'] | |
| except Exception: | |
| pass | |
| # If no code blocks found, return original text | |
| return text | |
| def escape_ansi(line): | |
| ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') | |
| return ansi_escape.sub('', line) | |
| def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): | |
| import PIL.Image | |
| image_file = str(uuid.uuid4()) + '.png' | |
| local_image_file = os.path.join(work_dir, image_file) | |
| png_bytes = base64.b64decode(image_base64) | |
| assert isinstance(png_bytes, bytes) | |
| bytes_io = io.BytesIO(png_bytes) | |
| PIL.Image.open(bytes_io).save(local_image_file, 'png') | |
| return local_image_file | |
| # local test for code interpreter | |
| def get_multiline_input(hint): | |
| print(hint) | |
| print('// Press ENTER to make a new line. Press CTRL-D to end input.') | |
| lines = [] | |
| while True: | |
| try: | |
| line = input() | |
| except EOFError: # CTRL-D | |
| break | |
| lines.append(line) | |
| print('// Input received.') | |
| if lines: | |
| return '\n'.join(lines) | |
| else: | |
| return '' | |
| if __name__ == '__main__': | |
| code_interpreter = IPythonInterpreter() | |
| while True: | |
| print(code_interpreter(get_multiline_input('Enter python code:'))) | |