Spaces:
Running
Running
File size: 10,253 Bytes
b6d1c13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | """Class for Database handling using SQLite."""
import logging
import os
import sqlite3
from sqlite3 import Connection
from typing import Any, Dict, List, Optional, Tuple
class SQLiteDB:
"""
Handles SQLite database operations for an image generation comparison experiment.
Attributes:
db_folder_path (str): Directory where the SQLite database file is stored.
experiment_name (str): Name of the experiment (used to name the database file).
db_filename (str): Filename of the SQLite database.
db_path (str): Full path to the SQLite database file.
logger (logging.Logger): Logger instance for logging database operations.
"""
def __init__(self, db_folder_path: str, experiment_name: str = "arena"):
"""
Initializes the SQLiteDB instance with dataset and database configuration.
Args:
db_folder_path (str): Directory where the SQLite database file is stored.
experiment_name (str, optional): Name of the experiment. Defaults to "arena".
"""
self.experiment_name = experiment_name
self.db_folder_path = db_folder_path
self.db_filename = f"{experiment_name.lower()}.db"
self.db_path = os.path.join(db_folder_path, self.db_filename)
self.log_path = os.path.join(db_folder_path, "log_db.txt")
logging.basicConfig(filename=self.log_path, filemode="a", level=logging.DEBUG)
self.logger = logging.getLogger()
self.conn: Optional[Connection] = None
def __del__(self):
if self.conn is not None:
try:
self.conn.close()
except Exception:
pass
def initialize_database(self) -> bool:
"""
Initializes the SQLite database and creates required tables if they do not exist.
Tables:
- User: Stores user information.
- Preference: Stores user preferences between generated images.
Returns:
bool: True if successful, False if an error occurred.
"""
try:
db_exists = os.path.exists(self.db_path)
if self.conn is None:
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
with self.conn: # auto commit
cursor = self.conn.cursor()
if db_exists:
self.logger.info(f"Database already exists at {self.db_path}")
return True
self.logger.info(f"Creating new database at {self.db_path}")
os.makedirs(self.db_folder_path, exist_ok=True)
cursor.execute("""
CREATE TABLE IF NOT EXISTS User (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS Preference (
preference_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
reference_id TEXT,
model_left_id TEXT,
model_right_id TEXT,
preferred_side TEXT CHECK(preferred_side IN ('left', 'right', 'tie')),
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES User(user_id)
)
""")
cursor.close()
except Exception as e:
self.logger.info(f"Creating the database failed with following error: {e}")
return False
return True
def create_user(self, username: str) -> Tuple[Optional[int], str]:
"""
Creates a new user in the database.
Args:
username (str): The username of the new user.
Returns:
Tuple[Optional[int], str]:
The user_id of the newly created user, or None if creation failed.
If first entry is None then the second contains the exception message.
"""
ret: Optional[int] = None
msg = ""
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO User (username) VALUES (?)
""",
(username,),
)
user_id = cursor.lastrowid
cursor.close()
self.logger.info(f"User '{username}' created with user_id {user_id}")
ret = user_id
msg = username
except sqlite3.IntegrityError:
msg = f"User '{username}' already exists."
self.logger.warning(msg)
except Exception as e:
msg = f"Failed to create user '{username}': {e}"
self.logger.error(msg)
return ret, msg
def get_user_id_by_username(self, username: str) -> Optional[int]:
"""
Checks if a username exists in the database and returns the associated user_id.
Args:
username (str): The username to look up.
Returns:
Optional[int]: The user_id if the username exists, None otherwise.
"""
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
cursor = self.conn.cursor()
cursor.execute(
"""
SELECT user_id FROM User WHERE username = ?
""",
(username,),
)
result = cursor.fetchone()
cursor.close()
if result:
return result[0]
else:
return None
except Exception as e:
self.logger.error(f"Error checking username '{username}': {e}")
return None
def insert_preference(
self, user_id: int, reference_id: str, model_left_id: str, model_right_id: str, preferred_side: str
) -> Tuple[bool, str]:
"""
Inserts a new preference entry into the database.
Args:
user_id (int): ID of the user making the preference.
reference_id (str): ID of the reference image.
model_left_id (str): ID of the left model's generated image.
model_right_id (str): ID of the right model's generated image.
preferred_side (str): The preferred side ('left', 'right', or 'tie').
Returns:
Tuple[bool, str]: True if insertion was successful, False otherwise with a
string message describing the exception.
"""
msg = ""
if preferred_side not in {"left", "right", "tie"}:
msg = f"Invalid preferred_side value: {preferred_side}"
self.logger.error(msg)
return False, msg
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO Preference (
user_id, reference_id, model_left_id, model_right_id, preferred_side
) VALUES (?, ?, ?, ?, ?)
""",
(user_id, reference_id, model_left_id, model_right_id, preferred_side),
)
cursor.close()
self.logger.info(f"Preference inserted for user_id {user_id}")
return True, msg
except Exception as e:
msg = f"Failed to insert preference: {e}"
self.logger.error(msg)
return False, msg
def get_all_preferences(self) -> List[Tuple]:
"""
Retrieves all preference entries from the database.
Returns:
List[Tuple]: A list of tuples representing all preference entries.
"""
preferences = []
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM Preference")
preferences = cursor.fetchall()
cursor.close()
except Exception as e:
self.logger.error(f"Failed to retrieve preferences: {e}")
return preferences
def get_preferences_by_user(self, user_id: int) -> List[Tuple]:
"""
Retrieves all preference entries for a specific user.
Args:
user_id (int): The ID of the user.
Returns:
List[Tuple]: A list of tuples representing the user's preference entries.
"""
preferences = []
try:
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
with self.conn:
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM Preference WHERE user_id = ?", (user_id,))
preferences = cursor.fetchall()
cursor.close()
self.logger.info(f"Retrieved {len(preferences)} preferences for user_id {user_id}.")
except Exception as e:
self.logger.error(f"Failed to retrieve preferences for user_id {user_id}: {e}")
return preferences
def map_preferences_to_dicts(self, preferences: List[Tuple]) -> List[Dict[str, Any]]:
"""
Maps a list of preference tuples to a list of dictionaries using the Preference schema.
Args:
preferences (List[Tuple]): List of tuples from the Preference table.
Returns:
List[Dict[str, Any]]: List of dictionaries with keys matching the Preference schema.
"""
keys = [
"preference_id",
"user_id",
"reference_id",
"model_left_id",
"model_right_id",
"preferred_side",
"timestamp",
]
return [dict(zip(keys, row)) for row in preferences]
|