| import os | |
| from injector import Injector | |
| from taskweaver.config.config_mgt import AppConfigSource | |
| from taskweaver.logging import LoggingModule | |
| from taskweaver.memory.plugin import PluginModule, PluginRegistry | |
| def test_load_plugin_yaml(): | |
| app_injector = Injector( | |
| [LoggingModule, PluginModule], | |
| ) | |
| app_config = AppConfigSource( | |
| config={ | |
| "plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"), | |
| }, | |
| ) | |
| app_injector.binder.bind(AppConfigSource, to=app_config) | |
| plugin_registry = app_injector.get(PluginRegistry) | |
| assert len(plugin_registry.registry) == 4 | |
| assert "anomaly_detection" in plugin_registry.registry | |
| assert plugin_registry.registry["anomaly_detection"].spec.name == "anomaly_detection" | |
| assert plugin_registry.registry["anomaly_detection"].spec.description.startswith( | |
| "anomaly_detection function identifies anomalies", | |
| ) | |
| assert plugin_registry.registry["anomaly_detection"].impl == "anomaly_detection" | |
| assert len(plugin_registry.registry["anomaly_detection"].spec.args) == 3 | |
| assert plugin_registry.registry["anomaly_detection"].spec.args[0].name == "df" | |
| assert plugin_registry.registry["anomaly_detection"].spec.args[0].type == "DataFrame" | |
| assert ( | |
| plugin_registry.registry["anomaly_detection"].spec.args[0].description | |
| == "the input data from which we can identify the " | |
| "anomalies with the 3-sigma algorithm." | |
| ) | |
| assert plugin_registry.registry["anomaly_detection"].spec.args[0].required == True | |
| assert len(plugin_registry.registry["anomaly_detection"].spec.returns) == 2 | |
| assert plugin_registry.registry["anomaly_detection"].spec.returns[0].name == "df" | |
| assert plugin_registry.registry["anomaly_detection"].spec.returns[0].type == "DataFrame" | |
| assert ( | |
| plugin_registry.registry["anomaly_detection"].spec.returns[0].description == "This DataFrame extends the input " | |
| "DataFrame with a newly-added column " | |
| '"Is_Anomaly" containing the anomaly detection result.' | |
| ) | |
| def test_plugin_format_prompt(): | |
| app_injector = Injector( | |
| [PluginModule, LoggingModule], | |
| ) | |
| app_config = AppConfigSource( | |
| config={ | |
| "plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"), | |
| }, | |
| ) | |
| app_injector.binder.bind(AppConfigSource, to=app_config) | |
| plugin_registry = app_injector.get(PluginRegistry) | |
| assert plugin_registry.registry["anomaly_detection"].format_prompt() == ( | |
| "# anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new " | |
| 'column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" ' | |
| "otherwise.\n" | |
| "def anomaly_detection(\n" | |
| "# the input data from which we can identify the anomalies with the 3-sigma algorithm.\n" | |
| "df: Any,\n" | |
| "# name of the column that contains the datetime\n" | |
| "time_col_name: Any,\n" | |
| "# name of the column that contains the numeric values.\n" | |
| "value_col_name: Any) -> Tuple[\n" | |
| '# df: This DataFrame extends the input DataFrame with a newly-added column "Is_Anomaly" containing the ' | |
| "anomaly detection result.\n" | |
| "DataFrame,\n" | |
| "# description: This is a string describing the anomaly detection results.\n" | |
| "str]:...\n" | |
| ) | |