Spaces:
Paused
Paused
| from core.tools.entities.values import default_tool_label_name_list | |
| from core.tools.provider.api_tool_provider import ApiToolProviderController | |
| from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | |
| from core.tools.provider.tool_provider import ToolProviderController | |
| from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController | |
| from extensions.ext_database import db | |
| from models.tools import ToolLabelBinding | |
| class ToolLabelManager: | |
| def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: | |
| """ | |
| Filter tool labels | |
| """ | |
| tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] | |
| return list(set(tool_labels)) | |
| def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): | |
| """ | |
| Update tool labels | |
| """ | |
| labels = cls.filter_tool_labels(labels) | |
| if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | |
| provider_id = controller.provider_id | |
| else: | |
| raise ValueError("Unsupported tool type") | |
| # delete old labels | |
| db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() | |
| # insert new labels | |
| for label in labels: | |
| db.session.add( | |
| ToolLabelBinding( | |
| tool_id=provider_id, | |
| tool_type=controller.provider_type.value, | |
| label_name=label, | |
| ) | |
| ) | |
| db.session.commit() | |
| def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: | |
| """ | |
| Get tool labels | |
| """ | |
| if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | |
| provider_id = controller.provider_id | |
| elif isinstance(controller, BuiltinToolProviderController): | |
| return controller.tool_labels | |
| else: | |
| raise ValueError("Unsupported tool type") | |
| labels: list[ToolLabelBinding] = ( | |
| db.session.query(ToolLabelBinding.label_name) | |
| .filter( | |
| ToolLabelBinding.tool_id == provider_id, | |
| ToolLabelBinding.tool_type == controller.provider_type.value, | |
| ) | |
| .all() | |
| ) | |
| return [label.label_name for label in labels] | |
| def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: | |
| """ | |
| Get tools labels | |
| :param tool_providers: list of tool providers | |
| :return: dict of tool labels | |
| :key: tool id | |
| :value: list of tool labels | |
| """ | |
| if not tool_providers: | |
| return {} | |
| for controller in tool_providers: | |
| if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | |
| raise ValueError("Unsupported tool type") | |
| provider_ids = [controller.provider_id for controller in tool_providers] | |
| labels: list[ToolLabelBinding] = ( | |
| db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |
| ) | |
| tool_labels = {label.tool_id: [] for label in labels} | |
| for label in labels: | |
| tool_labels[label.tool_id].append(label.label_name) | |
| return tool_labels | |