wlp_user_study / src /gecora /db /sqlite.py
Markus Pobitzer
app
b6d1c13
"""Class for Database handling using SQLite."""
import logging
import os
import sqlite3
from sqlite3 import Connection
from typing import Any, Dict, List, Optional, Tuple
class SQLiteDB:
"""
Handles SQLite database operations for an image generation comparison experiment.
Attributes:
db_folder_path (str): Directory where the SQLite database file is stored.
experiment_name (str): Name of the experiment (used to name the database file).
db_filename (str): Filename of the SQLite database.
db_path (str): Full path to the SQLite database file.
logger (logging.Logger): Logger instance for logging database operations.
"""
def __init__(self, db_folder_path: str, experiment_name: str = "arena"):
"""
Initializes the SQLiteDB instance with dataset and database configuration.
Args:
db_folder_path (str): Directory where the SQLite database file is stored.
experiment_name (str, optional): Name of the experiment. Defaults to "arena".
"""
self.experiment_name = experiment_name
self.db_folder_path = db_folder_path
self.db_filename = f"{experiment_name.lower()}.db"
self.db_path = os.path.join(db_folder_path, self.db_filename)
self.log_path = os.path.join(db_folder_path, "log_db.txt")
logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
self.logger = logging.getLogger()
self.conn: Optional[Connection] = None
def __del__(self):
if self.conn is not None:
try:
self.conn.close()
except Exception:
pass
def initialize_database(self) -> bool:
"""
Initializes the SQLite database and creates required tables if they do not exist.
Tables:
- User: Stores user information.
- Preference: Stores user preferences between generated images.
Returns:
bool: True if successful, False if an error occurred.
"""
try:
db_exists = os.path.exists(self.db_path)
if self.conn is None:
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
with self.conn: # auto commit
cursor = self.conn.cursor()
if db_exists:
self.logger.info(f"Database already exists at {self.db_path}")
return True
self.logger.info(f"Creating new database at {self.db_path}")
os.makedirs(self.db_folder_path, exist_ok=True)
cursor.execute("""
CREATE TABLE IF NOT EXISTS User (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS Preference (
preference_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
reference_id TEXT,
model_left_id TEXT,
model_right_id TEXT,
preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')),
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES User(user_id)
)
""")
cursor.close()
except Exception as e:
self.logger.info(f"Creating the database failed with following error: {e}")
return False
return True
def create_user(self, username: str) -> Tuple[Optional[int], str]:
"""
Creates a new user in the database.
Args:
username (str): The username of the new user.
Returns:
Tuple[Optional[int], str]:
The user_id of the newly created user, or None if creation failed.
If first entry is None then the second contains the exception message.
"""
ret: Optional[int] = None
msg = ""
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO User (username) VALUES (?)
""",
(username,),
)
user_id = cursor.lastrowid
cursor.close()
self.logger.info(f"User '{username}' created with user_id {user_id}")
ret = user_id
msg = username
except sqlite3.IntegrityError:
msg = f"User '{username}' already exists."
self.logger.warning(msg)
except Exception as e:
msg = f"Failed to create user '{username}': {e}"
self.logger.error(msg)
return ret, msg
def get_user_id_by_username(self, username: str) -> Optional[int]:
"""
Checks if a username exists in the database and returns the associated user_id.
Args:
username (str): The username to look up.
Returns:
Optional[int]: The user_id if the username exists, None otherwise.
"""
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
cursor = self.conn.cursor()
cursor.execute(
"""
SELECT user_id FROM User WHERE username = ?
""",
(username,),
)
result = cursor.fetchone()
cursor.close()
if result:
return result[0]
else:
return None
except Exception as e:
self.logger.error(f"Error checking username '{username}': {e}")
return None
def insert_preference(
self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
) -> Tuple[bool, str]:
"""
Inserts a new preference entry into the database.
Args:
user_id (int): ID of the user making the preference.
reference_id (str): ID of the reference image.
model_left_id (str): ID of the left model's generated image.
model_right_id (str): ID of the right model's generated image.
preferred_side (str): The preferred side ('left', 'right', or 'tie').
Returns:
Tuple[bool, str]: True if insertion was successful, False otherwise with a
string message describing the exception.
"""
msg = ""
if preferred_side not in {"left", "right", "tie"}:
msg = f"Invalid preferred_side value: {preferred_side}"
self.logger.error(msg)
return False, msg
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO Preference (
user_id, reference_id, model_left_id, model_right_id, preferred_side
) VALUES (?, ?, ?, ?, ?)
""",
(user_id, reference_id, model_left_id, model_right_id, preferred_side),
)
cursor.close()
self.logger.info(f"Preference inserted for user_id {user_id}")
return True, msg
except Exception as e:
msg = f"Failed to insert preference: {e}"
self.logger.error(msg)
return False, msg
def get_all_preferences(self) -> List[Tuple]:
"""
Retrieves all preference entries from the database.
Returns:
List[Tuple]: A list of tuples representing all preference entries.
"""
preferences = []
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM Preference")
preferences = cursor.fetchall()
cursor.close()
except Exception as e:
self.logger.error(f"Failed to retrieve preferences: {e}")
return preferences
def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
"""
Retrieves all preference entries for a specific user.
Args:
user_id (int): The ID of the user.
Returns:
List[Tuple]: A list of tuples representing the user's preference entries.
"""
preferences = []
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,))
preferences = cursor.fetchall()
cursor.close()
self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.")
except Exception as e:
self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
return preferences
def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
"""
Maps a list of preference tuples to a list of dictionaries using the Preference schema.
Args:
preferences (List[Tuple]): List of tuples from the Preference table.
Returns:
List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
"""
keys = [
"preference_id",
"user_id",
"reference_id",
"model_left_id",
"model_right_id",
"preferred_side",
"timestamp",
]
return [dict(zip(keys, row)) for row in preferences]