deepamr-api / src /api /database.py
hossainlab's picture
Deploy DeepAMR API backend
3255634
"""SQLite database layer for DeepAMR.
Provides user auth, prediction history, activity logging, and dashboard stats.
Uses Python stdlib sqlite3 — zero extra dependencies.
"""
import hashlib
import json
import os
import sqlite3
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
DB_PATH = Path(__file__).parent.parent.parent / "deepamr.db"
# ---------------------------------------------------------------------------
# Connection helper
# ---------------------------------------------------------------------------
def get_db() -> sqlite3.Connection:
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
# ---------------------------------------------------------------------------
# Schema / init
# ---------------------------------------------------------------------------
def init_db():
conn = get_db()
cur = conn.cursor()
cur.executescript("""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
password_hash TEXT NOT NULL,
salt TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'user',
organization TEXT,
created_at TEXT NOT NULL,
last_login TEXT
);
CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS predictions (
id TEXT PRIMARY KEY,
sample_id TEXT NOT NULL,
user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
organism TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
risk_level TEXT,
file_name TEXT,
file_size INTEGER,
results_json TEXT,
created_at TEXT NOT NULL,
completed_at TEXT
);
CREATE TABLE IF NOT EXISTS activity_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
user_name TEXT,
action TEXT NOT NULL,
details TEXT,
timestamp TEXT NOT NULL
);
""")
# Add indexes for performance
cur.executescript("""
CREATE INDEX IF NOT EXISTS idx_predictions_user_id ON predictions(user_id);
CREATE INDEX IF NOT EXISTS idx_predictions_created_at ON predictions(created_at);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
""")
# Add model_version column if missing
try:
cur.execute("ALTER TABLE predictions ADD COLUMN model_version TEXT")
except sqlite3.OperationalError:
pass # column already exists
# Seed admin user if table is empty
row = cur.execute("SELECT COUNT(*) FROM users").fetchone()
if row[0] == 0:
admin_pw = os.environ.get("DEEPAMR_ADMIN_PASSWORD")
if not admin_pw:
admin_pw = os.urandom(16).hex()
import warnings
warnings.warn(
f"No DEEPAMR_ADMIN_PASSWORD set. Generated random admin password: {admin_pw}",
stacklevel=2,
)
salt = os.urandom(16).hex()
pw_hash = hash_password(admin_pw, salt)
cur.execute(
"INSERT INTO users (id, email, name, password_hash, salt, role, organization, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
str(uuid.uuid4()),
"admin@deepamr.org",
"Admin",
pw_hash,
salt,
"admin",
"DeepAMR",
datetime.utcnow().isoformat(),
),
)
conn.commit()
conn.close()
# ---------------------------------------------------------------------------
# Password helpers
# ---------------------------------------------------------------------------
def hash_password(password: str, salt: str) -> str:
return hashlib.pbkdf2_hmac(
"sha256", password.encode(), salt.encode(), 100_000
).hex()
def verify_password(password: str, salt: str, pw_hash: str) -> bool:
# Support legacy SHA256 hashes during migration
new_hash = hash_password(password, salt)
if new_hash == pw_hash:
return True
legacy = hashlib.sha256((salt + password).encode()).hexdigest()
return legacy == pw_hash
# ---------------------------------------------------------------------------
# Users
# ---------------------------------------------------------------------------
def create_user(email: str, name: str, password: str, role: str = "user", organization: str | None = None) -> Dict:
conn = get_db()
user_id = str(uuid.uuid4())
salt = os.urandom(16).hex()
pw_hash = hash_password(password, salt)
now = datetime.utcnow().isoformat()
try:
conn.execute(
"INSERT INTO users (id, email, name, password_hash, salt, role, organization, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(user_id, email, name, pw_hash, salt, role, organization, now),
)
conn.commit()
except sqlite3.IntegrityError:
conn.close()
raise ValueError("Email already registered")
user = dict(conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone())
conn.close()
return _sanitize_user(user)
def get_user_by_email(email: str) -> Optional[Dict]:
conn = get_db()
row = conn.execute("SELECT * FROM users WHERE email = ?", (email,)).fetchone()
conn.close()
return dict(row) if row else None
def get_user_by_id(user_id: str) -> Optional[Dict]:
conn = get_db()
row = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
conn.close()
return dict(row) if row else None
def list_users() -> List[Dict]:
conn = get_db()
rows = conn.execute("SELECT * FROM users ORDER BY created_at DESC").fetchall()
conn.close()
return [_sanitize_user(dict(r)) for r in rows]
def delete_user(user_id: str) -> bool:
conn = get_db()
cur = conn.execute("DELETE FROM users WHERE id = ?", (user_id,))
conn.commit()
conn.close()
return cur.rowcount > 0
def update_last_login(user_id: str):
conn = get_db()
conn.execute("UPDATE users SET last_login = ? WHERE id = ?", (datetime.utcnow().isoformat(), user_id))
conn.commit()
conn.close()
def _sanitize_user(user: Dict) -> Dict:
"""Remove password fields from user dict for API responses."""
return {
"id": user["id"],
"email": user["email"],
"name": user["name"],
"role": user["role"],
"organization": user.get("organization"),
"createdAt": user["created_at"],
"lastLogin": user.get("last_login"),
}
# ---------------------------------------------------------------------------
# Sessions
# ---------------------------------------------------------------------------
def create_session(user_id: str) -> str:
conn = get_db()
token = str(uuid.uuid4())
now = datetime.utcnow()
expires = now + timedelta(days=7)
conn.execute(
"INSERT INTO sessions (token, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
(token, user_id, now.isoformat(), expires.isoformat()),
)
conn.commit()
conn.close()
return token
def get_session(token: str) -> Optional[Dict]:
conn = get_db()
row = conn.execute("SELECT * FROM sessions WHERE token = ?", (token,)).fetchone()
conn.close()
if not row:
return None
session = dict(row)
if datetime.fromisoformat(session["expires_at"]) < datetime.utcnow():
delete_session(token)
return None
return session
def delete_session(token: str):
conn = get_db()
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))
conn.commit()
conn.close()
# ---------------------------------------------------------------------------
# Predictions
# ---------------------------------------------------------------------------
def save_prediction(
sample_id: str,
user_id: Optional[str],
organism: str,
status: str,
risk_level: Optional[str],
file_name: Optional[str],
file_size: Optional[int],
results_json: Optional[str],
model_version: Optional[str] = None,
) -> Dict:
conn = get_db()
pred_id = f"pred-{uuid.uuid4().hex[:8]}"
now = datetime.utcnow().isoformat()
completed = now if status == "completed" else None
conn.execute(
"INSERT INTO predictions (id, sample_id, user_id, organism, status, risk_level, file_name, file_size, results_json, created_at, completed_at, model_version) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(pred_id, sample_id, user_id, organism, status, risk_level, file_name, file_size, results_json, now, completed, model_version),
)
conn.commit()
row = conn.execute("SELECT * FROM predictions WHERE id = ?", (pred_id,)).fetchone()
conn.close()
return _format_prediction(dict(row))
def get_prediction(pred_id: str) -> Optional[Dict]:
conn = get_db()
row = conn.execute("SELECT * FROM predictions WHERE id = ?", (pred_id,)).fetchone()
conn.close()
return _format_prediction(dict(row)) if row else None
def list_predictions(
user_id: Optional[str] = None,
organism: Optional[str] = None,
status: Optional[str] = None,
risk: Optional[str] = None,
search: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
) -> List[Dict]:
conn = get_db()
query = "SELECT * FROM predictions WHERE 1=1"
params: List[Any] = []
if user_id:
query += " AND user_id = ?"
params.append(user_id)
if organism:
query += " AND organism = ?"
params.append(organism)
if status:
query += " AND status = ?"
params.append(status)
if risk:
query += " AND risk_level = ?"
params.append(risk)
if search:
query += " AND (sample_id LIKE ? OR organism LIKE ? OR file_name LIKE ?)"
like = f"%{search}%"
params.extend([like, like, like])
if date_from:
query += " AND created_at >= ?"
params.append(date_from)
if date_to:
query += " AND created_at <= ?"
params.append(date_to)
query += " ORDER BY created_at DESC"
rows = conn.execute(query, params).fetchall()
conn.close()
return [_format_prediction(dict(r)) for r in rows]
def delete_prediction(pred_id: str) -> bool:
conn = get_db()
cur = conn.execute("DELETE FROM predictions WHERE id = ?", (pred_id,))
conn.commit()
conn.close()
return cur.rowcount > 0
def get_recent_predictions(limit: int = 5) -> List[Dict]:
conn = get_db()
rows = conn.execute(
"SELECT * FROM predictions ORDER BY created_at DESC LIMIT ?", (limit,)
).fetchall()
conn.close()
return [_format_prediction(dict(r)) for r in rows]
def _format_prediction(p: Dict) -> Dict:
"""Convert DB row to frontend-friendly format."""
results_data = None
if p.get("results_json"):
try:
results_data = json.loads(p["results_json"])
except json.JSONDecodeError:
pass
return {
"id": p["id"],
"sampleId": p["sample_id"],
"organism": p["organism"],
"status": p["status"],
"createdAt": p["created_at"],
"completedAt": p.get("completed_at"),
"uploadedBy": p.get("user_id", ""),
"fileName": p.get("file_name", ""),
"fileSize": p.get("file_size", 0),
"overallRisk": (p.get("risk_level") or "low").lower(),
"results": results_data.get("results") if results_data else None,
"detectedGenes": results_data.get("detectedGenes") if results_data else None,
"summary": results_data.get("summary") if results_data else None,
}
# ---------------------------------------------------------------------------
# Activity log
# ---------------------------------------------------------------------------
def log_activity(user_id: Optional[str], user_name: str, action: str, details: Optional[str] = None):
conn = get_db()
conn.execute(
"INSERT INTO activity_log (user_id, user_name, action, details, timestamp) VALUES (?, ?, ?, ?, ?)",
(user_id, user_name, action, details, datetime.utcnow().isoformat()),
)
conn.commit()
conn.close()
def get_recent_activity(limit: int = 20) -> List[Dict]:
conn = get_db()
rows = conn.execute(
"SELECT * FROM activity_log ORDER BY timestamp DESC LIMIT ?", (limit,)
).fetchall()
conn.close()
return [
{
"userId": r["user_id"],
"userName": r["user_name"],
"action": r["action"],
"details": r["details"],
"timestamp": r["timestamp"],
}
for r in rows
]
# ---------------------------------------------------------------------------
# Dashboard stats
# ---------------------------------------------------------------------------
def get_dashboard_stats() -> Dict:
conn = get_db()
total = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
resistant = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'high'").fetchone()[0]
susceptible = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal')").fetchone()[0]
pending = conn.execute("SELECT COUNT(*) FROM predictions WHERE status IN ('pending', 'processing')").fetchone()[0]
week_ago = (datetime.utcnow() - timedelta(days=7)).isoformat()
two_weeks_ago = (datetime.utcnow() - timedelta(days=14)).isoformat()
this_week = conn.execute("SELECT COUNT(*) FROM predictions WHERE created_at >= ?", (week_ago,)).fetchone()[0]
last_week = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE created_at >= ? AND created_at < ?",
(two_weeks_ago, week_ago),
).fetchone()[0]
weekly_change = this_week - last_week if last_week else 0
this_week_r = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at >= ?", (week_ago,)
).fetchone()[0]
last_week_r = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at >= ? AND created_at < ?",
(two_weeks_ago, week_ago),
).fetchone()[0]
weekly_r_change = this_week_r - last_week_r if last_week_r else 0
conn.close()
return {
"totalPredictions": total,
"resistantCount": resistant,
"susceptibleCount": susceptible,
"pendingCount": pending,
"weeklyChange": {
"predictions": weekly_change,
"resistant": weekly_r_change,
},
}
def get_resistance_overview() -> List[Dict]:
conn = get_db()
resistant = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'high'").fetchone()[0]
moderate = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level = 'moderate'").fetchone()[0]
susceptible = conn.execute("SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal')").fetchone()[0]
conn.close()
return [
{"name": "Resistant", "value": resistant, "color": "#ef4444"},
{"name": "Intermediate", "value": moderate, "color": "#eab308"},
{"name": "Susceptible", "value": susceptible, "color": "#22c55e"},
]
def get_trends() -> List[Dict]:
conn = get_db()
trends = []
for i in range(6, -1, -1):
day = datetime.utcnow() - timedelta(days=i)
day_start = day.strftime("%Y-%m-%dT00:00:00")
day_end = day.strftime("%Y-%m-%dT23:59:59")
label = day.strftime("%b %d")
r = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE risk_level = 'high' AND created_at BETWEEN ? AND ?",
(day_start, day_end),
).fetchone()[0]
s = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE risk_level IN ('low', 'minimal') AND created_at BETWEEN ? AND ?",
(day_start, day_end),
).fetchone()[0]
m = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE risk_level = 'moderate' AND created_at BETWEEN ? AND ?",
(day_start, day_end),
).fetchone()[0]
trends.append({"date": label, "resistant": r, "susceptible": s, "intermediate": m})
conn.close()
return trends
# ---------------------------------------------------------------------------
# Admin stats
# ---------------------------------------------------------------------------
def get_admin_stats() -> Dict:
conn = get_db()
total_users = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
week_ago = (datetime.utcnow() - timedelta(days=7)).isoformat()
active_users = conn.execute(
"SELECT COUNT(*) FROM users WHERE last_login >= ?", (week_ago,)
).fetchone()[0]
total_predictions = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
today_start = datetime.utcnow().strftime("%Y-%m-%dT00:00:00")
predictions_today = conn.execute(
"SELECT COUNT(*) FROM predictions WHERE created_at >= ?", (today_start,)
).fetchone()[0]
# Estimate storage from file sizes
storage_row = conn.execute("SELECT COALESCE(SUM(file_size), 0) FROM predictions").fetchone()
storage_bytes = storage_row[0]
storage_gb = round(storage_bytes / (1024**3), 2)
conn.close()
return {
"totalUsers": total_users,
"activeUsers": active_users,
"totalPredictions": total_predictions,
"predictionsToday": predictions_today,
"storageUsed": storage_gb,
"storageLimit": 10,
}