|
|
from typing import Dict, List |
|
|
|
|
|
import numpy as np |
|
|
from injector import inject |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
from taskweaver.llm import LLMApi |
|
|
from taskweaver.memory.plugin import PluginEntry, PluginRegistry |
|
|
|
|
|
|
|
|
class SelectedPluginPool: |
|
|
def __init__(self): |
|
|
self.selected_plugin_pool = [] |
|
|
self._previous_used_plugin_cache = [] |
|
|
|
|
|
def add_selected_plugins(self, external_plugin_pool: List[PluginEntry]): |
|
|
""" |
|
|
Add selected plugins to the pool |
|
|
""" |
|
|
self.selected_plugin_pool = self.merge_plugin_pool(self.selected_plugin_pool, external_plugin_pool) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.selected_plugin_pool) |
|
|
|
|
|
def filter_unused_plugins(self, code: str): |
|
|
""" |
|
|
Filter out plugins that are not used in the code generated by LLM |
|
|
""" |
|
|
plugins_used_in_code = [p for p in self.selected_plugin_pool if p.name in code] |
|
|
self._previous_used_plugin_cache = self.merge_plugin_pool( |
|
|
self._previous_used_plugin_cache, |
|
|
plugins_used_in_code, |
|
|
) |
|
|
self.selected_plugin_pool = self._previous_used_plugin_cache |
|
|
|
|
|
def get_plugins(self) -> List[PluginEntry]: |
|
|
return self.selected_plugin_pool |
|
|
|
|
|
@staticmethod |
|
|
def merge_plugin_pool(pool1: List[PluginEntry], pool2: List[PluginEntry]) -> List[PluginEntry]: |
|
|
""" |
|
|
Merge two plugin pools and remove duplicates |
|
|
""" |
|
|
merged_list = pool1 + pool2 |
|
|
result = [] |
|
|
|
|
|
for item in merged_list: |
|
|
is_duplicate = False |
|
|
for existing_item in result: |
|
|
if item.name == existing_item.name: |
|
|
is_duplicate = True |
|
|
break |
|
|
if not is_duplicate: |
|
|
result.append(item) |
|
|
return result |
|
|
|
|
|
|
|
|
class PluginSelector: |
|
|
@inject |
|
|
def __init__( |
|
|
self, |
|
|
plugin_registry: PluginRegistry, |
|
|
llm_api: LLMApi, |
|
|
plugin_only: bool = False, |
|
|
): |
|
|
if plugin_only: |
|
|
self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] |
|
|
else: |
|
|
self.available_plugins = plugin_registry.get_list() |
|
|
self.llm_api = llm_api |
|
|
self.plugin_embedding_dict: Dict[str, List[float]] = {} |
|
|
|
|
|
def generate_plugin_embeddings(self): |
|
|
plugin_intro_text_list: List[str] = [] |
|
|
for p in self.available_plugins: |
|
|
plugin_intro_text_list.append(p.name + ": " + p.spec.description) |
|
|
plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list) |
|
|
for i, p in enumerate(self.available_plugins): |
|
|
self.plugin_embedding_dict[p.name] = plugin_embeddings[i] |
|
|
|
|
|
def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: |
|
|
user_query_embedding = np.array(self.llm_api.get_embedding(user_query)) |
|
|
|
|
|
similarities = [] |
|
|
|
|
|
if top_k >= len(self.available_plugins): |
|
|
return self.available_plugins |
|
|
|
|
|
for p in self.available_plugins: |
|
|
similarity = cosine_similarity( |
|
|
user_query_embedding.reshape( |
|
|
1, |
|
|
-1, |
|
|
), |
|
|
np.array(self.plugin_embedding_dict[p.name]).reshape(1, -1), |
|
|
) |
|
|
similarities.append((p, similarity)) |
|
|
|
|
|
plugins_rank = sorted( |
|
|
similarities, |
|
|
key=lambda x: x[1], |
|
|
reverse=True, |
|
|
)[:top_k] |
|
|
|
|
|
selected_plugins = [p for p, sim in plugins_rank] |
|
|
|
|
|
return selected_plugins |
|
|
|