File size: 12,503 Bytes
369c2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c221420
 
369c2da
 
 
 
 
 
 
 
 
 
 
c221420
 
369c2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from sentence_transformers import SentenceTransformer, util
from nltk.tokenize import sent_tokenize
from util import select_random_from_list
from math import exp
import sys

all_MiniLM_L12_v2 = SentenceTransformer("all-MiniLM-L12-v2")
"""
all-MiniLM-L12-v2 is a sentence embedding model used for tasks involving semantic textual similarity, clustering, semantic search, and information retrieval. They convert a sentence to tensor based on their intent and then matching patterns like cos_sim can be used to compare them to other sentences.
"""

CONNECTION_PHRASES = ["just like that", "for the same", "similarly", "similar to the previous", "for that", "for it"]
CONNECTION_ENCODE = all_MiniLM_L12_v2.encode(CONNECTION_PHRASES,convert_to_tensor=True)

CLEAR_MESSAGES = ["delete", "delete context", "delete history", "clear", "clear context", "clear history", "reset", "reset context", "reset chat", "forget", "forget all"]
CLEAR_MESSAGES_ENCODE = all_MiniLM_L12_v2.encode(CLEAR_MESSAGES,convert_to_tensor=True)

prev_label = ""
prev_query_data = [] # Stores previous context if queries contain ambiguous content that may map to previous responses

confidence_threshold = 0.35 # Default confidence threshold
def generate_confidence_threshold(query: str, base=0.6, decay=0.03, min_threshold=0.25)->float:
    """Generate confidence threshold based on the sentence. Longer sentence lead to lower confidences, so confidence threshold is adjusted based on that.
    Parameters:
    1. query: Modify threshold based on this sentence
    2. base, decay: 0.8*e^(-(decay * no. of words in string))
    3. min_threshold: Clamp to minimum to avoid much lower confidence values"""
    global confidence_threshold
    length = len(query.split())
    confidence_threshold = max(base * exp(-decay * length), min_threshold)

# The value of each node contains the following data
# node.value[0] -> intent
# node.value[1] -> label
# node.value[2] -> examples
# node.value[3] -> response
# node.value[4] -> children

def cache_embeddings(tree, model = all_MiniLM_L12_v2)->None:
    """Store the encoded examples as part of the tree itself to avoid repetitive computations.
    Parameters:
    1. tree: Tree to cache embeddings
    2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""

    def _cache_node_embeddings(n):
        if isinstance(n.value, tuple) and len(n.value) >= 2:
            examples = n.value[2]
            n.embedding_cache = model.encode(examples, convert_to_tensor=True)
        for child in n.children:
            _cache_node_embeddings(child)
    _cache_node_embeddings(tree.root)

# SECOND_PERSON_MENTIONS = ["you", "youre", "your", "yours", "yourself", "y'all", "y'all's", "y'all'self", "you're", "your'e""u", "ur", "urs", "urself"]
MY_NAME = ["Dwarakesh","Dwarak","Dwara","Dwaraka"] # Delete for production purposes
def get_user_query(message="", model = all_MiniLM_L12_v2)->str:
    """Separate function to get input from user.
    Parameters:
    1. message: Show message to user before recieving input (Default: empty)
    2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""

    query = input(message).lower().strip()
    while query == "":
        query = input(message).lower().strip()
    
    query = query.replace("  "," ") # Remove double spaces
    for spm in MY_NAME: # Remove second person mentions
        query = query.replace(spm,"")  # Replace with 'you' for second person
    generate_confidence_threshold(query)
    query_encode = model.encode(query, convert_to_tensor=True)
    clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
    if clear_intent > confidence_threshold:
        return None
    return query

def _calculate_single_level(user_embed,predicted_intent):
    """Calculate predictions for the children of a single node. Each node contains a list of nodes as its children.
    Parameters:
    1. user_embed: User query converted to tensor
    2. predicted_intent: Calculate for children of this node"""

    categories = predicted_intent.children # List of node objects
    predicted_intent = None
    high_intent = 0
    for category in categories:
        if category.embedding_cache is None:
            raise ValueError("Embedding cache missing. Call cache_embeddings() on the tree first")
        score = util.cos_sim(user_embed, category.embedding_cache).max().item()

        if score > high_intent:
            high_intent = score
            predicted_intent = category # Node object
    return (predicted_intent,high_intent) # Returns the child node with the highest prediction confidence and the confidence value

def _store_prev_data(predicted_intent):
    """Store the previous computed data path.
    Parameters:
    1. predicted_intent: Store previous data w.r.t this node"""
    # Mutating global prev_query_data
    prev_query_data.clear()
    prev_context_treenode = predicted_intent
    while prev_context_treenode.parent: # Stop at tree root
        prev_query_data.append(prev_context_treenode)
        prev_context_treenode = prev_context_treenode.parent

