File size: 4,887 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from taskweaver.plugin.context import ArtifactType, LogErrorLevel, PluginContext

if TYPE_CHECKING:
    from taskweaver.ces.runtime.executor import Executor


class ExecutorPluginContext(PluginContext):
    def __init__(self, executor: Any) -> None:
        self.executor: Executor = executor

        self.artifact_list: List[Dict[str, str]] = []
        self.log_messages: List[Tuple[LogErrorLevel, str, str]] = []
        self.output: List[Tuple[str, str]] = []

    @property
    def execution_id(self) -> str:
        return self.executor.cur_execution_id

    @property
    def session_id(self) -> str:
        return self.executor.session_id

    @property
    def env_id(self) -> str:
        return self.executor.env_id

    @property
    def execution_idx(self) -> int:
        return self.executor.cur_execution_count

    def add_artifact(
        self,
        name: str,
        file_name: str,
        type: ArtifactType,
        val: Any,
        desc: Optional[str] = None,
    ) -> str:
        desc_preview = desc if desc is not None else self._get_preview_by_type(type, val)

        id, path = self.create_artifact_path(name, file_name, type, desc=desc_preview)
        if type == "chart":
            with open(path, "w") as f:
                f.write(val)
        elif type == "df":
            val.to_csv(path, index=False)
        elif type == "file" or type == "txt" or type == "svg" or type == "html":
            with open(path, "w") as f:
                f.write(val)
        else:
            raise Exception("unsupported data type")

        return id

    def _get_preview_by_type(self, type: str, val: Any) -> str:
        if type == "chart":
            preview = "chart"
        elif type == "df":
            preview = f"DataFrame in shape {val.shape} with columns {list(val.columns)}"
        elif type == "file" or type == "txt":
            preview = str(val)[:100]
        elif type == "html":
            preview = "Web Page"
        else:
            preview = str(val)
        return preview

    def create_artifact_path(
        self,
        name: str,
        file_name: str,
        type: ArtifactType,
        desc: str,
    ) -> Tuple[str, str]:
        id = f"obj_{self.execution_idx}_{type}_{len(self.artifact_list):04x}"

        file_path = f"{id}_{file_name}"
        full_file_path = self._get_obj_path(file_path)

        self.artifact_list.append(
            {
                "name": name,
                "type": type,
                "original_name": file_name,
                "file": file_path,
                "preview": desc,
            },
        )
        return id, full_file_path

    def set_output(self, output: List[Tuple[str, str]]):
        if isinstance(output, list):
            self.output.extend(output)
        else:
            self.output.append((str(output), ""))

    def get_normalized_output(self):
        def to_str(v: Any) -> str:
            # TODO: configure/tune value length limit
            # TODO: handle known/common data types explicitly
            return str(v)[:5000]

        def normalize_tuple(i: int, v: Any) -> Tuple[str, str]:
            default_name = f"execution_result_{i + 1}"
            if isinstance(v, tuple) or isinstance(v, list):
                list_value: Any = v
                name = to_str(list_value[0]) if len(list_value) > 0 else default_name
                if len(list_value) <= 2:
                    val = to_str(list_value[1]) if len(list_value) > 1 else to_str(None)
                else:
                    val = to_str(list_value[1:])
                return (name, val)

            return (default_name, to_str(v))

        return [normalize_tuple(i, o) for i, o in enumerate(self.output)]

    def log(self, level: LogErrorLevel, tag: str, message: str):
        self.log_messages.append((level, tag, message))

    def _get_obj_path(self, name: str) -> str:
        return os.path.join(self.executor.session_dir, "cwd", name)

    def call_llm_api(self, messages: List[Dict[str, str]], **args: Any) -> Any:
        # TODO: use llm_api from handle side
        return None

    def get_env(self, plugin_name: str, variable_name: str):
        # To avoid duplicate env variable, use plugin_name and vari_name to compose the final environment variable
        name = f"PLUGIN_{plugin_name}_{variable_name}"
        if name in os.environ:
            return os.environ[name]
        raise Exception(
            "Environment variable " + name + " is required to be specified in environment",
        )

    def get_session_var(
        self,
        variable_name: str,
        default: Optional[str] = None,
    ) -> Optional[str]:
        if variable_name in self.executor.session_var:
            return self.executor.session_var[variable_name]
        return default