Spaces:
Build error
Build error
| import datetime | |
| import numpy as np | |
| from src.apis.storage.database import db | |
| import numpy as np | |
| import random | |
| import numpy as np | |
| from src.utils.preprocessing import destinations, sorted_tags_dict, feature_names | |
| from src.utils.preprocessing import vectorizer, weights_bias_vector | |
| from src.utils.kner_onnx import onnx_predictor | |
| def get_onnx_predictor(question: str): | |
| original_sentence, predicted_tags = onnx_predictor.predict(question) | |
| return original_sentence, predicted_tags | |
| def get_user_weights(user_id): | |
| weights = db.get_user_weights(user_id) | |
| return weights.copy() | |
| def get_destination_index(destination_name): | |
| return destinations[destinations["name"] == destination_name].index[0] | |
| def get_tag_index(tag_name): | |
| all_tags = feature_names.tolist() | |
| return all_tags.index(tag_name.lower()) | |
| def user_weights_exist(user_id): | |
| """ | |
| Check if weights exist for the given user. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| Returns: | |
| bool: True if weights exist for the user, False otherwise. | |
| """ | |
| return db.user_weights_exist(user_id) | |
| def save_user_weights(user_id, weights): | |
| """ | |
| Save the weights for the given user. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| weights (numpy.ndarray): The weights to save. | |
| Returns: | |
| bool: True if the weights were saved successfully, False otherwise. | |
| """ | |
| return db.save_user_weights(user_id, weights) | |
| def update_weights_from_query(user_id, query_tags, feature_names): | |
| """ | |
| Update weights based on user query. For each tag in the query, if a destination | |
| has that tag with a weight > 0, increase the weight by 5. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| query_tags (list): The tags from the user's query. | |
| feature_names (numpy.ndarray): The names of all features (tags). | |
| Returns: | |
| bool: True if the weights were updated successfully, False otherwise. | |
| """ | |
| weights = get_user_weights(user_id) | |
| # Find indices of query tags | |
| tag_indices = [] | |
| for tag in query_tags: | |
| # Find the index of the tag in feature_names | |
| matches = np.where(feature_names == tag.lower())[0] | |
| if len(matches) > 0: | |
| tag_indices.append(matches[0]) | |
| # Update weights for destinations that have these tags | |
| for tag_index in tag_indices: | |
| for dest_index in range(len(weights)): | |
| # If the destination has this tag with weight > 0, increase by 5 | |
| if weights[dest_index][tag_index] > 0: | |
| weights[dest_index][tag_index] += 5 | |
| return save_user_weights(user_id, weights) | |
| def update_weight_from_destination(user_id, destination_id): | |
| destination_tags = db.get_destination_tags(destination_id) | |
| destination_name = db.get_destination_name(destination_id) | |
| destination_index = 0 | |
| all_tags = feature_names.tolist() | |
| for index, name in enumerate(destinations["name"]): | |
| if name == destination_name: | |
| destination_index = index | |
| break | |
| user_weights = get_user_weights(user_id) | |
| for tag in destination_tags: | |
| tag_index = all_tags.index(tag.lower()) | |
| user_weights[destination_index][tag_index] += 5 | |
| save_user_weights(user_id, user_weights) | |
| update_user_metadata(user_id, {"recent_tags": destination_tags}) | |
| def update_weights_from_feedback(user_id, destination_id, tags, rating): | |
| weights = get_user_weights(user_id) | |
| destination_name = db.get_destination_name(destination_id) | |
| destination_index = get_destination_index(destination_name) | |
| for tag in tags: | |
| tag_index = get_tag_index(tag) | |
| # Adjust weight based on rating | |
| if rating == 5: | |
| weights[destination_index][tag_index] += 5 | |
| elif rating == 4: | |
| weights[destination_index][tag_index] += 3 | |
| elif rating == 3: | |
| weights[destination_index][tag_index] += 1 | |
| elif rating == 2: | |
| weights[destination_index][tag_index] -= 3 | |
| elif rating == 1: | |
| weights[destination_index][tag_index] -= 5 | |
| return save_user_weights(user_id, weights) | |
| def get_user_metadata(user_id): | |
| """ | |
| Get the metadata for a specific user. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| Returns: | |
| dict: The metadata for the user. | |
| """ | |
| return db.get_user_metadata(user_id) | |
| def update_user_metadata(user_id, metadata): | |
| """ | |
| Update the metadata for the given user. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| metadata (dict): The metadata for the user. | |
| Returns: | |
| bool: True if the metadata was updated successfully, False otherwise. | |
| """ | |
| # Get existing metadata | |
| existing_metadata = db.get_user_metadata(user_id) | |
| # Update with new metadata | |
| existing_metadata.update(metadata) | |
| # Save updated metadata | |
| return db.save_user_metadata(user_id, existing_metadata) | |
| def track_question_tags(user_id, question_tags): | |
| """ | |
| Track the tags from a user's question, keeping the last 5 questions. | |
| Parameters: | |
| user_id (str): The ID of the user. | |
| question_tags (list): The tags from the user's question. | |
| Returns: | |
| bool: True if the tags were tracked successfully, False otherwise. | |
| """ | |
| # Get existing metadata | |
| metadata = db.get_user_metadata(user_id) | |
| # Initialize recent_tags if it doesn't exist | |
| if "recent_tags" not in metadata: | |
| metadata["recent_tags"] = [] | |
| # Add new tags to the beginning of the list | |
| metadata["recent_tags"].insert( | |
| 0, {"timestamp": str(datetime.datetime.now()), "tags": question_tags} | |
| ) | |
| # Keep only the last 5 entries | |
| metadata["recent_tags"] = metadata["recent_tags"][:5] | |
| # Save updated metadata | |
| return db.save_user_metadata(user_id, metadata) | |
| def get_all_users(): | |
| """ | |
| Get a list of all users. | |
| Returns: | |
| list: A list of all user IDs. | |
| """ | |
| return db.get_all_users() | |
| def get_des_accumulation(question_vector, weights_bias_vector): | |
| accumulation = 0 | |
| for index in range(len(weights_bias_vector)): | |
| if question_vector[index] == 1: | |
| accumulation += weights_bias_vector[index] | |
| return accumulation | |
| def get_destinations_list(question_vector, top_k, user_id): | |
| des = destinations | |
| des = des[1:].reset_index(drop=True) | |
| """ | |
| This function calculates the accumulated scores for each destination based on the given question vector and weights vector. | |
| It then selects the top 5 destinations with the highest scores and returns their names. | |
| Parameters: | |
| question_vector (numpy.ndarray): A 1D numpy array representing the question vector. Each element corresponds to a tag, and its value is 1 if the tag is present in the question, and 0 otherwise. | |
| Returns: | |
| destinations_list: A list of strings representing the names of the top k destinations with the highest scores. | |
| """ | |
| # Use user-specific weights if available, otherwise use default weights | |
| weights_vector = weights_bias_vector | |
| if user_id is not None: | |
| weights_vector = get_user_weights(user_id) | |
| accumulation_dict = {} | |
| if question_vector.size == 0: | |
| return [] | |
| else: | |
| question_vector = question_vector[0] | |
| for index in range(len(weights_vector)): | |
| accumulation = get_des_accumulation(question_vector, weights_vector[index]) | |
| accumulation_dict[str(index)] = accumulation | |
| top_keys = sorted(accumulation_dict, key=accumulation_dict.get, reverse=True) | |
| scores = [accumulation_dict[key] for key in top_keys] | |
| q1_score = np.percentile(scores, 25) | |
| destinations_list = [] | |
| for key in top_keys: | |
| if accumulation_dict[key] > q1_score: | |
| destinations_list.append(des["name"][int(key)]) | |
| return destinations_list[:top_k] | |
| def get_question_vector(question_tags): | |
| """ | |
| Generate a question vector based on the given list of question tags. | |
| Parameters: | |
| question_tags (list): A list of strings representing the tags associated with the question. | |
| Each tag is a word or phrase that describes a characteristic of a destination. | |
| Returns: | |
| numpy.ndarray: A 2D numpy array representing the question vector. | |
| The array is transformed from the input list of question tags using a vectorizer. | |
| Each row in the array corresponds to a tag, and its value is either 0 or 1. | |
| The length of each row is equal to the number of unique tags in the dataset. | |
| """ | |
| question_vector = vectorizer.transform(question_tags).toarray() | |
| return question_vector | |
| def get_recent_tags(user_id): | |
| if not user_id: | |
| return [] | |
| recent_tags = db.get_user_metadata(user_id).get("recent_tags", []) | |
| if not recent_tags: | |
| return [] | |
| tags = [] | |
| for item in recent_tags: | |
| if item.get("tags", []): | |
| tags.extend(item.get("tags", [])) | |
| #sort list unique_tags by frequency of tags | |
| unique_tags = sorted(set(tags), key=tags.count, reverse=True) | |
| if len(unique_tags) <= 5: | |
| return unique_tags | |
| else: | |
| unique_tags_random = unique_tags[:3] + random.sample(unique_tags[3:], 2) | |
| return unique_tags_random | |