Spaces:
Running
Running
| """Class for Database handling using SQLite.""" | |
| import logging | |
| import os | |
| import sqlite3 | |
| from sqlite3 import Connection | |
| from typing import Any, Dict, List, Optional, Tuple | |
| class SQLiteDB: | |
| """ | |
| Handles SQLite database operations for an image generation comparison experiment. | |
| Attributes: | |
| db_folder_path (str): Directory where the SQLite database file is stored. | |
| experiment_name (str): Name of the experiment (used to name the database file). | |
| db_filename (str): Filename of the SQLite database. | |
| db_path (str): Full path to the SQLite database file. | |
| logger (logging.Logger): Logger instance for logging database operations. | |
| """ | |
| def __init__(self, db_folder_path: str, experiment_name: str = "arena"): | |
| """ | |
| Initializes the SQLiteDB instance with dataset and database configuration. | |
| Args: | |
| db_folder_path (str): Directory where the SQLite database file is stored. | |
| experiment_name (str, optional): Name of the experiment. Defaults to "arena". | |
| """ | |
| self.experiment_name = experiment_name | |
| self.db_folder_path = db_folder_path | |
| self.db_filename = f"{experiment_name.lower()}.db" | |
| self.db_path = os.path.join(db_folder_path, self.db_filename) | |
| self.log_path = os.path.join(db_folder_path, "log_db.txt") | |
| logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG) | |
| self.logger = logging.getLogger() | |
| self.conn: Optional[Connection] = None | |
| def __del__(self): | |
| if self.conn is not None: | |
| try: | |
| self.conn.close() | |
| except Exception: | |
| pass | |
| def initialize_database(self) -> bool: | |
| """ | |
| Initializes the SQLite database and creates required tables if they do not exist. | |
| Tables: | |
| - User: Stores user information. | |
| - Preference: Stores user preferences between generated images. | |
| Returns: | |
| bool: True if successful, False if an error occurred. | |
| """ | |
| try: | |
| db_exists = os.path.exists(self.db_path) | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
| with self.conn: # auto commit | |
| cursor = self.conn.cursor() | |
| if db_exists: | |
| self.logger.info(f"Database already exists at {self.db_path}") | |
| return True | |
| self.logger.info(f"Creating new database at {self.db_path}") | |
| os.makedirs(self.db_folder_path, exist_ok=True) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS User ( | |
| user_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS Preference ( | |
| preference_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| reference_id TEXT, | |
| model_left_id TEXT, | |
| model_right_id TEXT, | |
| preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')), | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES User(user_id) | |
| ) | |
| """) | |
| cursor.close() | |
| except Exception as e: | |
| self.logger.info(f"Creating the database failed with following error: {e}") | |
| return False | |
| return True | |
| def create_user(self, username: str) -> Tuple[Optional[int], str]: | |
| """ | |
| Creates a new user in the database. | |
| Args: | |
| username (str): The username of the new user. | |
| Returns: | |
| Tuple[Optional[int], str]: | |
| The user_id of the newly created user, or None if creation failed. | |
| If first entry is None then the second contains the exception message. | |
| """ | |
| ret: Optional[int] = None | |
| msg = "" | |
| try: | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path) | |
| with self.conn: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| """ | |
| INSERT INTO User (username) VALUES (?) | |
| """, | |
| (username,), | |
| ) | |
| user_id = cursor.lastrowid | |
| cursor.close() | |
| self.logger.info(f"User '{username}' created with user_id {user_id}") | |
| ret = user_id | |
| msg = username | |
| except sqlite3.IntegrityError: | |
| msg = f"User '{username}' already exists." | |
| self.logger.warning(msg) | |
| except Exception as e: | |
| msg = f"Failed to create user '{username}': {e}" | |
| self.logger.error(msg) | |
| return ret, msg | |
| def get_user_id_by_username(self, username: str) -> Optional[int]: | |
| """ | |
| Checks if a username exists in the database and returns the associated user_id. | |
| Args: | |
| username (str): The username to look up. | |
| Returns: | |
| Optional[int]: The user_id if the username exists, None otherwise. | |
| """ | |
| try: | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path) | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| """ | |
| SELECT user_id FROM User WHERE username = ? | |
| """, | |
| (username,), | |
| ) | |
| result = cursor.fetchone() | |
| cursor.close() | |
| if result: | |
| return result[0] | |
| else: | |
| return None | |
| except Exception as e: | |
| self.logger.error(f"Error checking username '{username}': {e}") | |
| return None | |
| def insert_preference( | |
| self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Inserts a new preference entry into the database. | |
| Args: | |
| user_id (int): ID of the user making the preference. | |
| reference_id (str): ID of the reference image. | |
| model_left_id (str): ID of the left model's generated image. | |
| model_right_id (str): ID of the right model's generated image. | |
| preferred_side (str): The preferred side ('left', 'right', or 'tie'). | |
| Returns: | |
| Tuple[bool, str]: True if insertion was successful, False otherwise with a | |
| string message describing the exception. | |
| """ | |
| msg = "" | |
| if preferred_side not in {"left", "right", "tie"}: | |
| msg = f"Invalid preferred_side value: {preferred_side}" | |
| self.logger.error(msg) | |
| return False, msg | |
| try: | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path) | |
| with self.conn: | |
| cursor = self.conn.cursor() | |
| cursor.execute( | |
| """ | |
| INSERT INTO Preference ( | |
| user_id, reference_id, model_left_id, model_right_id, preferred_side | |
| ) VALUES (?, ?, ?, ?, ?) | |
| """, | |
| (user_id, reference_id, model_left_id, model_right_id, preferred_side), | |
| ) | |
| cursor.close() | |
| self.logger.info(f"Preference inserted for user_id {user_id}") | |
| return True, msg | |
| except Exception as e: | |
| msg = f"Failed to insert preference: {e}" | |
| self.logger.error(msg) | |
| return False, msg | |
| def get_all_preferences(self) -> List[Tuple]: | |
| """ | |
| Retrieves all preference entries from the database. | |
| Returns: | |
| List[Tuple]: A list of tuples representing all preference entries. | |
| """ | |
| preferences = [] | |
| try: | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path) | |
| with self.conn: | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT * FROM Preference") | |
| preferences = cursor.fetchall() | |
| cursor.close() | |
| except Exception as e: | |
| self.logger.error(f"Failed to retrieve preferences: {e}") | |
| return preferences | |
| def get_preferences_by_user(self, user_id: int) -> List[Tuple]: | |
| """ | |
| Retrieves all preference entries for a specific user. | |
| Args: | |
| user_id (int): The ID of the user. | |
| Returns: | |
| List[Tuple]: A list of tuples representing the user's preference entries. | |
| """ | |
| preferences = [] | |
| try: | |
| if self.conn is None: | |
| self.conn = sqlite3.connect(self.db_path) | |
| with self.conn: | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,)) | |
| preferences = cursor.fetchall() | |
| cursor.close() | |
| self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.") | |
| except Exception as e: | |
| self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}") | |
| return preferences | |
| def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]: | |
| """ | |
| Maps a list of preference tuples to a list of dictionaries using the Preference schema. | |
| Args: | |
| preferences (List[Tuple]): List of tuples from the Preference table. | |
| Returns: | |
| List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema. | |
| """ | |
| keys = [ | |
| "preference_id", | |
| "user_id", | |
| "reference_id", | |
| "model_left_id", | |
| "model_right_id", | |
| "preferred_side", | |
| "timestamp", | |
| ] | |
| return [dict(zip(keys, row)) for row in preferences] | |