Spaces:
Sleeping
Sleeping
File size: 6,317 Bytes
a30d85d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | import datetime
import json
import numpy as np
from database import db
def get_user_weights(user_id, default_weights):
"""
Get the weights for the given user. If the user doesn't have weights,
return the default weights.
Parameters:
user_id (str): The ID of the user.
default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights.
Returns:
numpy.ndarray: The weights for the user.
"""
weights = db.get_user_weights(user_id, default_weights)
if weights is not None:
return weights.copy()
else:
return default_weights.copy()
def user_weights_exist(user_id):
"""
Check if weights exist for the given user.
Parameters:
user_id (str): The ID of the user.
Returns:
bool: True if weights exist for the user, False otherwise.
"""
return db.user_weights_exist(user_id)
def save_user_weights(user_id, weights):
"""
Save the weights for the given user.
Parameters:
user_id (str): The ID of the user.
weights (numpy.ndarray): The weights to save.
Returns:
bool: True if the weights were saved successfully, False otherwise.
"""
return db.save_user_weights(user_id, weights)
def update_user_weights(user_id, tag_indices, new_weights, default_weights):
"""
Update the weights for the given user at the specified tag indices.
Parameters:
user_id (str): The ID of the user.
tag_indices (list): The indices of the tags to update.
new_weights (list): The new weights for the tags.
default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights.
Returns:
bool: True if the weights were updated successfully, False otherwise.
"""
weights = get_user_weights(user_id, default_weights)
# Update the weights
for i, tag_index in enumerate(tag_indices):
for j in range(len(weights)):
weights[j][tag_index] = new_weights[i]
return save_user_weights(user_id, weights)
def update_weights_from_query(user_id, query_tags, feature_names, default_weights):
"""
Update weights based on user query. For each tag in the query, if a destination
has that tag with a weight > 0, increase the weight by 5.
Parameters:
user_id (str): The ID of the user.
query_tags (list): The tags from the user's query.
feature_names (numpy.ndarray): The names of all features (tags).
default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights.
Returns:
bool: True if the weights were updated successfully, False otherwise.
"""
weights = get_user_weights(user_id, default_weights)
print(f"weights: {weights}")
# Find indices of query tags
tag_indices = []
for tag in query_tags:
# Find the index of the tag in feature_names
matches = np.where(feature_names == tag.lower())[0]
if len(matches) > 0:
tag_indices.append(matches[0])
# Update weights for destinations that have these tags
for tag_index in tag_indices:
for dest_index in range(len(weights)):
# If the destination has this tag with weight > 0, increase by 5
if weights[dest_index][tag_index] > 0:
weights[dest_index][tag_index] += 5
return save_user_weights(user_id, weights)
def update_weights_from_feedback(user_id, destination_id, tag_id, rating, default_weights):
"""
Update weights based on user feedback (star rating).
Parameters:
user_id (str): The ID of the user.
destination_id (int): The ID of the destination.
tag_id (int): The ID of the tag.
rating (int): The star rating (1-5).
default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights.
Returns:
bool: True if the weights were updated successfully, False otherwise.
"""
weights = get_user_weights(user_id, default_weights)
# Adjust weight based on rating
if rating == 5:
weights[destination_id][tag_id] += 5
elif rating == 4:
weights[destination_id][tag_id] += 3
elif rating == 3:
weights[destination_id][tag_id] += 1
elif rating == 2:
weights[destination_id][tag_id] -= 3
elif rating == 1:
weights[destination_id][tag_id] -= 5
return save_user_weights(user_id, weights)
def get_user_metadata(user_id):
"""
Get the metadata for a specific user.
Parameters:
user_id (str): The ID of the user.
Returns:
dict: The metadata for the user.
"""
return db.get_user_metadata(user_id)
def update_user_metadata(user_id, metadata):
"""
Update the metadata for the given user.
Parameters:
user_id (str): The ID of the user.
metadata (dict): The metadata for the user.
Returns:
bool: True if the metadata was updated successfully, False otherwise.
"""
# Get existing metadata
existing_metadata = db.get_user_metadata(user_id)
# Update with new metadata
existing_metadata.update(metadata)
# Save updated metadata
return db.save_user_metadata(user_id, existing_metadata)
def track_question_tags(user_id, question_tags):
"""
Track the tags from a user's question, keeping the last 5 questions.
Parameters:
user_id (str): The ID of the user.
question_tags (list): The tags from the user's question.
Returns:
bool: True if the tags were tracked successfully, False otherwise.
"""
# Get existing metadata
metadata = db.get_user_metadata(user_id)
# Initialize recent_tags if it doesn't exist
if "recent_tags" not in metadata:
metadata["recent_tags"] = []
# Add new tags to the beginning of the list
metadata["recent_tags"].insert(0, {
"timestamp": str(datetime.datetime.now()),
"tags": question_tags
})
# Keep only the last 5 entries
metadata["recent_tags"] = metadata["recent_tags"][:5]
# Save updated metadata
return db.save_user_metadata(user_id, metadata)
def get_all_users():
"""
Get a list of all users.
Returns:
list: A list of all user IDs.
"""
return db.get_all_users()
|