tskwvr / taskweaver /ces /kernel /ctx_magic.py
TRaw's picture
Upload 297 files
3d3d712
import json
import os
from typing import Any, Dict
from IPython.core.interactiveshell import InteractiveShell
from IPython.core.magic import Magics, cell_magic, line_cell_magic, line_magic, magics_class, needs_local_scope
from taskweaver.ces.runtime.executor import Executor
def fmt_response(is_success: bool, message: str, data: Any = None):
return {
"is_success": is_success,
"message": message,
"data": data,
}
@magics_class
class TaskWeaverContextMagic(Magics):
def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any):
super(TaskWeaverContextMagic, self).__init__(shell, **kwargs)
self.executor = executor
@needs_local_scope
@line_magic
def _taskweaver_session_init(self, line: str, local_ns: Dict[str, Any]):
self.executor.load_lib(local_ns)
return fmt_response(True, "TaskWeaver context initialized.")
@cell_magic
def _taskweaver_update_session_var(self, line: str, cell: str):
json_dict_str = cell
session_var_dict = json.loads(json_dict_str)
self.executor.update_session_var(session_var_dict)
return fmt_response(True, "Session var updated.", self.executor.session_var)
@cell_magic
def _taskweaver_convert_path(self, line: str, cell: str):
raw_path_str = cell
import os
full_path = os.path.abspath(raw_path_str)
return fmt_response(True, "Path converted.", full_path)
@line_magic
def _taskweaver_exec_pre_check(self, line: str):
exec_idx, exec_id = line.split(" ")
exec_idx = int(exec_idx)
return fmt_response(True, "", self.executor.pre_execution(exec_idx, exec_id))
@needs_local_scope
@line_magic
def _taskweaver_exec_post_check(self, line: str, local_ns: Dict[str, Any]):
if "_" in local_ns:
self.executor.ctx.set_output(local_ns["_"])
return fmt_response(True, "", self.executor.get_post_execution_state())
@magics_class
class TaskWeaverPluginMagic(Magics):
def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any):
super(TaskWeaverPluginMagic, self).__init__(shell, **kwargs)
self.executor = executor
@line_cell_magic
def _taskweaver_plugin_register(self, line: str, cell: str):
plugin_name = line
plugin_code = cell
try:
self.executor.register_plugin(plugin_name, plugin_code)
return fmt_response(True, f"Plugin {plugin_name} registered.")
except Exception as e:
return fmt_response(
False,
f"Plugin {plugin_name} failed to register: " + str(e),
)
@line_magic
def _taskweaver_plugin_test(self, line: str):
plugin_name = line
is_success, messages = self.executor.test_plugin(plugin_name)
if is_success:
return fmt_response(
True,
f"Plugin {plugin_name} passed tests: " + "\n".join(messages),
)
return fmt_response(
False,
f"Plugin {plugin_name} failed to test: " + "\n".join(messages),
)
@needs_local_scope
@line_cell_magic
def _taskweaver_plugin_load(self, line: str, cell: str, local_ns: Dict[str, Any]):
plugin_name = line
plugin_config: Any = json.loads(cell)
try:
self.executor.config_plugin(plugin_name, plugin_config)
local_ns[plugin_name] = self.executor.get_plugin_instance(plugin_name)
return fmt_response(True, f"Plugin {plugin_name} loaded.")
except Exception as e:
return fmt_response(
False,
f"Plugin {plugin_name} failed to load: " + str(e),
)
@needs_local_scope
@line_magic
def _taskweaver_plugin_unload(self, line: str, local_ns: Dict[str, Any]):
plugin_name = line
if plugin_name not in local_ns:
return fmt_response(
True,
f"Plugin {plugin_name} not loaded, skipping unloading.",
)
del local_ns[plugin_name]
return fmt_response(True, f"Plugin {plugin_name} unloaded.")
def load_ipython_extension(ipython: InteractiveShell):
env_id = os.environ.get("TASKWEAVER_ENV_ID", "local")
session_id = os.environ.get("TASKWEAVER_SESSION_ID", "session_temp")
session_dir = os.environ.get(
"TASKWEAVER_SESSION_DIR",
os.path.realpath(os.getcwd()),
)
executor = Executor(
env_id=env_id,
session_id=session_id,
session_dir=session_dir,
)
ctx_magic = TaskWeaverContextMagic(ipython, executor)
plugin_magic = TaskWeaverPluginMagic(ipython, executor)
ipython.register_magics(ctx_magic)
ipython.register_magics(plugin_magic)
ipython.InteractiveTB.set_mode(mode="Plain")