import json import os from typing import Any, Dict from IPython.core.interactiveshell import InteractiveShell from IPython.core.magic import Magics, cell_magic, line_cell_magic, line_magic, magics_class, needs_local_scope from taskweaver.ces.runtime.executor import Executor def fmt_response(is_success: bool, message: str, data: Any = None): return { "is_success": is_success, "message": message, "data": data, } @magics_class class TaskWeaverContextMagic(Magics): def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any): super(TaskWeaverContextMagic, self).__init__(shell, **kwargs) self.executor = executor @needs_local_scope @line_magic def _taskweaver_session_init(self, line: str, local_ns: Dict[str, Any]): self.executor.load_lib(local_ns) return fmt_response(True, "TaskWeaver context initialized.") @cell_magic def _taskweaver_update_session_var(self, line: str, cell: str): json_dict_str = cell session_var_dict = json.loads(json_dict_str) self.executor.update_session_var(session_var_dict) return fmt_response(True, "Session var updated.", self.executor.session_var) @cell_magic def _taskweaver_convert_path(self, line: str, cell: str): raw_path_str = cell import os full_path = os.path.abspath(raw_path_str) return fmt_response(True, "Path converted.", full_path) @line_magic def _taskweaver_exec_pre_check(self, line: str): exec_idx, exec_id = line.split(" ") exec_idx = int(exec_idx) return fmt_response(True, "", self.executor.pre_execution(exec_idx, exec_id)) @needs_local_scope @line_magic def _taskweaver_exec_post_check(self, line: str, local_ns: Dict[str, Any]): if "_" in local_ns: self.executor.ctx.set_output(local_ns["_"]) return fmt_response(True, "", self.executor.get_post_execution_state()) @magics_class class TaskWeaverPluginMagic(Magics): def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any): super(TaskWeaverPluginMagic, self).__init__(shell, **kwargs) self.executor = executor @line_cell_magic def _taskweaver_plugin_register(self, line: str, cell: str): plugin_name = line plugin_code = cell try: self.executor.register_plugin(plugin_name, plugin_code) return fmt_response(True, f"Plugin {plugin_name} registered.") except Exception as e: return fmt_response( False, f"Plugin {plugin_name} failed to register: " + str(e), ) @line_magic def _taskweaver_plugin_test(self, line: str): plugin_name = line is_success, messages = self.executor.test_plugin(plugin_name) if is_success: return fmt_response( True, f"Plugin {plugin_name} passed tests: " + "\n".join(messages), ) return fmt_response( False, f"Plugin {plugin_name} failed to test: " + "\n".join(messages), ) @needs_local_scope @line_cell_magic def _taskweaver_plugin_load(self, line: str, cell: str, local_ns: Dict[str, Any]): plugin_name = line plugin_config: Any = json.loads(cell) try: self.executor.config_plugin(plugin_name, plugin_config) local_ns[plugin_name] = self.executor.get_plugin_instance(plugin_name) return fmt_response(True, f"Plugin {plugin_name} loaded.") except Exception as e: return fmt_response( False, f"Plugin {plugin_name} failed to load: " + str(e), ) @needs_local_scope @line_magic def _taskweaver_plugin_unload(self, line: str, local_ns: Dict[str, Any]): plugin_name = line if plugin_name not in local_ns: return fmt_response( True, f"Plugin {plugin_name} not loaded, skipping unloading.", ) del local_ns[plugin_name] return fmt_response(True, f"Plugin {plugin_name} unloaded.") def load_ipython_extension(ipython: InteractiveShell): env_id = os.environ.get("TASKWEAVER_ENV_ID", "local") session_id = os.environ.get("TASKWEAVER_SESSION_ID", "session_temp") session_dir = os.environ.get( "TASKWEAVER_SESSION_DIR", os.path.realpath(os.getcwd()), ) executor = Executor( env_id=env_id, session_id=session_id, session_dir=session_dir, ) ctx_magic = TaskWeaverContextMagic(ipython, executor) plugin_magic = TaskWeaverPluginMagic(ipython, executor) ipython.register_magics(ctx_magic) ipython.register_magics(plugin_magic) ipython.InteractiveTB.set_mode(mode="Plain")