from typing import Dict, List import numpy as np from injector import inject from sklearn.metrics.pairwise import cosine_similarity from taskweaver.llm import LLMApi from taskweaver.memory.plugin import PluginEntry, PluginRegistry class SelectedPluginPool: def __init__(self): self.selected_plugin_pool = [] self._previous_used_plugin_cache = [] # cache the plugins used in the previous code generation def add_selected_plugins(self, external_plugin_pool: List[PluginEntry]): """ Add selected plugins to the pool """ self.selected_plugin_pool = self.merge_plugin_pool(self.selected_plugin_pool, external_plugin_pool) def __len__(self) -> int: return len(self.selected_plugin_pool) def filter_unused_plugins(self, code: str): """ Filter out plugins that are not used in the code generated by LLM """ plugins_used_in_code = [p for p in self.selected_plugin_pool if p.name in code] self._previous_used_plugin_cache = self.merge_plugin_pool( self._previous_used_plugin_cache, plugins_used_in_code, ) self.selected_plugin_pool = self._previous_used_plugin_cache def get_plugins(self) -> List[PluginEntry]: return self.selected_plugin_pool @staticmethod def merge_plugin_pool(pool1: List[PluginEntry], pool2: List[PluginEntry]) -> List[PluginEntry]: """ Merge two plugin pools and remove duplicates """ merged_list = pool1 + pool2 result = [] for item in merged_list: is_duplicate = False for existing_item in result: if item.name == existing_item.name: is_duplicate = True break if not is_duplicate: result.append(item) return result class PluginSelector: @inject def __init__( self, plugin_registry: PluginRegistry, llm_api: LLMApi, plugin_only: bool = False, ): if plugin_only: self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] else: self.available_plugins = plugin_registry.get_list() self.llm_api = llm_api self.plugin_embedding_dict: Dict[str, List[float]] = {} def generate_plugin_embeddings(self): plugin_intro_text_list: List[str] = [] for p in self.available_plugins: plugin_intro_text_list.append(p.name + ": " + p.spec.description) plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list) for i, p in enumerate(self.available_plugins): self.plugin_embedding_dict[p.name] = plugin_embeddings[i] def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: user_query_embedding = np.array(self.llm_api.get_embedding(user_query)) similarities = [] if top_k >= len(self.available_plugins): return self.available_plugins for p in self.available_plugins: similarity = cosine_similarity( user_query_embedding.reshape( 1, -1, ), np.array(self.plugin_embedding_dict[p.name]).reshape(1, -1), ) similarities.append((p, similarity)) plugins_rank = sorted( similarities, key=lambda x: x[1], reverse=True, )[:top_k] selected_plugins = [p for p, sim in plugins_rank] return selected_plugins