|
|
import os |
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from taskweaver.plugin.context import ArtifactType, LogErrorLevel, PluginContext |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from taskweaver.ces.runtime.executor import Executor |
|
|
|
|
|
|
|
|
class ExecutorPluginContext(PluginContext): |
|
|
def __init__(self, executor: Any) -> None: |
|
|
self.executor: Executor = executor |
|
|
|
|
|
self.artifact_list: List[Dict[str, str]] = [] |
|
|
self.log_messages: List[Tuple[LogErrorLevel, str, str]] = [] |
|
|
self.output: List[Tuple[str, str]] = [] |
|
|
|
|
|
@property |
|
|
def execution_id(self) -> str: |
|
|
return self.executor.cur_execution_id |
|
|
|
|
|
@property |
|
|
def session_id(self) -> str: |
|
|
return self.executor.session_id |
|
|
|
|
|
@property |
|
|
def env_id(self) -> str: |
|
|
return self.executor.env_id |
|
|
|
|
|
@property |
|
|
def execution_idx(self) -> int: |
|
|
return self.executor.cur_execution_count |
|
|
|
|
|
def add_artifact( |
|
|
self, |
|
|
name: str, |
|
|
file_name: str, |
|
|
type: ArtifactType, |
|
|
val: Any, |
|
|
desc: Optional[str] = None, |
|
|
) -> str: |
|
|
desc_preview = desc if desc is not None else self._get_preview_by_type(type, val) |
|
|
|
|
|
id, path = self.create_artifact_path(name, file_name, type, desc=desc_preview) |
|
|
if type == "chart": |
|
|
with open(path, "w") as f: |
|
|
f.write(val) |
|
|
elif type == "df": |
|
|
val.to_csv(path, index=False) |
|
|
elif type == "file" or type == "txt" or type == "svg" or type == "html": |
|
|
with open(path, "w") as f: |
|
|
f.write(val) |
|
|
else: |
|
|
raise Exception("unsupported data type") |
|
|
|
|
|
return id |
|
|
|
|
|
def _get_preview_by_type(self, type: str, val: Any) -> str: |
|
|
if type == "chart": |
|
|
preview = "chart" |
|
|
elif type == "df": |
|
|
preview = f"DataFrame in shape {val.shape} with columns {list(val.columns)}" |
|
|
elif type == "file" or type == "txt": |
|
|
preview = str(val)[:100] |
|
|
elif type == "html": |
|
|
preview = "Web Page" |
|
|
else: |
|
|
preview = str(val) |
|
|
return preview |
|
|
|
|
|
def create_artifact_path( |
|
|
self, |
|
|
name: str, |
|
|
file_name: str, |
|
|
type: ArtifactType, |
|
|
desc: str, |
|
|
) -> Tuple[str, str]: |
|
|
id = f"obj_{self.execution_idx}_{type}_{len(self.artifact_list):04x}" |
|
|
|
|
|
file_path = f"{id}_{file_name}" |
|
|
full_file_path = self._get_obj_path(file_path) |
|
|
|
|
|
self.artifact_list.append( |
|
|
{ |
|
|
"name": name, |
|
|
"type": type, |
|
|
"original_name": file_name, |
|
|
"file": file_path, |
|
|
"preview": desc, |
|
|
}, |
|
|
) |
|
|
return id, full_file_path |
|
|
|
|
|
def set_output(self, output: List[Tuple[str, str]]): |
|
|
if isinstance(output, list): |
|
|
self.output.extend(output) |
|
|
else: |
|
|
self.output.append((str(output), "")) |
|
|
|
|
|
def get_normalized_output(self): |
|
|
def to_str(v: Any) -> str: |
|
|
|
|
|
|
|
|
return str(v)[:5000] |
|
|
|
|
|
def normalize_tuple(i: int, v: Any) -> Tuple[str, str]: |
|
|
default_name = f"execution_result_{i + 1}" |
|
|
if isinstance(v, tuple) or isinstance(v, list): |
|
|
list_value: Any = v |
|
|
name = to_str(list_value[0]) if len(list_value) > 0 else default_name |
|
|
if len(list_value) <= 2: |
|
|
val = to_str(list_value[1]) if len(list_value) > 1 else to_str(None) |
|
|
else: |
|
|
val = to_str(list_value[1:]) |
|
|
return (name, val) |
|
|
|
|
|
return (default_name, to_str(v)) |
|
|
|
|
|
return [normalize_tuple(i, o) for i, o in enumerate(self.output)] |
|
|
|
|
|
def log(self, level: LogErrorLevel, tag: str, message: str): |
|
|
self.log_messages.append((level, tag, message)) |
|
|
|
|
|
def _get_obj_path(self, name: str) -> str: |
|
|
return os.path.join(self.executor.session_dir, "cwd", name) |
|
|
|
|
|
def call_llm_api(self, messages: List[Dict[str, str]], **args: Any) -> Any: |
|
|
|
|
|
return None |
|
|
|
|
|
def get_env(self, plugin_name: str, variable_name: str): |
|
|
|
|
|
name = f"PLUGIN_{plugin_name}_{variable_name}" |
|
|
if name in os.environ: |
|
|
return os.environ[name] |
|
|
raise Exception( |
|
|
"Environment variable " + name + " is required to be specified in environment", |
|
|
) |
|
|
|
|
|
def get_session_var( |
|
|
self, |
|
|
variable_name: str, |
|
|
default: Optional[str] = None, |
|
|
) -> Optional[str]: |
|
|
if variable_name in self.executor.session_var: |
|
|
return self.executor.session_var[variable_name] |
|
|
return default |
|
|
|