Dwarakesh-V's picture
Deleted unwanted files, updated portfolio info, and added 3rd person mentions.
c221420
from sentence_transformers import SentenceTransformer, util
from nltk.tokenize import sent_tokenize
from util import select_random_from_list
from math import exp
import sys
all_MiniLM_L12_v2 = SentenceTransformer("all-MiniLM-L12-v2")
"""
all-MiniLM-L12-v2 is a sentence embedding model used for tasks involving semantic textual similarity, clustering, semantic search, and information retrieval. They convert a sentence to tensor based on their intent and then matching patterns like cos_sim can be used to compare them to other sentences.
"""
CONNECTION_PHRASES = ["just like that", "for the same", "similarly", "similar to the previous", "for that", "for it"]
CONNECTION_ENCODE = all_MiniLM_L12_v2.encode(CONNECTION_PHRASES,convert_to_tensor=True)
CLEAR_MESSAGES = ["delete", "delete context", "delete history", "clear", "clear context", "clear history", "reset", "reset context", "reset chat", "forget", "forget all"]
CLEAR_MESSAGES_ENCODE = all_MiniLM_L12_v2.encode(CLEAR_MESSAGES,convert_to_tensor=True)
prev_label = ""
prev_query_data = [] # Stores previous context if queries contain ambiguous content that may map to previous responses
confidence_threshold = 0.35 # Default confidence threshold
def generate_confidence_threshold(query: str, base=0.6, decay=0.03, min_threshold=0.25)->float:
"""Generate confidence threshold based on the sentence. Longer sentence lead to lower confidences, so confidence threshold is adjusted based on that.
Parameters:
1. query: Modify threshold based on this sentence
2. base, decay: 0.8*e^(-(decay * no. of words in string))
3. min_threshold: Clamp to minimum to avoid much lower confidence values"""
global confidence_threshold
length = len(query.split())
confidence_threshold = max(base * exp(-decay * length), min_threshold)
# The value of each node contains the following data
# node.value[0] -> intent
# node.value[1] -> label
# node.value[2] -> examples
# node.value[3] -> response
# node.value[4] -> children
def cache_embeddings(tree, model = all_MiniLM_L12_v2)->None:
"""Store the encoded examples as part of the tree itself to avoid repetitive computations.
Parameters:
1. tree: Tree to cache embeddings
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
def _cache_node_embeddings(n):
if isinstance(n.value, tuple) and len(n.value) >= 2:
examples = n.value[2]
n.embedding_cache = model.encode(examples, convert_to_tensor=True)
for child in n.children:
_cache_node_embeddings(child)
_cache_node_embeddings(tree.root)
# SECOND_PERSON_MENTIONS = ["you", "youre", "your", "yours", "yourself", "y'all", "y'all's", "y'all'self", "you're", "your'e""u", "ur", "urs", "urself"]
MY_NAME = ["Dwarakesh","Dwarak","Dwara","Dwaraka"] # Delete for production purposes
def get_user_query(message="", model = all_MiniLM_L12_v2)->str:
"""Separate function to get input from user.
Parameters:
1. message: Show message to user before recieving input (Default: empty)
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
query = input(message).lower().strip()
while query == "":
query = input(message).lower().strip()
query = query.replace(" "," ") # Remove double spaces
for spm in MY_NAME: # Remove second person mentions
query = query.replace(spm,"") # Replace with 'you' for second person
generate_confidence_threshold(query)
query_encode = model.encode(query, convert_to_tensor=True)
clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
if clear_intent > confidence_threshold:
return None
return query
def _calculate_single_level(user_embed,predicted_intent):
"""Calculate predictions for the children of a single node. Each node contains a list of nodes as its children.
Parameters:
1. user_embed: User query converted to tensor
2. predicted_intent: Calculate for children of this node"""
categories = predicted_intent.children # List of node objects
predicted_intent = None
high_intent = 0
for category in categories:
if category.embedding_cache is None:
raise ValueError("Embedding cache missing. Call cache_embeddings() on the tree first")
score = util.cos_sim(user_embed, category.embedding_cache).max().item()
if score > high_intent:
high_intent = score
predicted_intent = category # Node object
return (predicted_intent,high_intent) # Returns the child node with the highest prediction confidence and the confidence value
def _store_prev_data(predicted_intent):
"""Store the previous computed data path.
Parameters:
1. predicted_intent: Store previous data w.r.t this node"""
# Mutating global prev_query_data
prev_query_data.clear()
prev_context_treenode = predicted_intent
while prev_context_treenode.parent: # Stop at tree root
prev_query_data.append(prev_context_treenode)
prev_context_treenode = prev_context_treenode.parent
def h_pass(tree, user_embed, predicted_intent = None)->tuple:
"""Use the model to pass through the tree to compare it with the user query in a hierarchical manner and return an output.
Parameters:
1. tree: Which tree to pass through hierarchically
2. user_embed: User input converted to tensor
3. predicted_intent: Where to start the pass from (Default: Root of the tree)"""
global prev_label
predicted_intent = tree.root if predicted_intent == None else predicted_intent
predicted_intent_parent = None
high_intent = 0
passed_once = False
pass_through_intent = {}
while predicted_intent.children: # If the node has children, check for the child with the highest confidence value
predicted_intent_parent = predicted_intent
predicted_intent, high_intent = _calculate_single_level(user_embed,predicted_intent)
pass_through_intent[predicted_intent] = high_intent # Store the confidence value of the current node
if passed_once: # If the data didn't pass even once, then don't store it
_store_prev_data(predicted_intent_parent) # Storing previous data w.r.t parent node as context is changed from current node
if high_intent < confidence_threshold: # If highest confidence value is still too low, stop.
prev_label = predicted_intent_parent.value[1]
return (predicted_intent, predicted_intent_parent, high_intent, passed_once, False, pass_through_intent) # If the confidence value is low, stop
passed_once = True
_store_prev_data(predicted_intent)
prev_label = predicted_intent.value[1]
return (predicted_intent, predicted_intent_parent, high_intent, passed_once, True, pass_through_intent)
def query_pass(tree, user_input, model=all_MiniLM_L12_v2)->list:
"""Separate multiple queries into separate single ones, analyze relation between them if any, and process them to give an output while storing incomplete query outputs in non-leaf list, which contains the current level of context.
Parameters:
1. tree: Which tree to pass through hierarchically
2. user_input: User input that may contain one or more queries as a string
3. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
queries = sent_tokenize(user_input)
user_embeddings = [model.encode(query,convert_to_tensor=True) for query in queries]
result = []
label = prev_label
for i in range(len(queries)):
generate_confidence_threshold(queries[i])
pass_value = (None, None, 0, False, False, None)
# pass_value[0] -> current predicted intention (node)
# pass_value[1] -> parent node of current predicted intention
# pass_value[2] -> confidence level
# pass_value[3] -> has the query passed through the model atleast once?
# pass_value[4] -> has the query reached a leaf node?
# pass_value[5] -> confidence values of traversal for query [DEBUGGING PURPOSES]
# Acquiring data from previous query if the query has words matching with connecting phrases
conn_sim = util.cos_sim(user_embeddings[i], CONNECTION_ENCODE).max().item()
if conn_sim > confidence_threshold:
queries[i] = queries[i] + label
user_embeddings[i] = model.encode(queries[i], convert_to_tensor=True)
# Pass values through the root node and the nodes that have the current context
pass_value_root = h_pass(tree,user_embeddings[i]) # Passing through root node
pass_value_nonleaf = [h_pass(tree,user_embeddings[i],j) for j in prev_query_data] # Passing through nodes that have current context
all_nodes = [pass_value_root] + pass_value_nonleaf # List of all nodes that have been passed through
pass_value = max(all_nodes, key=lambda x: x[2]) # Maximum confidence node for available context. Root is always a context.
print(f"Query reach confidence: {[i[5] for i in all_nodes]}", file=sys.stderr) # DEBUGGING PURPOSES
if pass_value[3]: # If the query has passed at least once, ask for data and store current result
if not pass_value[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
label = pass_value[1].value[1]
result.append(f"{pass_value[1].value[3]}")
# continue
else: # Query has reached a leaf node
label = pass_value[0].value[1]
result.append(pass_value[0].value[3])
# continue
else: # Query has not passed even once. Check if it works when previous context is available
for parent_context in prev_query_data:
pass_value_context = h_pass(tree, user_embeddings[i], parent_context)
if pass_value_context[3]: # Check if it has passed at least once
# If it has passed, then the query is valid
if not pass_value_context[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context
label = pass_value_context[1].value[1]
result.append(f"What are you looking for in {pass_value_context[1].value[0]}? {pass_value_context[1].value[3]}")
else:
label = pass_value_context[0].value[1]
result.append(pass_value_context[0].value[3])
break # The else block won't be executed if code reaches here
else: # The else statement of a for loop will execute only if the loop completes, and won't execute when broken by "break"
result.append(f"I don't quite understand what you are trying to ask by \"{queries[i]}\"")
# continue
# End of "for" loop processing queries
# Finally, return result. A list of responses same as the length of queries.
return result
def process_user_query(query: str, model = all_MiniLM_L12_v2)->str:
"""Separate function to get input from user.
Parameters:
1. message: Show message to user before recieving input (Default: empty)
2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)"""
query = query.lower().strip()
generate_confidence_threshold(query)
query_encode = model.encode(query, convert_to_tensor=True)
clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item()
if clear_intent > confidence_threshold:
return None
return query
def interact_with_user(tree_data, user_input: str) -> str:
"""Handles a single user query and returns a response string."""
user_input = process_user_query(user_input)
all_results = []
if user_input: # If not empty or command
results = query_pass(tree_data, user_input)
for result in results:
# return f"{select_random_from_list(result)}\nContext window: {prev_query_data}"
all_results.append(f"{select_random_from_list(result)}")
print(f"Previous query data: {prev_query_data}", file=sys.stderr)
return all_results
else:
# Mutating global variables: Clearing context and recent history on command
prev_query_data.clear()
print(f"Previous query data: {prev_query_data}", file=sys.stderr)
return ["Cleared previous context"]