"""SQLite database layer for DeepAMR. Provides user auth, prediction history, activity logging, and dashboard stats. Uses Python stdlib sqlite3 — zero extra dependencies. """ import hashlib import json import os import sqlite3 import uuid from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Tuple DB_PATH = Path(__file__).parent.parent.parent / "deepamr.db" # --------------------------------------------------------------------------- # Connection helper # --------------------------------------------------------------------------- def get_db() -> sqlite3.Connection: conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") return conn # --------------------------------------------------------------------------- # Schema / init # --------------------------------------------------------------------------- def init_db(): conn = get_db() cur = conn.cursor() cur.executescript(""" CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, email TEXT UNIQUE NOT NULL, name TEXT NOT NULL, password_hash TEXT NOT NULL, salt TEXT NOT NULL, role TEXT NOT NULL DEFAULT 'user', organization TEXT, created_at TEXT NOT NULL, last_login TEXT ); CREATE TABLE IF NOT EXISTS sessions ( token TEXT PRIMARY KEY, user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, created_at TEXT NOT NULL, expires_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS predictions ( id TEXT PRIMARY KEY, sample_id TEXT NOT NULL, user_id TEXT REFERENCES users(id) ON DELETE SET NULL, organism TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', risk_level TEXT, file_name TEXT, file_size INTEGER, results_json TEXT, created_at TEXT NOT NULL, completed_at TEXT ); CREATE TABLE IF NOT EXISTS activity_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT, user_name TEXT, action TEXT NOT NULL, details TEXT, timestamp TEXT NOT NULL ); """) # Add indexes for performance cur.executescript(""" CREATE INDEX IF NOT EXISTS idx_predictions_user_id ON predictions(user_id); CREATE INDEX IF NOT EXISTS idx_predictions_created_at ON predictions(created_at); CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); """) # Add model_version column if missing try: cur.execute("ALTER TABLE predictions ADD COLUMN model_version TEXT") except sqlite3.OperationalError: pass # column already exists # Seed admin user if table is empty row = cur.execute("SELECT COUNT(*) FROM users").fetchone() if row[0] == 0: admin_pw = os.environ.get("DEEPAMR_ADMIN_PASSWORD") if not admin_pw: admin_pw = os.urandom(16).hex() import warnings warnings.warn( f"No DEEPAMR_ADMIN_PASSWORD set. Generated random admin password: {admin_pw}", stacklevel=2, ) salt = os.urandom(16).hex() pw_hash = hash_password(admin_pw, salt) cur.execute( "INSERT INTO users (id, email, name, password_hash, salt, role, organization, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ( str(uuid.uuid4()), "admin@deepamr.org", "Admin", pw_hash, salt, "admin", "DeepAMR", datetime.utcnow().isoformat(), ), ) conn.commit() conn.close() # --------------------------------------------------------------------------- # Password helpers # --------------------------------------------------------------------------- def hash_password(password: str, salt: str) -> str: return hashlib.pbkdf2_hmac( "sha256", password.encode(), salt.encode(), 100_000 ).hex() def verify_password(password: str, salt: str, pw_hash: str) -> bool: # Support legacy SHA256 hashes during migration new_hash = hash_password(password, salt) if new_hash == pw_hash: return True legacy = hashlib.sha256((salt + password).encode()).hexdigest() return legacy == pw_hash # --------------------------------------------------------------------------- # Users # --------------------------------------------------------------------------- def create_user(email: str, name: str, password: str, role: str = "user", organization: str | None = None) -> Dict: conn = get_db() user_id = str(uuid.uuid4()) salt = os.urandom(16).hex() pw_hash = hash_password(password, salt) now = datetime.utcnow().isoformat() try: conn.execute( "INSERT INTO users (id, email, name, password_hash, salt, role, organization, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (user_id, email, name, pw_hash, salt, role, organization, now), ) conn.commit() except sqlite3.IntegrityError: conn.close() raise ValueError("Email already registered") user = dict(conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()) conn.close() return _sanitize_user(user) def get_user_by_email(email: str) -> Optional[Dict]: conn = get_db() row = conn.execute("SELECT * FROM users WHERE email = ?", (email,)).fetchone() conn.close() return dict(row) if row else None def get_user_by_id(user_id: str) -> Optional[Dict]: conn = get_db() row = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone() conn.close() return dict(row) if row else None def list_users() -> List[Dict]: conn = get_db() rows = conn.execute("SELECT * FROM users ORDER BY created_at DESC").fetchall() conn.close() return [_sanitize_user(dict(r)) for r in rows] def delete_user(user_id: str) -> bool: conn = get_db() cur = conn.execute("DELETE FROM users WHERE id = ?", (user_id,)) conn.commit() conn.close() return cur.rowcount > 0 def update_last_login(user_id: str): conn = get_db() conn.execute("UPDATE users SET last_login = ? WHERE id = ?", (datetime.utcnow().isoformat(), user_id)) conn.commit() conn.close() def _sanitize_user(user: Dict) -> Dict: """Remove password fields from user dict for API responses.""" return { "id": user["id"], "email": user["email"], "name": user["name"], "role": user["role"], "organization": user.get("organization"), "createdAt": user["created_at"], "lastLogin": user.get("last_login"), } # --------------------------------------------------------------------------- # Sessions # --------------------------------------------------------------------------- def create_session(user_id: str) -> str: conn = get_db() token = str(uuid.uuid4()) now = datetime.utcnow() expires = now + timedelta(days=7) conn.execute( "INSERT INTO sessions (token, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)", (token, user_id, now.isoformat(), expires.isoformat()), ) conn.commit() conn.close() return token def get_session(token: str) -> Optional[Dict]: conn = get_db() row = conn.execute("SELECT * FROM sessions WHERE token = ?", (token,)).fetchone() conn.close() if not row: return None session = dict(row) if datetime.fromisoformat(session["expires_at"]) < datetime.utcnow(): delete_session(token) return None return session def delete_session(token: str): conn = get_db() conn.execute("DELETE FROM sessions WHERE token = ?", (token,)) conn.commit() conn.close() # --------------------------------------------------------------------------- # Predictions # --------------------------------------------------------------------------- def save_prediction( sample_id: str, user_id: Optional[str], organism: str, status: str, risk_level: Optional[str], file_name: Optional[str], file_size: Optional[int], results_json: Optional[str], model_version: Optional[str] = None, ) -> Dict: conn = get_db() pred_id = f"pred-{uuid.uuid4().hex[:8]}" now = datetime.utcnow().isoformat() completed = now if status == "completed" else None conn.execute( "INSERT INTO predictions (id, sample_id, user_id, organism, status, risk_level, file_name, file_size, results_json, created_at, completed_at, model_version) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (pred_id, sample_id, user_id, organism, status, risk_level, file_name, file_size, results_json, now, completed, model_version), ) conn.commit() row = conn.execute("SELECT * FROM predictions WHERE id = ?", (pred_id,)).fetchone() conn.close() return _format_prediction(dict(row)) def get_prediction(pred_id: str) -> Optional[Dict]: conn = get_db() row = conn.execute("SELECT * FROM predictions WHERE id = ?", (pred_id,)).fetchone() conn.close() return _format_prediction(dict(row)) if row else None def list_predictions( user_id: Optional[str] = None, organism: Optional[str] = None, status: Optional[str] = None, risk: Optional[str] = None, search: Optional[str] = None, date_from: Optional[str] = None, date_to: Optional[str] = None, ) -> List[Dict]: conn = get_db() query = "SELECT * FROM predictions WHERE 1=1" params: List[Any] = [] if user_id: query += " AND user_id = ?" params.append(user_id) if organism: query += " AND organism = ?" params.append(organism) if status: query += " AND status = ?" params.append(status) if risk: query += " AND risk_level = ?" params.append(risk) if search: query += " AND (sample_id LIKE ? OR organism LIKE ? OR file_name LIKE ?)" like = f"%{search}%" params.extend([like, like, like]) if date_from: query += " AND created_at >= ?" params.append(date_from) if date_to: query += " AND created_at <= ?" params.append(date_to) query += " ORDER BY created_at DESC" rows = conn.execute(query, params).fetchall() conn.close() return [_format_prediction(dict(r)) for r in rows] def delete_prediction(pred_id: str) -> bool: conn = get_db() cur = conn.execute("DELETE FROM predictions WHERE id = ?", (pred_id,)) conn.commit() conn.close() return cur.rowcount > 0 def get_recent_predictions(limit: int = 5) -> List[Dict]: conn = get_db() rows = conn.execute( "SELECT * FROM predictions ORDER BY created_at DESC LIMIT ?", (limit,) ).fetchall() conn.close() return [_format_prediction(dict(r)) for r in rows] def _format_prediction(p: Dict) -> Dict: """Convert DB row to frontend-friendly format.""" results_data = None if p.get("results_json"): try: results_data = json.loads(p["results_json"]) except json.JSONDecodeError: pass return { "id": p["id"], "sampleId": p["sample_id"], "organism": p["organism"], "status": p["status"], "createdAt": p["created_at"], "completedAt": p.get("completed_at"), "uploadedBy": p.get("user_id", ""), "fileName": p.get("file_name", ""), "fileSize": p.get("file_size", 0), "overallRisk": (p.get("risk_level") or "low").lower(), "results": results_data.get("results") if results_data else None, "detectedGenes": results_data.get("detectedGenes") if results_data else None, "summary": results_data.get("summary") if results_data else None, } # --------------------------------------------------------------------------- # Activity log # --------------------------------------------------------------------------- def log_activity(user_id: Optional[str], user_name: str, action: str, details: Optional[str] = None): conn = get_db() conn.execute( "INSERT INTO activity_log (user_id, user_name, action, details, timestamp) VALUES (?, ?, ?, ?, ?)", (user_id, user_name, action, details, datetime.utcnow().isoformat()), ) conn.commit() conn.close() def get_recent_activity(limit: int = 20) -> List[Dict]: conn = get_db() rows = conn.execute( "SELECT * FROM activity_log ORDER BY timestamp DESC LIMIT ?", (limit,) ).fetchall() conn.close() return [ { "userId": r["user_id"], "userName": r["user_name"], "action": r["action"], "details": r["details"], "timestamp": r["timestamp"], } for r in rows ] # --------------------------------------------------------------------------- # Dashboard stats # --------------------------------------------------------------------------- def get_dashboard_stats() -> Dict: conn = get_db() total = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0] resistant = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'high'").fetchone()[0] susceptible = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal')").fetchone()[0] pending = conn.execute("SELECT COUNT(*) FROM predictions WHERE status IN ('pending', 'processing')").fetchone()[0] week_ago = (datetime.utcnow() - timedelta(days=7)).isoformat() two_weeks_ago = (datetime.utcnow() - timedelta(days=14)).isoformat() this_week = conn.execute("SELECT COUNT(*) FROM predictions WHERE created_at >= ?", (week_ago,)).fetchone()[0] last_week = conn.execute( "SELECT COUNT(*) FROM predictions WHERE created_at >= ? AND created_at < ?", (two_weeks_ago, week_ago), ).fetchone()[0] weekly_change = this_week - last_week if last_week else 0 this_week_r = conn.execute( "SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at >= ?", (week_ago,) ).fetchone()[0] last_week_r = conn.execute( "SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at >= ? AND created_at < ?", (two_weeks_ago, week_ago), ).fetchone()[0] weekly_r_change = this_week_r - last_week_r if last_week_r else 0 conn.close() return { "totalPredictions": total, "resistantCount": resistant, "susceptibleCount": susceptible, "pendingCount": pending, "weeklyChange": { "predictions": weekly_change, "resistant": weekly_r_change, }, } def get_resistance_overview() -> List[Dict]: conn = get_db() resistant = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'high'").fetchone()[0] moderate = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'moderate'").fetchone()[0] susceptible = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal')").fetchone()[0] conn.close() return [ {"name": "Resistant", "value": resistant, "color": "#ef4444"}, {"name": "Intermediate", "value": moderate, "color": "#eab308"}, {"name": "Susceptible", "value": susceptible, "color": "#22c55e"}, ] def get_trends() -> List[Dict]: conn = get_db() trends = [] for i in range(6, -1, -1): day = datetime.utcnow() - timedelta(days=i) day_start = day.strftime("%Y-%m-%dT00:00:00") day_end = day.strftime("%Y-%m-%dT23:59:59") label = day.strftime("%b %d") r = conn.execute( "SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at BETWEEN ? AND ?", (day_start, day_end), ).fetchone()[0] s = conn.execute( "SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal') AND created_at BETWEEN ? AND ?", (day_start, day_end), ).fetchone()[0] m = conn.execute( "SELECT COUNT(*) FROM predictions WHERE risk_level = 'moderate' AND created_at BETWEEN ? AND ?", (day_start, day_end), ).fetchone()[0] trends.append({"date": label, "resistant": r, "susceptible": s, "intermediate": m}) conn.close() return trends # --------------------------------------------------------------------------- # Admin stats # --------------------------------------------------------------------------- def get_admin_stats() -> Dict: conn = get_db() total_users = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0] week_ago = (datetime.utcnow() - timedelta(days=7)).isoformat() active_users = conn.execute( "SELECT COUNT(*) FROM users WHERE last_login >= ?", (week_ago,) ).fetchone()[0] total_predictions = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0] today_start = datetime.utcnow().strftime("%Y-%m-%dT00:00:00") predictions_today = conn.execute( "SELECT COUNT(*) FROM predictions WHERE created_at >= ?", (today_start,) ).fetchone()[0] # Estimate storage from file sizes storage_row = conn.execute("SELECT COALESCE(SUM(file_size), 0) FROM predictions").fetchone() storage_bytes = storage_row[0] storage_gb = round(storage_bytes / (1024**3), 2) conn.close() return { "totalUsers": total_users, "activeUsers": active_users, "totalPredictions": total_predictions, "predictionsToday": predictions_today, "storageUsed": storage_gb, "storageLimit": 10, }