Spaces:
Running
Running
| from sentence_transformers import SentenceTransformer, util | |
| from nltk.tokenize import sent_tokenize | |
| from util import select_random_from_list | |
| from math import exp | |
| import sys | |
| all_MiniLM_L12_v2 = SentenceTransformer("all-MiniLM-L12-v2") | |
| """ | |
| all-MiniLM-L12-v2 is a sentence embedding model used for tasks involving semantic textual similarity, clustering, semantic search, and information retrieval. They convert a sentence to tensor based on their intent and then matching patterns like cos_sim can be used to compare them to other sentences. | |
| """ | |
| CONNECTION_PHRASES = ["just like that", "for the same", "similarly", "similar to the previous", "for that", "for it"] | |
| CONNECTION_ENCODE = all_MiniLM_L12_v2.encode(CONNECTION_PHRASES,convert_to_tensor=True) | |
| CLEAR_MESSAGES = ["delete", "delete context", "delete history", "clear", "clear context", "clear history", "reset", "reset context", "reset chat", "forget", "forget all"] | |
| CLEAR_MESSAGES_ENCODE = all_MiniLM_L12_v2.encode(CLEAR_MESSAGES,convert_to_tensor=True) | |
| prev_label = "" | |
| prev_query_data = [] # Stores previous context if queries contain ambiguous content that may map to previous responses | |
| confidence_threshold = 0.35 # Default confidence threshold | |
| def generate_confidence_threshold(query: str, base=0.6, decay=0.03, min_threshold=0.25)->float: | |
| """Generate confidence threshold based on the sentence. Longer sentence lead to lower confidences, so confidence threshold is adjusted based on that. | |
| Parameters: | |
| 1. query: Modify threshold based on this sentence | |
| 2. base, decay: 0.8*e^(-(decay * no. of words in string)) | |
| 3. min_threshold: Clamp to minimum to avoid much lower confidence values""" | |
| global confidence_threshold | |
| length = len(query.split()) | |
| confidence_threshold = max(base * exp(-decay * length), min_threshold) | |
| # The value of each node contains the following data | |
| # node.value[0] -> intent | |
| # node.value[1] -> label | |
| # node.value[2] -> examples | |
| # node.value[3] -> response | |
| # node.value[4] -> children | |
| def cache_embeddings(tree, model = all_MiniLM_L12_v2)->None: | |
| """Store the encoded examples as part of the tree itself to avoid repetitive computations. | |
| Parameters: | |
| 1. tree: Tree to cache embeddings | |
| 2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)""" | |
| def _cache_node_embeddings(n): | |
| if isinstance(n.value, tuple) and len(n.value) >= 2: | |
| examples = n.value[2] | |
| n.embedding_cache = model.encode(examples, convert_to_tensor=True) | |
| for child in n.children: | |
| _cache_node_embeddings(child) | |
| _cache_node_embeddings(tree.root) | |
| # SECOND_PERSON_MENTIONS = ["you", "youre", "your", "yours", "yourself", "y'all", "y'all's", "y'all'self", "you're", "your'e""u", "ur", "urs", "urself"] | |
| MY_NAME = ["Dwarakesh","Dwarak","Dwara","Dwaraka"] # Delete for production purposes | |
| def get_user_query(message="", model = all_MiniLM_L12_v2)->str: | |
| """Separate function to get input from user. | |
| Parameters: | |
| 1. message: Show message to user before recieving input (Default: empty) | |
| 2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)""" | |
| query = input(message).lower().strip() | |
| while query == "": | |
| query = input(message).lower().strip() | |
| query = query.replace(" "," ") # Remove double spaces | |
| for spm in MY_NAME: # Remove second person mentions | |
| query = query.replace(spm,"") # Replace with 'you' for second person | |
| generate_confidence_threshold(query) | |
| query_encode = model.encode(query, convert_to_tensor=True) | |
| clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item() | |
| if clear_intent > confidence_threshold: | |
| return None | |
| return query | |
| def _calculate_single_level(user_embed,predicted_intent): | |
| """Calculate predictions for the children of a single node. Each node contains a list of nodes as its children. | |
| Parameters: | |
| 1. user_embed: User query converted to tensor | |
| 2. predicted_intent: Calculate for children of this node""" | |
| categories = predicted_intent.children # List of node objects | |
| predicted_intent = None | |
| high_intent = 0 | |
| for category in categories: | |
| if category.embedding_cache is None: | |
| raise ValueError("Embedding cache missing. Call cache_embeddings() on the tree first") | |
| score = util.cos_sim(user_embed, category.embedding_cache).max().item() | |
| if score > high_intent: | |
| high_intent = score | |
| predicted_intent = category # Node object | |
| return (predicted_intent,high_intent) # Returns the child node with the highest prediction confidence and the confidence value | |
| def _store_prev_data(predicted_intent): | |
| """Store the previous computed data path. | |
| Parameters: | |
| 1. predicted_intent: Store previous data w.r.t this node""" | |
| # Mutating global prev_query_data | |
| prev_query_data.clear() | |
| prev_context_treenode = predicted_intent | |
| while prev_context_treenode.parent: # Stop at tree root | |
| prev_query_data.append(prev_context_treenode) | |
| prev_context_treenode = prev_context_treenode.parent | |
| def h_pass(tree, user_embed, predicted_intent = None)->tuple: | |
| """Use the model to pass through the tree to compare it with the user query in a hierarchical manner and return an output. | |
| Parameters: | |
| 1. tree: Which tree to pass through hierarchically | |
| 2. user_embed: User input converted to tensor | |
| 3. predicted_intent: Where to start the pass from (Default: Root of the tree)""" | |
| global prev_label | |
| predicted_intent = tree.root if predicted_intent == None else predicted_intent | |
| predicted_intent_parent = None | |
| high_intent = 0 | |
| passed_once = False | |
| pass_through_intent = {} | |
| while predicted_intent.children: # If the node has children, check for the child with the highest confidence value | |
| predicted_intent_parent = predicted_intent | |
| predicted_intent, high_intent = _calculate_single_level(user_embed,predicted_intent) | |
| pass_through_intent[predicted_intent] = high_intent # Store the confidence value of the current node | |
| if passed_once: # If the data didn't pass even once, then don't store it | |
| _store_prev_data(predicted_intent_parent) # Storing previous data w.r.t parent node as context is changed from current node | |
| if high_intent < confidence_threshold: # If highest confidence value is still too low, stop. | |
| prev_label = predicted_intent_parent.value[1] | |
| return (predicted_intent, predicted_intent_parent, high_intent, passed_once, False, pass_through_intent) # If the confidence value is low, stop | |
| passed_once = True | |
| _store_prev_data(predicted_intent) | |
| prev_label = predicted_intent.value[1] | |
| return (predicted_intent, predicted_intent_parent, high_intent, passed_once, True, pass_through_intent) | |
| def query_pass(tree, user_input, model=all_MiniLM_L12_v2)->list: | |
| """Separate multiple queries into separate single ones, analyze relation between them if any, and process them to give an output while storing incomplete query outputs in non-leaf list, which contains the current level of context. | |
| Parameters: | |
| 1. tree: Which tree to pass through hierarchically | |
| 2. user_input: User input that may contain one or more queries as a string | |
| 3. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)""" | |
| queries = sent_tokenize(user_input) | |
| user_embeddings = [model.encode(query,convert_to_tensor=True) for query in queries] | |
| result = [] | |
| label = prev_label | |
| for i in range(len(queries)): | |
| generate_confidence_threshold(queries[i]) | |
| pass_value = (None, None, 0, False, False, None) | |
| # pass_value[0] -> current predicted intention (node) | |
| # pass_value[1] -> parent node of current predicted intention | |
| # pass_value[2] -> confidence level | |
| # pass_value[3] -> has the query passed through the model atleast once? | |
| # pass_value[4] -> has the query reached a leaf node? | |
| # pass_value[5] -> confidence values of traversal for query [DEBUGGING PURPOSES] | |
| # Acquiring data from previous query if the query has words matching with connecting phrases | |
| conn_sim = util.cos_sim(user_embeddings[i], CONNECTION_ENCODE).max().item() | |
| if conn_sim > confidence_threshold: | |
| queries[i] = queries[i] + label | |
| user_embeddings[i] = model.encode(queries[i], convert_to_tensor=True) | |
| # Pass values through the root node and the nodes that have the current context | |
| pass_value_root = h_pass(tree,user_embeddings[i]) # Passing through root node | |
| pass_value_nonleaf = [h_pass(tree,user_embeddings[i],j) for j in prev_query_data] # Passing through nodes that have current context | |
| all_nodes = [pass_value_root] + pass_value_nonleaf # List of all nodes that have been passed through | |
| pass_value = max(all_nodes, key=lambda x: x[2]) # Maximum confidence node for available context. Root is always a context. | |
| print(f"Query reach confidence: {[i[5] for i in all_nodes]}", file=sys.stderr) # DEBUGGING PURPOSES | |
| if pass_value[3]: # If the query has passed at least once, ask for data and store current result | |
| if not pass_value[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context | |
| label = pass_value[1].value[1] | |
| result.append(f"{pass_value[1].value[3]}") | |
| # continue | |
| else: # Query has reached a leaf node | |
| label = pass_value[0].value[1] | |
| result.append(pass_value[0].value[3]) | |
| # continue | |
| else: # Query has not passed even once. Check if it works when previous context is available | |
| for parent_context in prev_query_data: | |
| pass_value_context = h_pass(tree, user_embeddings[i], parent_context) | |
| if pass_value_context[3]: # Check if it has passed at least once | |
| # If it has passed, then the query is valid | |
| if not pass_value_context[4]: # If pass has not reached a leaf node, then ask for more data from the user and keep parent context | |
| label = pass_value_context[1].value[1] | |
| result.append(f"What are you looking for in {pass_value_context[1].value[0]}? {pass_value_context[1].value[3]}") | |
| else: | |
| label = pass_value_context[0].value[1] | |
| result.append(pass_value_context[0].value[3]) | |
| break # The else block won't be executed if code reaches here | |
| else: # The else statement of a for loop will execute only if the loop completes, and won't execute when broken by "break" | |
| result.append(f"I don't quite understand what you are trying to ask by \"{queries[i]}\"") | |
| # continue | |
| # End of "for" loop processing queries | |
| # Finally, return result. A list of responses same as the length of queries. | |
| return result | |
| def process_user_query(query: str, model = all_MiniLM_L12_v2)->str: | |
| """Separate function to get input from user. | |
| Parameters: | |
| 1. message: Show message to user before recieving input (Default: empty) | |
| 2. model: Which model to use to encode (Default: Global model all_MiniLM_L12_v2)""" | |
| query = query.lower().strip() | |
| generate_confidence_threshold(query) | |
| query_encode = model.encode(query, convert_to_tensor=True) | |
| clear_intent = util.cos_sim(query_encode,CLEAR_MESSAGES_ENCODE).max().item() | |
| if clear_intent > confidence_threshold: | |
| return None | |
| return query | |
| def interact_with_user(tree_data, user_input: str) -> str: | |
| """Handles a single user query and returns a response string.""" | |
| user_input = process_user_query(user_input) | |
| all_results = [] | |
| if user_input: # If not empty or command | |
| results = query_pass(tree_data, user_input) | |
| for result in results: | |
| # return f"{select_random_from_list(result)}\nContext window: {prev_query_data}" | |
| all_results.append(f"{select_random_from_list(result)}") | |
| print(f"Previous query data: {prev_query_data}", file=sys.stderr) | |
| return all_results | |
| else: | |
| # Mutating global variables: Clearing context and recent history on command | |
| prev_query_data.clear() | |
| print(f"Previous query data: {prev_query_data}", file=sys.stderr) | |
| return ["Cleared previous context"] | |