File size: 4,871 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import os
from typing import Any, Dict

from IPython.core.interactiveshell import InteractiveShell
from IPython.core.magic import Magics, cell_magic, line_cell_magic, line_magic, magics_class, needs_local_scope

from taskweaver.ces.runtime.executor import Executor


def fmt_response(is_success: bool, message: str, data: Any = None):
    return {
        "is_success": is_success,
        "message": message,
        "data": data,
    }


@magics_class
class TaskWeaverContextMagic(Magics):
    def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any):
        super(TaskWeaverContextMagic, self).__init__(shell, **kwargs)
        self.executor = executor

    @needs_local_scope
    @line_magic
    def _taskweaver_session_init(self, line: str, local_ns: Dict[str, Any]):
        self.executor.load_lib(local_ns)
        return fmt_response(True, "TaskWeaver context initialized.")

    @cell_magic
    def _taskweaver_update_session_var(self, line: str, cell: str):
        json_dict_str = cell
        session_var_dict = json.loads(json_dict_str)
        self.executor.update_session_var(session_var_dict)
        return fmt_response(True, "Session var updated.", self.executor.session_var)

    @cell_magic
    def _taskweaver_convert_path(self, line: str, cell: str):
        raw_path_str = cell
        import os

        full_path = os.path.abspath(raw_path_str)
        return fmt_response(True, "Path converted.", full_path)

    @line_magic
    def _taskweaver_exec_pre_check(self, line: str):
        exec_idx, exec_id = line.split(" ")
        exec_idx = int(exec_idx)
        return fmt_response(True, "", self.executor.pre_execution(exec_idx, exec_id))

    @needs_local_scope
    @line_magic
    def _taskweaver_exec_post_check(self, line: str, local_ns: Dict[str, Any]):
        if "_" in local_ns:
            self.executor.ctx.set_output(local_ns["_"])
        return fmt_response(True, "", self.executor.get_post_execution_state())


@magics_class
class TaskWeaverPluginMagic(Magics):
    def __init__(self, shell: InteractiveShell, executor: Executor, **kwargs: Any):
        super(TaskWeaverPluginMagic, self).__init__(shell, **kwargs)
        self.executor = executor

    @line_cell_magic
    def _taskweaver_plugin_register(self, line: str, cell: str):
        plugin_name = line
        plugin_code = cell
        try:
            self.executor.register_plugin(plugin_name, plugin_code)
            return fmt_response(True, f"Plugin {plugin_name} registered.")
        except Exception as e:
            return fmt_response(
                False,
                f"Plugin {plugin_name} failed to register: " + str(e),
            )

    @line_magic
    def _taskweaver_plugin_test(self, line: str):
        plugin_name = line
        is_success, messages = self.executor.test_plugin(plugin_name)
        if is_success:
            return fmt_response(
                True,
                f"Plugin {plugin_name} passed tests: " + "\n".join(messages),
            )

        return fmt_response(
            False,
            f"Plugin {plugin_name} failed to test: " + "\n".join(messages),
        )

    @needs_local_scope
    @line_cell_magic
    def _taskweaver_plugin_load(self, line: str, cell: str, local_ns: Dict[str, Any]):
        plugin_name = line
        plugin_config: Any = json.loads(cell)
        try:
            self.executor.config_plugin(plugin_name, plugin_config)
            local_ns[plugin_name] = self.executor.get_plugin_instance(plugin_name)
            return fmt_response(True, f"Plugin {plugin_name} loaded.")
        except Exception as e:
            return fmt_response(
                False,
                f"Plugin {plugin_name} failed to load: " + str(e),
            )

    @needs_local_scope
    @line_magic
    def _taskweaver_plugin_unload(self, line: str, local_ns: Dict[str, Any]):
        plugin_name = line
        if plugin_name not in local_ns:
            return fmt_response(
                True,
                f"Plugin {plugin_name} not loaded, skipping unloading.",
            )
        del local_ns[plugin_name]
        return fmt_response(True, f"Plugin {plugin_name} unloaded.")


def load_ipython_extension(ipython: InteractiveShell):
    env_id = os.environ.get("TASKWEAVER_ENV_ID", "local")
    session_id = os.environ.get("TASKWEAVER_SESSION_ID", "session_temp")
    session_dir = os.environ.get(
        "TASKWEAVER_SESSION_DIR",
        os.path.realpath(os.getcwd()),
    )

    executor = Executor(
        env_id=env_id,
        session_id=session_id,
        session_dir=session_dir,
    )

    ctx_magic = TaskWeaverContextMagic(ipython, executor)
    plugin_magic = TaskWeaverPluginMagic(ipython, executor)

    ipython.register_magics(ctx_magic)
    ipython.register_magics(plugin_magic)
    ipython.InteractiveTB.set_mode(mode="Plain")