Spaces:
Running
Running
File size: 12,503 Bytes
369c2da c221420 369c2da c221420 369c2da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
from sentence_transformers import SentenceTransformer, util
from nltk.tokenize import sent_tokenize
from util import select_random_from_list
from math import exp
import sys
all_MiniLM_L12_v2 = SentenceTransformer("all-MiniLM-L12-v2")
"""
all-MiniLM-L12-v2 is a sentence embedding model used for tasks involving semantic textual similarity, clustering, semantic search, and information retrieval. They convert a sentence to tensor based on their intent and then matching patterns like cos_sim can be used to compare them to other sentences.
"""
CONNECTION_PHRASES = ["just like that", "for the same", "similarly", "similar to the previous", "for that", "for it"]
CONNECTION_ENCODE = all_MiniLM_L12_v2.encode(CONNECTION_PHRASES,convert_to_tensor=True)
CLEAR_MESSAGES = ["delete", "delete context", "delete history", "clear", "clear context", "clear history", "reset", "reset context", "reset chat", "forget", "forget all"]
CLEAR_MESSAGES_ENCODE = all_MiniLM_L12_v2.encode(CLEAR_MESSAGES,convert_to_tensor=True)
prev_label = ""
prev_query_data = [] # Stores previous context if queries contain ambiguous content that may map to previous responses
confidence_threshold = 0.35 # Default confidence threshold
def generate_confidence_threshold(query: str, base=0.6, decay=0.03, min_threshold=0.25)->float:
"""Generate confidence threshold based on the sentence. Longer sentence lead to lower confidences, so confidence threshold is adjusted based on that.
Parameters:
1. query: Modify threshold based on this sentence
2. base, decay: 0.8*e^(-(decay * no. of words in string))
3. min_threshold: Clamp to minimum to avoid much lower confidence values"""
global confidence_threshold
length = len(query.split())
confidence_threshold = max(base * exp(-decay * length), min_threshold)
# The value of each node contains the following data
# node.value[0] -> intent
# node.value[1] -> label
# node.value[2] -> examples
# node.value[3] -> response
# node.value[4] -> children
def cache_embeddings(tree, model = all_MiniLM_L12_v2)->None:
"""Store the encoded examples as part of the tree itself to avoid repetitive computations.
Parameters:
1. tree: Tree to cache embeddings
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
def _cache_node_embeddings(n):
if isinstance(n.value, tuple) and len(n.value) >= 2:
examples = n.value[2]
n.embedding_cache = model.encode(examples, convert_to_tensor=True)
for child in n.children:
_cache_node_embeddings(child)
_cache_node_embeddings(tree.root)
# SECOND_PERSON_MENTIONS = ["you", "youre", "your", "yours", "yourself", "y'all", "y'all's", "y'all'self", "you're", "your'e""u", "ur", "urs", "urself"]
MY_NAME = ["Dwarakesh","Dwarak","Dwara","Dwaraka"] # Delete for production purposes
def get_user_query(message="", model = all_MiniLM_L12_v2)->str:
"""Separate function to get input from user.
Parameters:
1. message: Show message to user before recieving input (Default: empty)
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
query = input(message).lower().strip()
while query == "":
query = input(message).lower().strip()
query = query.replace(" "," ") # Remove double spaces
for spm in MY_NAME: # Remove second person mentions
query = query.replace(spm,"") # Replace with 'you' for second person
generate_confidence_threshold(query)
query_encode = model.encode(query, convert_to_tensor=True)
clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
if clear_intent > confidence_threshold:
return None
return query
def _calculate_single_level(user_embed,predicted_intent):
"""Calculate predictions for the children of a single node. Each node contains a list of nodes as its children.
Parameters:
1. user_embed: User query converted to tensor
2. predicted_intent: Calculate for children of this node"""
categories = predicted_intent.children # List of node objects
predicted_intent = None
high_intent = 0
for category in categories:
if category.embedding_cache is None:
raise ValueError("Embedding cache missing. Call cache_embeddings() on the tree first")
score = util.cos_sim(user_embed, category.embedding_cache).max().item()
if score > high_intent:
high_intent = score
predicted_intent = category # Node object
return (predicted_intent,high_intent) # Returns the child node with the highest prediction confidence and the confidence value
def _store_prev_data(predicted_intent):
"""Store the previous computed data path.
Parameters:
1. predicted_intent: Store previous data w.r.t this node"""
# Mutating global prev_query_data
prev_query_data.clear()
prev_context_treenode = predicted_intent
while prev_context_treenode.parent: # Stop at tree root
prev_query_data.append(prev_context_treenode)
prev_context_treenode = prev_context_treenode.parent
def h_pass(tree, user_embed, predicted_intent = None)->tuple:
"""Use the model to pass through the tree to compare it with the user query in a hierarchical manner and return an output.
Parameters:
1. tree: Which tree to pass through hierarchically
2. user_embed: User input converted to tensor
3. predicted_intent: Where to start the pass from (Default: Root of the tree)"""
global prev_label
predicted_intent = tree.root if predicted_intent == None else predicted_intent
predicted_intent_parent = None
high_intent = 0
passed_once = False
pass_through_intent = {}
while predicted_intent.children: # If the node has children, check for the child with the highest confidence value
predicted_intent_parent = predicted_intent
predicted_intent, high_intent = _calculate_single_level(user_embed,predicted_intent)
pass_through_intent[predicted_intent] = high_intent # Store the confidence value of the current node
if passed_once: # If the data didn't pass even once, then don't store it
_store_prev_data(predicted_intent_parent) # Storing previous data w.r.t parent node as context is changed from current node
if high_intent < confidence_threshold: # If highest confidence value is still too low, stop.
prev_label = predicted_intent_parent.value[1]
return (predicted_intent, predicted_intent_parent, high_intent, passed_once, False, pass_through_intent) # If the confidence value is low, stop
passed_once = True
_store_prev_data(predicted_intent)
prev_label = predicted_intent.value[1]
return (predicted_intent, predicted_intent_parent, high_intent, passed_once, True, pass_through_intent)
def query_pass(tree, user_input, model=all_MiniLM_L12_v2)->list:
"""Separate multiple queries into separate single ones, analyze relation between them if any, and process them to give an output while storing incomplete query outputs in non-leaf list, which contains the current level of context.
Parameters:
1. tree: Which tree to pass through hierarchically
2. user_input: User input that may contain one or more queries as a string
3. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
queries = sent_tokenize(user_input)
user_embeddings = [model.encode(query,convert_to_tensor=True) for query in queries]
result = []
label = prev_label
for i in range(len(queries)):
generate_confidence_threshold(queries[i])
pass_value = (None, None, 0, False, False, None)
# pass_value[0] -> current predicted intention (node)
# pass_value[1] -> parent node of current predicted intention
# pass_value[2] -> confidence level
# pass_value[3] -> has the query passed through the model atleast once?
# pass_value[4] -> has the query reached a leaf node?
# pass_value[5] -> confidence values of traversal for query [DEBUGGING PURPOSES]
# Acquiring data from previous query if the query has words matching with connecting phrases
conn_sim = util.cos_sim(user_embeddings[i], CONNECTION_ENCODE).max().item()
if conn_sim > confidence_threshold:
queries[i] = queries[i] + label
user_embeddings[i] = model.encode(queries[i], convert_to_tensor=True)
# Pass values through the root node and the nodes that have the current context
pass_value_root = h_pass(tree,user_embeddings[i]) # Passing through root node
pass_value_nonleaf = [h_pass(tree,user_embeddings[i],j) for j in prev_query_data] # Passing through nodes that have current context
all_nodes = [pass_value_root] + pass_value_nonleaf # List of all nodes that have been passed through
pass_value = max(all_nodes, key=lambda x: x[2]) # Maximum confidence node for available context. Root is always a context.
print(f"Query reach confidence: {[i[5] for i in all_nodes]}", file=sys.stderr) # DEBUGGING PURPOSES
if pass_value[3]: # If the query has passed at least once, ask for data and store current result
if not pass_value[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
label = pass_value[1].value[1]
result.append(f"{pass_value[1].value[3]}")
# continue
else: # Query has reached a leaf node
label = pass_value[0].value[1]
result.append(pass_value[0].value[3])
# continue
else: # Query has not passed even once. Check if it works when previous context is available
for parent_context in prev_query_data:
pass_value_context = h_pass(tree, user_embeddings[i], parent_context)
if pass_value_context[3]: # Check if it has passed at least once
# If it has passed, then the query is valid
if not pass_value_context[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
label = pass_value_context[1].value[1]
result.append(f"What are you looking for in {pass_value_context[1].value[0]}? {pass_value_context[1].value[3]}")
else:
label = pass_value_context[0].value[1]
result.append(pass_value_context[0].value[3])
break # The else block won't be executed if code reaches here
else: # The else statement of a for loop will execute only if the loop completes, and won't execute when broken by "break"
result.append(f"I don't quite understand what you are trying to ask by \"{queries[i]}\"")
# continue
# End of "for" loop processing queries
# Finally, return result. A list of responses same as the length of queries.
return result
def process_user_query(query: str, model = all_MiniLM_L12_v2)->str:
"""Separate function to get input from user.
Parameters:
1. message: Show message to user before recieving input (Default: empty)
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
query = query.lower().strip()
generate_confidence_threshold(query)
query_encode = model.encode(query, convert_to_tensor=True)
clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
if clear_intent > confidence_threshold:
return None
return query
def interact_with_user(tree_data, user_input: str) -> str:
"""Handles a single user query and returns a response string."""
user_input = process_user_query(user_input)
all_results = []
if user_input: # If not empty or command
results = query_pass(tree_data, user_input)
for result in results:
# return f"{select_random_from_list(result)}\nContext window: {prev_query_data}"
all_results.append(f"{select_random_from_list(result)}")
print(f"Previous query data: {prev_query_data}", file=sys.stderr)
return all_results
else:
# Mutating global variables: Clearing context and recent history on command
prev_query_data.clear()
print(f"Previous query data: {prev_query_data}", file=sys.stderr)
return ["Cleared previous context"]
|