File size: 7,877 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import os
import tempfile
import traceback
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from taskweaver.ces.common import EnvPlugin
from taskweaver.ces.runtime.context import ExecutorPluginContext, LogErrorLevel
from taskweaver.plugin.base import Plugin
from taskweaver.plugin.context import PluginContext
@dataclass
class PluginTestEntry:
name: str
description: str
test: Callable[[Plugin], None]
@dataclass
class RuntimePlugin(EnvPlugin):
initializer: Optional[type[Plugin]] = None
test_cases: List[PluginTestEntry] = field(default_factory=list)
@property
def module_name(self) -> str:
return f"taskweaver_ext.plugin.{self.name}"
def load_impl(self):
if self.loaded:
return
def register_plugin(impl: Type[Plugin]):
if self.initializer is not None:
raise Exception(
f"duplicated plugin impl registration for plugin {self.name}",
)
self.initializer = impl
def register_plugin_test(
test_name: str,
test_desc: str,
test_impl: Callable[[Plugin], None],
):
self.test_cases.append(
PluginTestEntry(
test_name,
test_desc,
test_impl,
),
)
try:
# the following code is to load the plugin module and register the plugin
import importlib
import os
import sys
from taskweaver.plugin import register
module_name = self.module_name
with tempfile.TemporaryDirectory() as temp_dir:
module_path = os.path.join(temp_dir, f"{self.name}.py")
with open(module_path, "w") as f:
f.write(self.impl)
spec = importlib.util.spec_from_file_location( # type: ignore
module_name,
module_path,
)
module = importlib.util.module_from_spec(spec) # type: ignore
sys.modules[module_name] = module # type: ignore
register.register_plugin_inner = register_plugin
register.register_plugin_test_inner = register_plugin_test
spec.loader.exec_module(module) # type: ignore
register.register_plugin_inner = None
register.register_plugin_test_inner = None
if self.initializer is None:
raise Exception("no registration found")
except Exception as e:
traceback.print_exc()
raise Exception(f"failed to load plugin {self.name} {str(e)}")
self.loaded = True
def unload_impl(self):
if not self.loaded:
return
# attempt to unload the module, though it is not guaranteed to work
# there might be some memory leak or other issues there are still some references to
# certain code inside of the original module
try:
self.initializer = None
import sys
del sys.modules[self.module_name]
except Exception:
pass
self.loaded = False
def get_instance(self, context: PluginContext) -> Plugin:
if self.initializer is None:
raise Exception(f"plugin {self.name} is not loaded")
try:
return self.initializer(self.name, context, self.config or {})
except Exception as e:
raise Exception(
f"failed to create instance for plugin {self.name} {str(e)}",
)
def test_impl(self):
error_list: List[str] = []
from taskweaver.plugin.context import temp_context
for test in self.test_cases:
try:
with temp_context() as ctx:
print("=====================================================")
print("Test Name:", test.name)
print("Test Description:", test.description)
print("Running Test...")
inst = self.get_instance(ctx)
test.test(inst)
print()
except Exception as e:
traceback.print_exc()
error_list.append(
f"failed to test plugin {self.name} on {test.name} ({test.description}) \n {str(e)}",
)
return len(error_list) == 0, error_list
class Executor:
def __init__(self, env_id: str, session_id: str, session_dir: str) -> None:
self.env_id: str = env_id
self.session_id: str = session_id
self.session_dir: str = session_dir
# Session var management
self.session_var: Dict[str, str] = {}
# Plugin management state
self.plugin_registry: Dict[str, RuntimePlugin] = {}
# Execution counter and id
self.cur_execution_count: int = 0
self.cur_execution_id: str = ""
self._init_session_dir()
self.ctx: ExecutorPluginContext = ExecutorPluginContext(self)
def _init_session_dir(self):
if not os.path.exists(self.session_dir):
os.makedirs(self.session_dir)
def pre_execution(self, exec_idx: int, exec_id: str):
self.cur_execution_count = exec_idx
self.cur_execution_id = exec_id
self.ctx.artifact_list = []
self.ctx.log_messages = []
self.ctx.output = []
def load_lib(self, local_ns: Dict[str, Any]):
try:
pd = __import__("pandas")
# customize pandas display options
pd.set_option("display.html.table_schema", False)
pd.set_option("display.notebook_repr_html", False)
pd.set_option("display.max_rows", 4)
pd.set_option("display.expand_frame_repr", False)
local_ns["pd"] = pd
except ImportError:
self.log(
"warning",
"recommended package pandas not found, certain functions may not work properly",
)
try:
local_ns["np"] = __import__("numpy")
except ImportError:
self.log(
"warning",
"recommended package numpy not found, certain functions may not work properly",
)
try:
local_ns["plt"] = __import__("matplotlib.pyplot")
except ImportError:
self.log(
"warning",
"recommended package matplotlib not found, certain functions may not work properly",
)
def register_plugin(self, plugin_name: str, plugin_impl: str):
plugin = RuntimePlugin(
plugin_name,
plugin_impl,
None,
False,
)
plugin.load_impl()
self.plugin_registry[plugin_name] = plugin
def config_plugin(self, plugin_name: str, plugin_config: Dict[str, str]):
plugin = self.plugin_registry[plugin_name]
plugin.config = plugin_config
def get_plugin_instance(self, plugin_name: str) -> Plugin:
plugin = self.plugin_registry[plugin_name]
return plugin.get_instance(self.ctx)
def test_plugin(self, plugin_name: str) -> tuple[bool, list[str]]:
plugin = self.plugin_registry[plugin_name]
return plugin.test_impl()
def get_post_execution_state(self):
return {
"artifact": self.ctx.artifact_list,
"log": self.ctx.log_messages,
"output": self.ctx.get_normalized_output(),
}
def log(self, level: LogErrorLevel, message: str):
self.ctx.log(level, "Engine", message)
def update_session_var(self, variables: Dict[str, str]):
self.session_var = {str(k): str(v) for k, v in variables.items()}
|