File size: 3,616 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from typing import Dict, List
import numpy as np
from injector import inject
from sklearn.metrics.pairwise import cosine_similarity
from taskweaver.llm import LLMApi
from taskweaver.memory.plugin import PluginEntry, PluginRegistry
class SelectedPluginPool:
def __init__(self):
self.selected_plugin_pool = []
self._previous_used_plugin_cache = [] # cache the plugins used in the previous code generation
def add_selected_plugins(self, external_plugin_pool: List[PluginEntry]):
"""
Add selected plugins to the pool
"""
self.selected_plugin_pool = self.merge_plugin_pool(self.selected_plugin_pool, external_plugin_pool)
def __len__(self) -> int:
return len(self.selected_plugin_pool)
def filter_unused_plugins(self, code: str):
"""
Filter out plugins that are not used in the code generated by LLM
"""
plugins_used_in_code = [p for p in self.selected_plugin_pool if p.name in code]
self._previous_used_plugin_cache = self.merge_plugin_pool(
self._previous_used_plugin_cache,
plugins_used_in_code,
)
self.selected_plugin_pool = self._previous_used_plugin_cache
def get_plugins(self) -> List[PluginEntry]:
return self.selected_plugin_pool
@staticmethod
def merge_plugin_pool(pool1: List[PluginEntry], pool2: List[PluginEntry]) -> List[PluginEntry]:
"""
Merge two plugin pools and remove duplicates
"""
merged_list = pool1 + pool2
result = []
for item in merged_list:
is_duplicate = False
for existing_item in result:
if item.name == existing_item.name:
is_duplicate = True
break
if not is_duplicate:
result.append(item)
return result
class PluginSelector:
@inject
def __init__(
self,
plugin_registry: PluginRegistry,
llm_api: LLMApi,
plugin_only: bool = False,
):
if plugin_only:
self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True]
else:
self.available_plugins = plugin_registry.get_list()
self.llm_api = llm_api
self.plugin_embedding_dict: Dict[str, List[float]] = {}
def generate_plugin_embeddings(self):
plugin_intro_text_list: List[str] = []
for p in self.available_plugins:
plugin_intro_text_list.append(p.name + ": " + p.spec.description)
plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list)
for i, p in enumerate(self.available_plugins):
self.plugin_embedding_dict[p.name] = plugin_embeddings[i]
def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]:
user_query_embedding = np.array(self.llm_api.get_embedding(user_query))
similarities = []
if top_k >= len(self.available_plugins):
return self.available_plugins
for p in self.available_plugins:
similarity = cosine_similarity(
user_query_embedding.reshape(
1,
-1,
),
np.array(self.plugin_embedding_dict[p.name]).reshape(1, -1),
)
similarities.append((p, similarity))
plugins_rank = sorted(
similarities,
key=lambda x: x[1],
reverse=True,
)[:top_k]
selected_plugins = [p for p, sim in plugins_rank]
return selected_plugins
|