File size: 4,887 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from taskweaver.plugin.context import ArtifactType, LogErrorLevel, PluginContext
if TYPE_CHECKING:
from taskweaver.ces.runtime.executor import Executor
class ExecutorPluginContext(PluginContext):
def __init__(self, executor: Any) -> None:
self.executor: Executor = executor
self.artifact_list: List[Dict[str, str]] = []
self.log_messages: List[Tuple[LogErrorLevel, str, str]] = []
self.output: List[Tuple[str, str]] = []
@property
def execution_id(self) -> str:
return self.executor.cur_execution_id
@property
def session_id(self) -> str:
return self.executor.session_id
@property
def env_id(self) -> str:
return self.executor.env_id
@property
def execution_idx(self) -> int:
return self.executor.cur_execution_count
def add_artifact(
self,
name: str,
file_name: str,
type: ArtifactType,
val: Any,
desc: Optional[str] = None,
) -> str:
desc_preview = desc if desc is not None else self._get_preview_by_type(type, val)
id, path = self.create_artifact_path(name, file_name, type, desc=desc_preview)
if type == "chart":
with open(path, "w") as f:
f.write(val)
elif type == "df":
val.to_csv(path, index=False)
elif type == "file" or type == "txt" or type == "svg" or type == "html":
with open(path, "w") as f:
f.write(val)
else:
raise Exception("unsupported data type")
return id
def _get_preview_by_type(self, type: str, val: Any) -> str:
if type == "chart":
preview = "chart"
elif type == "df":
preview = f"DataFrame in shape {val.shape} with columns {list(val.columns)}"
elif type == "file" or type == "txt":
preview = str(val)[:100]
elif type == "html":
preview = "Web Page"
else:
preview = str(val)
return preview
def create_artifact_path(
self,
name: str,
file_name: str,
type: ArtifactType,
desc: str,
) -> Tuple[str, str]:
id = f"obj_{self.execution_idx}_{type}_{len(self.artifact_list):04x}"
file_path = f"{id}_{file_name}"
full_file_path = self._get_obj_path(file_path)
self.artifact_list.append(
{
"name": name,
"type": type,
"original_name": file_name,
"file": file_path,
"preview": desc,
},
)
return id, full_file_path
def set_output(self, output: List[Tuple[str, str]]):
if isinstance(output, list):
self.output.extend(output)
else:
self.output.append((str(output), ""))
def get_normalized_output(self):
def to_str(v: Any) -> str:
# TODO: configure/tune value length limit
# TODO: handle known/common data types explicitly
return str(v)[:5000]
def normalize_tuple(i: int, v: Any) -> Tuple[str, str]:
default_name = f"execution_result_{i + 1}"
if isinstance(v, tuple) or isinstance(v, list):
list_value: Any = v
name = to_str(list_value[0]) if len(list_value) > 0 else default_name
if len(list_value) <= 2:
val = to_str(list_value[1]) if len(list_value) > 1 else to_str(None)
else:
val = to_str(list_value[1:])
return (name, val)
return (default_name, to_str(v))
return [normalize_tuple(i, o) for i, o in enumerate(self.output)]
def log(self, level: LogErrorLevel, tag: str, message: str):
self.log_messages.append((level, tag, message))
def _get_obj_path(self, name: str) -> str:
return os.path.join(self.executor.session_dir, "cwd", name)
def call_llm_api(self, messages: List[Dict[str, str]], **args: Any) -> Any:
# TODO: use llm_api from handle side
return None
def get_env(self, plugin_name: str, variable_name: str):
# To avoid duplicate env variable, use plugin_name and vari_name to compose the final environment variable
name = f"PLUGIN_{plugin_name}_{variable_name}"
if name in os.environ:
return os.environ[name]
raise Exception(
"Environment variable " + name + " is required to be specified in environment",
)
def get_session_var(
self,
variable_name: str,
default: Optional[str] = None,
) -> Optional[str]:
if variable_name in self.executor.session_var:
return self.executor.session_var[variable_name]
return default
|