TriVenture-Personalize / database.py
ABAO77's picture
Upload 3 files
9173a54 verified
import json
import os
import datetime
import numpy as np
import pymongo
from pymongo import MongoClient
from bson.binary import Binary
from dotenv import load_dotenv
from bson import ObjectId
from loguru import logger
load_dotenv(override=True)
class Database:
def __init__(self):
self.client = None
self.db = None
def connect(self):
"""
Connect to the MongoDB database.
"""
try:
# Connect to MongoDB
self.client = MongoClient(os.environ.get("MONGODB_URL"))
self.db = self.client[os.environ.get("DB_NAME", "scheduling")]
# Create users collection if it doesn't exist
if "user" not in self.db.list_collection_names():
self.db.create_collection("user")
return True
except Exception as e:
logger.error(f"Error connecting to database: {e}")
return False
def close(self):
"""
Close the database connection.
"""
if self.client:
self.client.close()
def create_tables(self):
"""
Create the necessary collections if they don't exist.
Note: This method is kept for backward compatibility but is no longer used
for creating user-specific collections. Instead, user data is stored in collections
that are created on-demand when saving user data.
"""
try:
# Ensure users collection exists
if "user" not in self.db.list_collection_names():
self.db.create_collection("user")
return True
except Exception as e:
logger.error(f"Error creating collections: {e}")
return False
def create_user_table(self, user_id):
"""
Create a collection for a specific user if it doesn't exist.
Parameters:
user_id (str): The ID of the user.
Returns:
bool: True if the collection was created successfully, False otherwise.
"""
try:
# # User data is stored in the "user_data" collection
# if "user" not in self.db.list_collection_names():
# self.db.create_collection("user")
# Add user to the users collection if not exists
self.db.user.update_one(
{"_id": ObjectId(user_id)},
{"$setOnInsert": {"created_at": datetime.datetime.now()}},
upsert=True,
)
return True
except Exception as e:
logger.error(f"Error creating user collection: {e}")
return False
def save_user_weights(self, user_id, weights):
"""
Save user weights to the database.
Parameters:
user_id (str): The ID of the user.
weights (numpy.ndarray): The weights to save.
Returns:
bool: True if the weights were saved successfully, False otherwise.
"""
try:
# Create user entry if it doesn't exist
# if not self.create_user_table(user_id):
# return False
# Convert numpy array to bytes
weights_bytes = Binary(weights.tobytes())
# Update or insert weights
self.db.user.update_one(
{"_id": ObjectId(user_id)},
{
"$set": {
"weights": weights_bytes,
"weights_updated_at": datetime.datetime.now(),
}
},
)
return True
except Exception as e:
logger.error(f"Error saving user weights: {e}")
return False
def get_user_weights(self, user_id, default_weights):
"""
Get user weights from the database.
Parameters:
user_id (str): The ID of the user.
default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights.
Returns:
numpy.ndarray: The weights for the user.
"""
try:
# Get weights from user_data collection
result = self.db.user.find_one({"_id": ObjectId(user_id)})
if result and "weights" in result:
# Convert bytes to numpy array
weights_bytes = result["weights"]
weights = np.frombuffer(weights_bytes, dtype=default_weights.dtype)
return weights.reshape(default_weights.shape)
return default_weights
except Exception as e:
logger.error(f"Error getting user weights: {e}")
return default_weights
def user_weights_exist(self, user_id):
"""
Check if weights exist for the given user.
Parameters:
user_id (str): The ID of the user.
Returns:
bool: True if weights exist for the user, False otherwise.
"""
try:
# Check if weights exist in user_data collection
result = self.db.user.find_one({"_id": ObjectId(user_id)})
return result is not None
except Exception as e:
logger.error(f"Error checking if user weights exist: {e}")
return False
def save_user_metadata(self, user_id, metadata):
"""
Save user metadata to the database.
Parameters:
user_id (str): The ID of the user.
metadata (dict): The metadata to save.
Returns:
bool: True if the metadata was saved successfully, False otherwise.
"""
try:
# Create user entry if it doesn't exist
# if not self.create_user_table(user_id):
# return False
# Update or insert metadata
self.db.user.update_one(
{"_id": ObjectId(user_id)},
{
"$set": {
"metadata": metadata,
"weights_updated_at": datetime.datetime.now(),
}
},
)
return True
except Exception as e:
logger.error(f"Error saving user metadata: {e}")
return False
def get_user_metadata(self, user_id):
"""
Get user metadata from the database.
Parameters:
user_id (str): The ID of the user.
Returns:
dict: The metadata for the user.
"""
try:
# Get metadata from user_data collection
result = self.db.user.find_one({"_id": ObjectId(user_id)})
if result and "metadata" in result:
return result["metadata"]
return {}
except Exception as e:
logger.error(f"Error getting user metadata: {e}")
return {}
def get_all_user_metadata(self):
"""
Get metadata for all users from the database.
Returns:
dict: A dictionary mapping user IDs to their metadata.
"""
try:
# Get metadata for all users
results = self.db.user.find({})
# Build a dictionary of user_id -> metadata
metadata = {}
for result in results:
if "_id" in result and "metadata" in result:
metadata[result["_id"]] = result["metadata"]
return metadata
except Exception as e:
logger.error(f"Error getting all user metadata: {e}")
return {}
def get_user_metadata(self, user_id):
"""
Get metadata for a specific user from the database.
Parameters:
user_id (str): The ID of the user.
Returns:
dict: The metadata for the user.
"""
try:
# Get metadata from user_data collection
result = self.db.user.find_one({"_id": ObjectId(user_id)})
if result and "metadata" in result:
return result["metadata"]
return {}
except Exception as e:
logger.error(f"Error getting user metadata: {e}")
return {}
def get_all_users(self):
"""
Get a list of all users from the database.
Returns:
list: A list of all user IDs.
"""
try:
# Get all users from users collection
results = self.db.user.find({})
# Extract user IDs
user_ids = [str(result["_id"]) for result in results]
print("user_ids", user_ids)
return user_ids
except Exception as e:
logger.error(f"Error getting all users: {e}")
return []
def get_destination_ids(self,destination_names):
"""
Get destination IDs from the database.
Parameters:
destination_names (list): A list of destination names.
Returns:
list: A list of destination IDs.
"""
try:
# Get destination IDs from the database
results = self.db.destination.find({"name": {"$in": destination_names}})
destination_ids = [str(result["_id"]) for result in results]
return destination_ids
except Exception as e:
print(f"Error getting destination IDs: {e}")
return []
# Create a singleton instance
db = Database()