Spaces:
Running
Running
| """ | |
| database.py - SQLite Database Manager for Dressa App | |
| Handles: | |
| - User session management | |
| - Upload tracking | |
| - Rating storage | |
| - Corpus growth tracking | |
| """ | |
| import sqlite3 | |
| import uuid | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| import os | |
| from typing import List, Optional, Dict, Any | |
| from contextlib import contextmanager | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Default database path | |
| APP_DIR = Path(__file__).parent.resolve() | |
| DEFAULT_DB_PATH = Path( | |
| os.getenv("DRESSA_DB_PATH", str(APP_DIR / "user_study.db")) | |
| ) | |
| class Database: | |
| """SQLite database manager for user study data.""" | |
| def __init__(self, db_path: Optional[Path] = None): | |
| """ | |
| Initialize database connection. | |
| Args: | |
| db_path: Path to SQLite database file | |
| """ | |
| self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH | |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) | |
| self._init_db() | |
| def _init_db(self): | |
| """Create tables if they don't exist.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| # Users table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| user_id TEXT PRIMARY KEY, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Uploads table (user's query images) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS uploads ( | |
| upload_id TEXT PRIMARY KEY, | |
| user_id TEXT, | |
| filepath TEXT, | |
| uploaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| added_to_corpus INTEGER DEFAULT 0, | |
| num_ratings INTEGER DEFAULT 0, | |
| FOREIGN KEY (user_id) REFERENCES users(user_id) | |
| ) | |
| """) | |
| # Ratings table (legacy, kept for backward compatibility) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS ratings ( | |
| rating_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id TEXT, | |
| upload_id TEXT, | |
| result_image_path TEXT, | |
| model TEXT, | |
| rating TEXT, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(user_id), | |
| FOREIGN KEY (upload_id) REFERENCES uploads(upload_id) | |
| ) | |
| """) | |
| # New evaluation_ratings table with provenance | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS evaluation_ratings ( | |
| rating_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id TEXT NOT NULL, | |
| upload_id TEXT NOT NULL, | |
| result_image_id TEXT NOT NULL, | |
| rating TEXT NOT NULL, | |
| provenance TEXT NOT NULL, | |
| display_position INTEGER, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| conn.commit() | |
| logger.info(f"Database initialized at {self.db_path}") | |
| def _get_connection(self): | |
| """Context manager for database connections.""" | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row | |
| try: | |
| yield conn | |
| finally: | |
| conn.close() | |
| # ==================== User Methods ==================== | |
| def create_user(self) -> str: | |
| """ | |
| Create a new user with a unique ID. | |
| Returns: | |
| user_id: UUID string for the new user | |
| """ | |
| user_id = str(uuid.uuid4()) | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO users (user_id) VALUES (?)", | |
| (user_id,) | |
| ) | |
| conn.commit() | |
| logger.info(f"Created new user: {user_id}") | |
| return user_id | |
| def get_user(self, user_id: str) -> Optional[Dict]: | |
| """Get user by ID.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM users WHERE user_id = ?", | |
| (user_id,) | |
| ) | |
| row = cursor.fetchone() | |
| return dict(row) if row else None | |
| def user_exists(self, user_id: str) -> bool: | |
| """Check if a user exists.""" | |
| return self.get_user(user_id) is not None | |
| # ==================== Upload Methods ==================== | |
| def create_upload(self, user_id: str, filepath: str) -> str: | |
| """ | |
| Record a new image upload. | |
| Args: | |
| user_id: ID of the user who uploaded | |
| filepath: Path where the uploaded image is stored | |
| Returns: | |
| upload_id: UUID string for the new upload | |
| """ | |
| upload_id = str(uuid.uuid4()) | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """INSERT INTO uploads (upload_id, user_id, filepath) | |
| VALUES (?, ?, ?)""", | |
| (upload_id, user_id, filepath) | |
| ) | |
| conn.commit() | |
| logger.info(f"Created upload {upload_id} for user {user_id}") | |
| return upload_id | |
| def get_upload(self, upload_id: str) -> Optional[Dict]: | |
| """Get upload by ID.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM uploads WHERE upload_id = ?", | |
| (upload_id,) | |
| ) | |
| row = cursor.fetchone() | |
| return dict(row) if row else None | |
| def get_user_uploads(self, user_id: str) -> List[Dict]: | |
| """Get all uploads for a user.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM uploads | |
| WHERE user_id = ? | |
| ORDER BY uploaded_at DESC""", | |
| (user_id,) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def mark_added_to_corpus(self, upload_id: str): | |
| """Mark an upload as added to the corpus.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """UPDATE uploads | |
| SET added_to_corpus = 1 | |
| WHERE upload_id = ?""", | |
| (upload_id,) | |
| ) | |
| conn.commit() | |
| logger.info(f"Marked upload {upload_id} as added to corpus") | |
| def increment_upload_ratings(self, upload_id: str) -> int: | |
| """ | |
| Increment the rating count for an upload. | |
| Returns: | |
| New rating count | |
| """ | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """UPDATE uploads | |
| SET num_ratings = num_ratings + 1 | |
| WHERE upload_id = ?""", | |
| (upload_id,) | |
| ) | |
| cursor.execute( | |
| "SELECT num_ratings FROM uploads WHERE upload_id = ?", | |
| (upload_id,) | |
| ) | |
| conn.commit() | |
| result = cursor.fetchone() | |
| return result['num_ratings'] if result else 0 | |
| # ==================== Rating Methods ==================== | |
| def save_rating( | |
| self, | |
| user_id: str, | |
| upload_id: str, | |
| result_image_path: str, | |
| model: str, | |
| rating: str | |
| ) -> int: | |
| """ | |
| Save a user's rating for a result image. | |
| Args: | |
| user_id: ID of the user | |
| upload_id: ID of the query image upload | |
| result_image_path: Path to the result image being rated | |
| model: Which model returned this result | |
| rating: 'similar' or 'not_similar' | |
| Returns: | |
| rating_id: ID of the new rating | |
| """ | |
| if rating not in ('similar', 'not_similar'): | |
| raise ValueError(f"Invalid rating: {rating}. " | |
| f"Must be 'similar' or 'not_similar'") | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """INSERT INTO ratings | |
| (user_id, upload_id, result_image_path, model, rating) | |
| VALUES (?, ?, ?, ?, ?)""", | |
| (user_id, upload_id, result_image_path, model, rating) | |
| ) | |
| conn.commit() | |
| rating_id = cursor.lastrowid | |
| # Increment upload's rating count | |
| self.increment_upload_ratings(upload_id) | |
| logger.info(f"Saved rating {rating_id}: {rating} for {model}") | |
| return rating_id | |
| def get_upload_ratings(self, upload_id: str) -> List[Dict]: | |
| """Get all ratings for an upload.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM ratings | |
| WHERE upload_id = ? | |
| ORDER BY timestamp""", | |
| (upload_id,) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_user_ratings(self, user_id: str) -> List[Dict]: | |
| """Get all ratings by a user.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM ratings | |
| WHERE user_id = ? | |
| ORDER BY timestamp""", | |
| (user_id,) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def has_rated(self, upload_id: str, result_image_path: str) -> bool: | |
| """Check if a result image has already been rated for this upload.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT COUNT(*) as count FROM ratings | |
| WHERE upload_id = ? AND result_image_path = ?""", | |
| (upload_id, result_image_path) | |
| ) | |
| result = cursor.fetchone() | |
| return result['count'] > 0 | |
| # ==================== Evaluation Rating Methods ==================== | |
| def save_evaluation_rating( | |
| self, | |
| user_id: str, | |
| upload_id: str, | |
| result_image_id: str, | |
| rating: str, | |
| provenance: dict, | |
| display_position: int | |
| ): | |
| """ | |
| Save a user's evaluation rating with provenance information. | |
| Args: | |
| user_id: ID of the user | |
| upload_id: ID of the query image upload | |
| result_image_id: Path/ID of the result image being rated | |
| rating: 'similar' or 'not_similar' | |
| provenance: Dict mapping model_name -> rank (1-indexed) | |
| display_position: Position in the shuffled display order | |
| """ | |
| if rating not in ('similar', 'not_similar'): | |
| raise ValueError(f"Invalid rating: {rating}. " | |
| f"Must be 'similar' or 'not_similar'") | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO evaluation_ratings | |
| (user_id, upload_id, result_image_id, rating, provenance, display_position) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, (user_id, upload_id, result_image_id, rating, | |
| json.dumps(provenance), display_position)) | |
| conn.commit() | |
| # Increment upload's rating count | |
| self.increment_upload_ratings(upload_id) | |
| logger.info(f"Saved evaluation rating: {rating} for {result_image_id} " | |
| f"(position {display_position}, provenance: {provenance})") | |
| def get_evaluation_ratings(self, upload_id: str) -> List[Dict]: | |
| """Get all evaluation ratings for an upload.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM evaluation_ratings | |
| WHERE upload_id = ? | |
| ORDER BY timestamp""", | |
| (upload_id,) | |
| ) | |
| rows = [dict(row) for row in cursor.fetchall()] | |
| # Parse provenance JSON | |
| for row in rows: | |
| row['provenance'] = json.loads(row['provenance']) | |
| return rows | |
| # ==================== Analytics Methods ==================== | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get overall statistics.""" | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| stats = {} | |
| # Total users | |
| cursor.execute("SELECT COUNT(*) as count FROM users") | |
| stats['total_users'] = cursor.fetchone()['count'] | |
| # Total uploads | |
| cursor.execute("SELECT COUNT(*) as count FROM uploads") | |
| stats['total_uploads'] = cursor.fetchone()['count'] | |
| # Uploads added to corpus | |
| cursor.execute( | |
| "SELECT COUNT(*) as count FROM uploads WHERE added_to_corpus = 1" | |
| ) | |
| stats['corpus_additions'] = cursor.fetchone()['count'] | |
| # Total ratings | |
| cursor.execute("SELECT COUNT(*) as count FROM ratings") | |
| stats['total_ratings'] = cursor.fetchone()['count'] | |
| # Ratings by model | |
| cursor.execute( | |
| """SELECT model, COUNT(*) as count | |
| FROM ratings GROUP BY model""" | |
| ) | |
| stats['ratings_by_model'] = { | |
| row['model']: row['count'] | |
| for row in cursor.fetchall() | |
| } | |
| # Similar vs not similar | |
| cursor.execute( | |
| """SELECT rating, COUNT(*) as count | |
| FROM ratings GROUP BY rating""" | |
| ) | |
| stats['ratings_breakdown'] = { | |
| row['rating']: row['count'] | |
| for row in cursor.fetchall() | |
| } | |
| return stats | |
| def export_ratings_csv(self, output_path: str): | |
| """Export all ratings to CSV for analysis.""" | |
| import csv | |
| with self._get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT r.*, u.filepath as query_image_path | |
| FROM ratings r | |
| JOIN uploads u ON r.upload_id = u.upload_id | |
| ORDER BY r.timestamp | |
| """) | |
| rows = cursor.fetchall() | |
| with open(output_path, 'w', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| 'rating_id', 'user_id', 'upload_id', 'query_image_path', | |
| 'result_image_path', 'model', 'rating', 'timestamp' | |
| ]) | |
| for row in rows: | |
| writer.writerow([ | |
| row['rating_id'], row['user_id'], row['upload_id'], | |
| row['query_image_path'], row['result_image_path'], | |
| row['model'], row['rating'], row['timestamp'] | |
| ]) | |
| logger.info(f"Exported {len(rows)} ratings to {output_path}") | |
| # Convenience function for testing | |
| def test_database(): | |
| """Test database operations.""" | |
| import tempfile | |
| import os | |
| # Use temp database for testing | |
| test_db = Path(tempfile.gettempdir()) / "dressa_test.db" | |
| if test_db.exists(): | |
| os.remove(test_db) | |
| db = Database(test_db) | |
| print("Testing database.py...") | |
| # Test user creation | |
| user_id = db.create_user() | |
| print(f" Created user: {user_id[:8]}...") | |
| # Test upload creation | |
| upload_id = db.create_upload(user_id, "/path/to/test.jpg") | |
| print(f" Created upload: {upload_id[:8]}...") | |
| # Test rating | |
| rating_id = db.save_rating( | |
| user_id, upload_id, "/path/to/result.jpg", | |
| "openai_clip", "similar" | |
| ) | |
| print(f" Saved rating: {rating_id}") | |
| # Test stats | |
| stats = db.get_stats() | |
| print(f" Stats: {stats}") | |
| # Cleanup | |
| os.remove(test_db) | |
| print("\nDatabase tests complete!") | |
| if __name__ == "__main__": | |
| test_database() | |