"""Class for Database handling using SQLite.""" import logging import os import sqlite3 from sqlite3 import Connection from typing import Any, Dict, List, Optional, Tuple class SQLiteDB: """ Handles SQLite database operations for an image generation comparison experiment. Attributes: db_folder_path (str): Directory where the SQLite database file is stored. experiment_name (str): Name of the experiment (used to name the database file). db_filename (str): Filename of the SQLite database. db_path (str): Full path to the SQLite database file. logger (logging.Logger): Logger instance for logging database operations. """ def __init__(self, db_folder_path: str, experiment_name: str = "arena"): """ Initializes the SQLiteDB instance with dataset and database configuration. Args: db_folder_path (str): Directory where the SQLite database file is stored. experiment_name (str, optional): Name of the experiment. Defaults to "arena". """ self.experiment_name = experiment_name self.db_folder_path = db_folder_path self.db_filename = f"{experiment_name.lower()}.db" self.db_path = os.path.join(db_folder_path, self.db_filename) self.log_path = os.path.join(db_folder_path, "log_db.txt") logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG) self.logger = logging.getLogger() self.conn: Optional[Connection] = None def __del__(self): if self.conn is not None: try: self.conn.close() except Exception: pass def initialize_database(self) -> bool: """ Initializes the SQLite database and creates required tables if they do not exist. Tables: - User: Stores user information. - Preference: Stores user preferences between generated images. Returns: bool: True if successful, False if an error occurred. """ try: db_exists = os.path.exists(self.db_path) if self.conn is None: self.conn = sqlite3.connect(self.db_path, check_same_thread=False) with self.conn: # auto commit cursor = self.conn.cursor() if db_exists: self.logger.info(f"Database already exists at {self.db_path}") return True self.logger.info(f"Creating new database at {self.db_path}") os.makedirs(self.db_folder_path, exist_ok=True) cursor.execute(""" CREATE TABLE IF NOT EXISTS User ( user_id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS Preference ( preference_id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER, reference_id TEXT, model_left_id TEXT, model_right_id TEXT, preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')), timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES User(user_id) ) """) cursor.close() except Exception as e: self.logger.info(f"Creating the database failed with following error: {e}") return False return True def create_user(self, username: str) -> Tuple[Optional[int], str]: """ Creates a new user in the database. Args: username (str): The username of the new user. Returns: Tuple[Optional[int], str]: The user_id of the newly created user, or None if creation failed. If first entry is None then the second contains the exception message. """ ret: Optional[int] = None msg = "" try: if self.conn is None: self.conn = sqlite3.connect(self.db_path) with self.conn: cursor = self.conn.cursor() cursor.execute( """ INSERT INTO User (username) VALUES (?) """, (username,), ) user_id = cursor.lastrowid cursor.close() self.logger.info(f"User '{username}' created with user_id {user_id}") ret = user_id msg = username except sqlite3.IntegrityError: msg = f"User '{username}' already exists." self.logger.warning(msg) except Exception as e: msg = f"Failed to create user '{username}': {e}" self.logger.error(msg) return ret, msg def get_user_id_by_username(self, username: str) -> Optional[int]: """ Checks if a username exists in the database and returns the associated user_id. Args: username (str): The username to look up. Returns: Optional[int]: The user_id if the username exists, None otherwise. """ try: if self.conn is None: self.conn = sqlite3.connect(self.db_path) cursor = self.conn.cursor() cursor.execute( """ SELECT user_id FROM User WHERE username = ? """, (username,), ) result = cursor.fetchone() cursor.close() if result: return result[0] else: return None except Exception as e: self.logger.error(f"Error checking username '{username}': {e}") return None def insert_preference( self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str ) -> Tuple[bool, str]: """ Inserts a new preference entry into the database. Args: user_id (int): ID of the user making the preference. reference_id (str): ID of the reference image. model_left_id (str): ID of the left model's generated image. model_right_id (str): ID of the right model's generated image. preferred_side (str): The preferred side ('left', 'right', or 'tie'). Returns: Tuple[bool, str]: True if insertion was successful, False otherwise with a string message describing the exception. """ msg = "" if preferred_side not in {"left", "right", "tie"}: msg = f"Invalid preferred_side value: {preferred_side}" self.logger.error(msg) return False, msg try: if self.conn is None: self.conn = sqlite3.connect(self.db_path) with self.conn: cursor = self.conn.cursor() cursor.execute( """ INSERT INTO Preference ( user_id, reference_id, model_left_id, model_right_id, preferred_side ) VALUES (?, ?, ?, ?, ?) """, (user_id, reference_id, model_left_id, model_right_id, preferred_side), ) cursor.close() self.logger.info(f"Preference inserted for user_id {user_id}") return True, msg except Exception as e: msg = f"Failed to insert preference: {e}" self.logger.error(msg) return False, msg def get_all_preferences(self) -> List[Tuple]: """ Retrieves all preference entries from the database. Returns: List[Tuple]: A list of tuples representing all preference entries. """ preferences = [] try: if self.conn is None: self.conn = sqlite3.connect(self.db_path) with self.conn: cursor = self.conn.cursor() cursor.execute("SELECT * FROM Preference") preferences = cursor.fetchall() cursor.close() except Exception as e: self.logger.error(f"Failed to retrieve preferences: {e}") return preferences def get_preferences_by_user(self, user_id: int) -> List[Tuple]: """ Retrieves all preference entries for a specific user. Args: user_id (int): The ID of the user. Returns: List[Tuple]: A list of tuples representing the user's preference entries. """ preferences = [] try: if self.conn is None: self.conn = sqlite3.connect(self.db_path) with self.conn: cursor = self.conn.cursor() cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,)) preferences = cursor.fetchall() cursor.close() self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.") except Exception as e: self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}") return preferences def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]: """ Maps a list of preference tuples to a list of dictionaries using the Preference schema. Args: preferences (List[Tuple]): List of tuples from the Preference table. Returns: List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema. """ keys = [ "preference_id", "user_id", "reference_id", "model_left_id", "model_right_id", "preferred_side", "timestamp", ] return [dict(zip(keys, row)) for row in preferences]