|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from datetime import timedelta |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from injector import Module, provider |
|
|
|
|
|
from taskweaver.config.module_config import ModuleConfig |
|
|
from taskweaver.misc.component_registry import ComponentRegistry |
|
|
from taskweaver.utils import read_yaml, validate_yaml |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PluginParameter: |
|
|
"""PluginParameter is the data structure for plugin parameters (including arguments and return values.)""" |
|
|
|
|
|
name: str = "" |
|
|
type: str = "None" |
|
|
required: bool = False |
|
|
description: Optional[str] = None |
|
|
|
|
|
@staticmethod |
|
|
def from_dict(d: Dict[str, Any]): |
|
|
return PluginParameter( |
|
|
name=d["name"], |
|
|
description=d["description"], |
|
|
required=d["required"] if "required" in d else False, |
|
|
type=d["type"] if "type" in d else "Any", |
|
|
) |
|
|
|
|
|
def format_prompt(self, indent: int = 0) -> str: |
|
|
lines: List[str] = [] |
|
|
|
|
|
def line(cnt: str): |
|
|
lines.append(" " * indent + cnt) |
|
|
|
|
|
line(f"- name: {self.name}") |
|
|
line(f" type: {self.type}") |
|
|
line(f" required: {self.required}") |
|
|
line(f" description: {self.description}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PluginSpec: |
|
|
"""PluginSpec is the data structure for plugin specification defined in the yaml files.""" |
|
|
|
|
|
name: str = "" |
|
|
description: str = "" |
|
|
args: List[PluginParameter] = field(default_factory=list) |
|
|
returns: List[PluginParameter] = field(default_factory=list) |
|
|
embedding: List[float] = field(default_factory=list) |
|
|
|
|
|
@staticmethod |
|
|
def from_dict(d: Dict[str, Any]): |
|
|
return PluginSpec( |
|
|
name=d["name"], |
|
|
description=d["description"], |
|
|
args=[PluginParameter.from_dict(p) for p in d["parameters"]], |
|
|
returns=[PluginParameter.from_dict(p) for p in d["returns"]], |
|
|
embedding=[], |
|
|
) |
|
|
|
|
|
def format_prompt(self) -> str: |
|
|
def normalize_type(t: str) -> str: |
|
|
if t.lower() == "string": |
|
|
return "str" |
|
|
if t.lower() == "integer": |
|
|
return "int" |
|
|
return t |
|
|
|
|
|
def normalize_description(d: str) -> str: |
|
|
d = d.strip().replace("\n", "\n# ") |
|
|
return d |
|
|
|
|
|
def normalize_value(v: PluginParameter) -> PluginParameter: |
|
|
return PluginParameter( |
|
|
name=v.name, |
|
|
type=normalize_type(v.type), |
|
|
required=v.required, |
|
|
description=normalize_description(v.description or ""), |
|
|
) |
|
|
|
|
|
def format_arg_val(val: PluginParameter) -> str: |
|
|
val = normalize_value(val) |
|
|
type_val = f"Optional[{val.type}]" if val.type != "Any" and not val.required else "Any" |
|
|
if val.description is not None: |
|
|
return f"\n# {val.description}\n{val.name}: {type_val}" |
|
|
return f"{val.name}: {type_val}" |
|
|
|
|
|
param_list = ",".join([format_arg_val(p) for p in self.args]) |
|
|
|
|
|
return_type = "" |
|
|
if len(self.returns) > 1: |
|
|
|
|
|
def format_return_val(val: PluginParameter) -> str: |
|
|
val = normalize_value(val) |
|
|
if val.description is not None: |
|
|
return f"\n# {val.name}: {val.description}\n{val.type}" |
|
|
return val.type |
|
|
|
|
|
return_type = f"Tuple[{','.join([format_return_val(r) for r in self.returns])}]" |
|
|
elif len(self.returns) == 1: |
|
|
rv = normalize_value(self.returns[0]) |
|
|
if rv.description is not None: |
|
|
return_type = f"\\\n# {rv.name}: {rv.description}\n{rv.type}" |
|
|
return_type = rv.type |
|
|
else: |
|
|
return_type = "None" |
|
|
return f"# {self.description}\ndef {self.name}({param_list}) -> {return_type}:...\n" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PluginEntry: |
|
|
name: str |
|
|
plugin_only: bool |
|
|
impl: str |
|
|
spec: PluginSpec |
|
|
config: Dict[str, Any] |
|
|
required: bool |
|
|
enabled: bool = True |
|
|
|
|
|
@staticmethod |
|
|
def from_yaml_file(path: str) -> Optional["PluginEntry"]: |
|
|
content = read_yaml(path) |
|
|
return PluginEntry.from_yaml_content(content) |
|
|
|
|
|
@staticmethod |
|
|
def from_yaml_content(content: Dict) -> Optional["PluginEntry"]: |
|
|
do_validate = False |
|
|
valid_state = False |
|
|
if do_validate: |
|
|
valid_state = validate_yaml(content, schema="plugin_schema") |
|
|
if not do_validate or valid_state: |
|
|
spec: PluginSpec = PluginSpec.from_dict(content) |
|
|
return PluginEntry( |
|
|
name=spec.name, |
|
|
impl=content.get("code", spec.name), |
|
|
spec=spec, |
|
|
config=content.get("configurations", {}), |
|
|
required=content.get("required", False), |
|
|
enabled=content.get("enabled", True), |
|
|
plugin_only=content.get("plugin_only", False), |
|
|
) |
|
|
return None |
|
|
|
|
|
def format_prompt(self) -> str: |
|
|
return self.spec.format_prompt() |
|
|
|
|
|
def to_dict(self): |
|
|
return { |
|
|
"name": self.name, |
|
|
"impl": self.impl, |
|
|
"spec": self.spec, |
|
|
"config": self.config, |
|
|
"required": self.required, |
|
|
"enabled": self.enabled, |
|
|
} |
|
|
|
|
|
def format_function_calling(self) -> Dict: |
|
|
assert self.plugin_only is True, "Only `plugin_only` plugins can be called in this way." |
|
|
|
|
|
def map_type(t: str) -> str: |
|
|
if t.lower() == "string" or t.lower() == "str" or t.lower() == "text": |
|
|
return "string" |
|
|
if t.lower() == "integer" or t.lower() == "int": |
|
|
return "integer" |
|
|
if t.lower() == "float" or t.lower() == "double" or t.lower() == "number": |
|
|
return "number" |
|
|
if t.lower() == "boolean" or t.lower() == "bool": |
|
|
return "boolean" |
|
|
if t.lower() == "null" or t.lower() == "none": |
|
|
return "null" |
|
|
raise Exception(f"unknown type {t}") |
|
|
|
|
|
function = {"type": "function", "function": {}} |
|
|
required_params = [] |
|
|
function["function"]["name"] = self.name |
|
|
function["function"]["description"] = self.spec.description |
|
|
function["function"]["parameters"] = {"type": "object", "properties": {}} |
|
|
for arg in self.spec.args: |
|
|
function["function"]["parameters"]["properties"][arg.name] = { |
|
|
"type": map_type(arg.type), |
|
|
"description": arg.description, |
|
|
} |
|
|
if arg.required: |
|
|
required_params.append(arg.name) |
|
|
function["function"]["parameters"]["required"] = required_params |
|
|
|
|
|
return function |
|
|
|
|
|
|
|
|
class PluginRegistry(ComponentRegistry[PluginEntry]): |
|
|
def __init__( |
|
|
self, |
|
|
file_glob: str, |
|
|
ttl: Optional[timedelta] = None, |
|
|
) -> None: |
|
|
super().__init__(file_glob, ttl) |
|
|
|
|
|
def _load_component(self, path: str) -> Tuple[str, PluginEntry]: |
|
|
entry: Optional[PluginEntry] = PluginEntry.from_yaml_file(path) |
|
|
if entry is None: |
|
|
raise Exception(f"failed to loading plugin from {path}") |
|
|
if not entry.enabled: |
|
|
raise Exception(f"plugin {entry.name} is disabled") |
|
|
return entry.name, entry |
|
|
|
|
|
|
|
|
class PluginModuleConfig(ModuleConfig): |
|
|
def _configure(self) -> None: |
|
|
self._set_name("plugin") |
|
|
app_dir = self.src.app_base_path |
|
|
self.base_path = self._get_path( |
|
|
"base_path", |
|
|
os.path.join( |
|
|
app_dir, |
|
|
"plugins", |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class PluginModule(Module): |
|
|
@provider |
|
|
def provide_plugin_registry( |
|
|
self, |
|
|
config: PluginModuleConfig, |
|
|
) -> PluginRegistry: |
|
|
import os |
|
|
|
|
|
file_glob = os.path.join(config.base_path, "*.yaml") |
|
|
return PluginRegistry( |
|
|
file_glob=file_glob, |
|
|
ttl=timedelta(minutes=10), |
|
|
) |
|
|
|