import json import os import datetime import numpy as np import pymongo from pymongo import MongoClient from bson.binary import Binary from dotenv import load_dotenv from bson import ObjectId from loguru import logger load_dotenv(override=True) class Database: def __init__(self): self.client = None self.db = None def connect(self): """ Connect to the MongoDB database. """ try: # Connect to MongoDB self.client = MongoClient(os.environ.get("MONGODB_URL")) self.db = self.client[os.environ.get("DB_NAME", "scheduling")] # Create users collection if it doesn't exist if "user" not in self.db.list_collection_names(): self.db.create_collection("user") return True except Exception as e: logger.error(f"Error connecting to database: {e}") return False def close(self): """ Close the database connection. """ if self.client: self.client.close() def create_tables(self): """ Create the necessary collections if they don't exist. Note: This method is kept for backward compatibility but is no longer used for creating user-specific collections. Instead, user data is stored in collections that are created on-demand when saving user data. """ try: # Ensure users collection exists if "user" not in self.db.list_collection_names(): self.db.create_collection("user") return True except Exception as e: logger.error(f"Error creating collections: {e}") return False def create_user_table(self, user_id): """ Create a collection for a specific user if it doesn't exist. Parameters: user_id (str): The ID of the user. Returns: bool: True if the collection was created successfully, False otherwise. """ try: # # User data is stored in the "user_data" collection # if "user" not in self.db.list_collection_names(): # self.db.create_collection("user") # Add user to the users collection if not exists self.db.user.update_one( {"_id": ObjectId(user_id)}, {"$setOnInsert": {"created_at": datetime.datetime.now()}}, upsert=True, ) return True except Exception as e: logger.error(f"Error creating user collection: {e}") return False def save_user_weights(self, user_id, weights): """ Save user weights to the database. Parameters: user_id (str): The ID of the user. weights (numpy.ndarray): The weights to save. Returns: bool: True if the weights were saved successfully, False otherwise. """ try: # Create user entry if it doesn't exist # if not self.create_user_table(user_id): # return False # Convert numpy array to bytes weights_bytes = Binary(weights.tobytes()) # Update or insert weights self.db.user.update_one( {"_id": ObjectId(user_id)}, { "$set": { "weights": weights_bytes, "weights_updated_at": datetime.datetime.now(), } }, ) return True except Exception as e: logger.error(f"Error saving user weights: {e}") return False def get_user_weights(self, user_id, default_weights): """ Get user weights from the database. Parameters: user_id (str): The ID of the user. default_weights (numpy.ndarray): The default weights to use if the user doesn't have weights. Returns: numpy.ndarray: The weights for the user. """ try: # Get weights from user_data collection result = self.db.user.find_one({"_id": ObjectId(user_id)}) if result and "weights" in result: # Convert bytes to numpy array weights_bytes = result["weights"] weights = np.frombuffer(weights_bytes, dtype=default_weights.dtype) return weights.reshape(default_weights.shape) return default_weights except Exception as e: logger.error(f"Error getting user weights: {e}") return default_weights def user_weights_exist(self, user_id): """ Check if weights exist for the given user. Parameters: user_id (str): The ID of the user. Returns: bool: True if weights exist for the user, False otherwise. """ try: # Check if weights exist in user_data collection result = self.db.user.find_one({"_id": ObjectId(user_id)}) return result is not None except Exception as e: logger.error(f"Error checking if user weights exist: {e}") return False def save_user_metadata(self, user_id, metadata): """ Save user metadata to the database. Parameters: user_id (str): The ID of the user. metadata (dict): The metadata to save. Returns: bool: True if the metadata was saved successfully, False otherwise. """ try: # Create user entry if it doesn't exist # if not self.create_user_table(user_id): # return False # Update or insert metadata self.db.user.update_one( {"_id": ObjectId(user_id)}, { "$set": { "metadata": metadata, "weights_updated_at": datetime.datetime.now(), } }, ) return True except Exception as e: logger.error(f"Error saving user metadata: {e}") return False def get_user_metadata(self, user_id): """ Get user metadata from the database. Parameters: user_id (str): The ID of the user. Returns: dict: The metadata for the user. """ try: # Get metadata from user_data collection result = self.db.user.find_one({"_id": ObjectId(user_id)}) if result and "metadata" in result: return result["metadata"] return {} except Exception as e: logger.error(f"Error getting user metadata: {e}") return {} def get_all_user_metadata(self): """ Get metadata for all users from the database. Returns: dict: A dictionary mapping user IDs to their metadata. """ try: # Get metadata for all users results = self.db.user.find({}) # Build a dictionary of user_id -> metadata metadata = {} for result in results: if "_id" in result and "metadata" in result: metadata[result["_id"]] = result["metadata"] return metadata except Exception as e: logger.error(f"Error getting all user metadata: {e}") return {} def get_user_metadata(self, user_id): """ Get metadata for a specific user from the database. Parameters: user_id (str): The ID of the user. Returns: dict: The metadata for the user. """ try: # Get metadata from user_data collection result = self.db.user.find_one({"_id": ObjectId(user_id)}) if result and "metadata" in result: return result["metadata"] return {} except Exception as e: logger.error(f"Error getting user metadata: {e}") return {} def get_all_users(self): """ Get a list of all users from the database. Returns: list: A list of all user IDs. """ try: # Get all users from users collection results = self.db.user.find({}) # Extract user IDs user_ids = [str(result["_id"]) for result in results] print("user_ids", user_ids) return user_ids except Exception as e: logger.error(f"Error getting all users: {e}") return [] def get_destination_ids(self,destination_names): """ Get destination IDs from the database. Parameters: destination_names (list): A list of destination names. Returns: list: A list of destination IDs. """ try: # Get destination IDs from the database results = self.db.destination.find({"name": {"$in": destination_names}}) destination_ids = [str(result["_id"]) for result in results] return destination_ids except Exception as e: print(f"Error getting destination IDs: {e}") return [] # Create a singleton instance db = Database()