def h_pass(tree, user_embed, predicted_intent = None)->tuple:
    """Use the model to pass through the tree to compare it with the user query in a hierarchical manner and return an output.
    Parameters:
    1. tree: Which tree to pass through hierarchically
    2. user_embed: User input converted to tensor
    3. predicted_intent: Where to start the pass from (Default: Root of the tree)"""
    global prev_label
    predicted_intent = tree.root if predicted_intent == None else predicted_intent
    predicted_intent_parent = None
    high_intent = 0
    passed_once = False
    pass_through_intent = {}
    while predicted_intent.children: # If the node has children, check for the child with the highest confidence value
        predicted_intent_parent = predicted_intent
        predicted_intent, high_intent = _calculate_single_level(user_embed,predicted_intent)
        pass_through_intent[predicted_intent] = high_intent # Store the confidence value of the current node
        if passed_once: # If the data didn't pass even once, then don't store it
            _store_prev_data(predicted_intent_parent) # Storing previous data w.r.t parent node as context is changed from current node
        if high_intent < confidence_threshold: # If highest confidence value is still too low, stop.
            prev_label = predicted_intent_parent.value[1]
            return (predicted_intent, predicted_intent_parent, high_intent, passed_once, False, pass_through_intent) # If the confidence value is low, stop   
        passed_once = True
    
    _store_prev_data(predicted_intent)
    prev_label = predicted_intent.value[1]
    return (predicted_intent, predicted_intent_parent, high_intent, passed_once, True, pass_through_intent)

def query_pass(tree, user_input, model=all_MiniLM_L12_v2)->list:
    """Separate multiple queries into separate single ones, analyze relation between them if any, and process them to give an output while storing incomplete query outputs in non-leaf list, which contains the current level of context.
    Parameters:
    1. tree: Which tree to pass through hierarchically
    2. user_input: User input that may contain one or more queries as a string
    3. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""

    queries = sent_tokenize(user_input)
    user_embeddings = [model.encode(query,convert_to_tensor=True) for query in queries]
    result = []
    label = prev_label

    for i in range(len(queries)):
        generate_confidence_threshold(queries[i])
        pass_value = (None, None, 0, False, False, None)
        # pass_value[0] -> current predicted intention (node)
        # pass_value[1] -> parent node of current predicted intention
        # pass_value[2] -> confidence level
        # pass_value[3] -> has the query passed through the model atleast once?
        # pass_value[4] -> has the query reached a leaf node?
        # pass_value[5] -> confidence values of traversal for query [DEBUGGING PURPOSES]

        # Acquiring data from previous query if the query has words matching with connecting phrases
        conn_sim = util.cos_sim(user_embeddings[i], CONNECTION_ENCODE).max().item()
        if conn_sim > confidence_threshold:
            queries[i] = queries[i] + label
            user_embeddings[i] = model.encode(queries[i], convert_to_tensor=True)

        # Pass values through the root node and the nodes that have the current context
        pass_value_root = h_pass(tree,user_embeddings[i]) # Passing through root node
        pass_value_nonleaf = [h_pass(tree,user_embeddings[i],j) for j in prev_query_data] # Passing through nodes that have current context
        all_nodes = [pass_value_root] + pass_value_nonleaf # List of all nodes that have been passed through
        pass_value = max(all_nodes, key=lambda x: x[2]) # Maximum confidence node for available context. Root is always a context.
        print(f"Query reach confidence: {[i[5] for i in all_nodes]}", file=sys.stderr) # DEBUGGING PURPOSES

        if pass_value[3]: # If the query has passed at least once, ask for data and store current result
            if not pass_value[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
                label = pass_value[1].value[1]
                result.append(f"{pass_value[1].value[3]}")
                # continue

            else: # Query has reached a leaf node
                label = pass_value[0].value[1]
                result.append(pass_value[0].value[3])
                # continue

        else: # Query has not passed even once. Check if it works when previous context is available
            for parent_context in prev_query_data:
                pass_value_context = h_pass(tree, user_embeddings[i], parent_context)
                if pass_value_context[3]: # Check if it has passed at least once
                    # If it has passed, then the query is valid
                    if not pass_value_context[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
                        label = pass_value_context[1].value[1]
                        result.append(f"What are you looking for in {pass_value_context[1].value[0]}? {pass_value_context[1].value[3]}")
                    else:
                        label = pass_value_context[0].value[1]
                        result.append(pass_value_context[0].value[3])
                    break # The else block won't be executed if code reaches here
            else: # The else statement of a for loop will execute only if the loop completes, and won't execute when broken by "break"
                result.append(f"I don't quite understand what you are trying to ask by \"{queries[i]}\"")
                # continue
    # End of "for" loop processing queries

    # Finally, return result. A list of responses same as the length of queries.
    return result

def process_user_query(query: str, model = all_MiniLM_L12_v2)->str:
    """Separate function to get input from user.
    Parameters:
    1. message: Show message to user before recieving input (Default: empty)
    2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""

    query = query.lower().strip()
    generate_confidence_threshold(query)
    query_encode = model.encode(query, convert_to_tensor=True)
    clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
    if clear_intent > confidence_threshold:
        return None
    return query

def interact_with_user(tree_data, user_input: str) -> str:
    """Handles a single user query and returns a response string."""
    user_input = process_user_query(user_input)
    all_results = []
    if user_input:  # If not empty or command
        results = query_pass(tree_data, user_input)
        for result in results:
            # return f"{select_random_from_list(result)}\nContext window: {prev_query_data}"
            all_results.append(f"{select_random_from_list(result)}")
        print(f"Previous query data: {prev_query_data}", file=sys.stderr)
        return all_results
    else:
        # Mutating global variables: Clearing context and recent history on command
        prev_query_data.clear()
        print(f"Previous query data: {prev_query_data}", file=sys.stderr)
        return ["Cleared previous context"]