TRaw's picture
Upload 297 files
3d3d712
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from taskweaver.plugin.context import ArtifactType, LogErrorLevel, PluginContext
if TYPE_CHECKING:
from taskweaver.ces.runtime.executor import Executor
class ExecutorPluginContext(PluginContext):
def __init__(self, executor: Any) -> None:
self.executor: Executor = executor
self.artifact_list: List[Dict[str, str]] = []
self.log_messages: List[Tuple[LogErrorLevel, str, str]] = []
self.output: List[Tuple[str, str]] = []
@property
def execution_id(self) -> str:
return self.executor.cur_execution_id
@property
def session_id(self) -> str:
return self.executor.session_id
@property
def env_id(self) -> str:
return self.executor.env_id
@property
def execution_idx(self) -> int:
return self.executor.cur_execution_count
def add_artifact(
self,
name: str,
file_name: str,
type: ArtifactType,
val: Any,
desc: Optional[str] = None,
) -> str:
desc_preview = desc if desc is not None else self._get_preview_by_type(type, val)
id, path = self.create_artifact_path(name, file_name, type, desc=desc_preview)
if type == "chart":
with open(path, "w") as f:
f.write(val)
elif type == "df":
val.to_csv(path, index=False)
elif type == "file" or type == "txt" or type == "svg" or type == "html":
with open(path, "w") as f:
f.write(val)
else:
raise Exception("unsupported data type")
return id
def _get_preview_by_type(self, type: str, val: Any) -> str:
if type == "chart":
preview = "chart"
elif type == "df":
preview = f"DataFrame in shape {val.shape} with columns {list(val.columns)}"
elif type == "file" or type == "txt":
preview = str(val)[:100]
elif type == "html":
preview = "Web Page"
else:
preview = str(val)
return preview
def create_artifact_path(
self,
name: str,
file_name: str,
type: ArtifactType,
desc: str,
) -> Tuple[str, str]:
id = f"obj_{self.execution_idx}_{type}_{len(self.artifact_list):04x}"
file_path = f"{id}_{file_name}"
full_file_path = self._get_obj_path(file_path)
self.artifact_list.append(
{
"name": name,
"type": type,
"original_name": file_name,
"file": file_path,
"preview": desc,
},
)
return id, full_file_path
def set_output(self, output: List[Tuple[str, str]]):
if isinstance(output, list):
self.output.extend(output)
else:
self.output.append((str(output), ""))
def get_normalized_output(self):
def to_str(v: Any) -> str:
# TODO: configure/tune value length limit
# TODO: handle known/common data types explicitly
return str(v)[:5000]
def normalize_tuple(i: int, v: Any) -> Tuple[str, str]:
default_name = f"execution_result_{i + 1}"
if isinstance(v, tuple) or isinstance(v, list):
list_value: Any = v
name = to_str(list_value[0]) if len(list_value) > 0 else default_name
if len(list_value) <= 2:
val = to_str(list_value[1]) if len(list_value) > 1 else to_str(None)
else:
val = to_str(list_value[1:])
return (name, val)
return (default_name, to_str(v))
return [normalize_tuple(i, o) for i, o in enumerate(self.output)]
def log(self, level: LogErrorLevel, tag: str, message: str):
self.log_messages.append((level, tag, message))
def _get_obj_path(self, name: str) -> str:
return os.path.join(self.executor.session_dir, "cwd", name)
def call_llm_api(self, messages: List[Dict[str, str]], **args: Any) -> Any:
# TODO: use llm_api from handle side
return None
def get_env(self, plugin_name: str, variable_name: str):
# To avoid duplicate env variable, use plugin_name and vari_name to compose the final environment variable
name = f"PLUGIN_{plugin_name}_{variable_name}"
if name in os.environ:
return os.environ[name]
raise Exception(
"Environment variable " + name + " is required to be specified in environment",
)
def get_session_var(
self,
variable_name: str,
default: Optional[str] = None,
) -> Optional[str]:
if variable_name in self.executor.session_var:
return self.executor.session_var[variable_name]
return default