File size: 3,616 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import Dict, List

import numpy as np
from injector import inject
from sklearn.metrics.pairwise import cosine_similarity

from taskweaver.llm import LLMApi
from taskweaver.memory.plugin import PluginEntry, PluginRegistry


class SelectedPluginPool:
    def __init__(self):
        self.selected_plugin_pool = []
        self._previous_used_plugin_cache = []  # cache the plugins used in the previous code generation

    def add_selected_plugins(self, external_plugin_pool: List[PluginEntry]):
        """
        Add selected plugins to the pool
        """
        self.selected_plugin_pool = self.merge_plugin_pool(self.selected_plugin_pool, external_plugin_pool)

    def __len__(self) -> int:
        return len(self.selected_plugin_pool)

    def filter_unused_plugins(self, code: str):
        """
        Filter out plugins that are not used in the code generated by LLM
        """
        plugins_used_in_code = [p for p in self.selected_plugin_pool if p.name in code]
        self._previous_used_plugin_cache = self.merge_plugin_pool(
            self._previous_used_plugin_cache,
            plugins_used_in_code,
        )
        self.selected_plugin_pool = self._previous_used_plugin_cache

    def get_plugins(self) -> List[PluginEntry]:
        return self.selected_plugin_pool

    @staticmethod
    def merge_plugin_pool(pool1: List[PluginEntry], pool2: List[PluginEntry]) -> List[PluginEntry]:
        """
        Merge two plugin pools and remove duplicates
        """
        merged_list = pool1 + pool2
        result = []

        for item in merged_list:
            is_duplicate = False
            for existing_item in result:
                if item.name == existing_item.name:
                    is_duplicate = True
                    break
            if not is_duplicate:
                result.append(item)
        return result


class PluginSelector:
    @inject
    def __init__(
        self,
        plugin_registry: PluginRegistry,
        llm_api: LLMApi,
        plugin_only: bool = False,
    ):
        if plugin_only:
            self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True]
        else:
            self.available_plugins = plugin_registry.get_list()
        self.llm_api = llm_api
        self.plugin_embedding_dict: Dict[str, List[float]] = {}

    def generate_plugin_embeddings(self):
        plugin_intro_text_list: List[str] = []
        for p in self.available_plugins:
            plugin_intro_text_list.append(p.name + ": " + p.spec.description)
        plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list)
        for i, p in enumerate(self.available_plugins):
            self.plugin_embedding_dict[p.name] = plugin_embeddings[i]

    def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]:
        user_query_embedding = np.array(self.llm_api.get_embedding(user_query))

        similarities = []

        if top_k >= len(self.available_plugins):
            return self.available_plugins

        for p in self.available_plugins:
            similarity = cosine_similarity(
                user_query_embedding.reshape(
                    1,
                    -1,
                ),
                np.array(self.plugin_embedding_dict[p.name]).reshape(1, -1),
            )
            similarities.append((p, similarity))

        plugins_rank = sorted(
            similarities,
            key=lambda x: x[1],
            reverse=True,
        )[:top_k]

        selected_plugins = [p for p, sim in plugins_rank]

        return selected_plugins