Ace-119's picture
Define SQLite schema for users, sessions, and feedback
4fd1054
"""
database/db.py
==============
SQLite persistence layer for user accounts and stress analysis sessions.
Tables
------
- ``users`` — username, password_hash, encrypted_history, created_at
- ``sessions`` — per-analysis snapshots linked to users
Thread-safety is handled by using ``check_same_thread=False`` and
relying on SQLite's internal serialisation (WAL mode).
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import time
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Default database path (overridable via env var or constructor arg)
_DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db")
class DatabaseManager:
"""Thin wrapper around a SQLite database for user + session storage.
Parameters
----------
db_path : str
File path for the SQLite database. Use ``":memory:"`` for an
ephemeral in-memory database (useful in tests).
"""
def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None:
self._db_path = db_path
self._conn = sqlite3.connect(
db_path, check_same_thread=False,
)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._create_tables()
# ------------------------------------------------------------------
# Schema
# ------------------------------------------------------------------
def _create_tables(self) -> None:
"""Create tables if they do not exist."""
self._conn.executescript(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
encrypted_history TEXT,
created_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
stress_score REAL NOT NULL,
stress_label TEXT NOT NULL,
temporal_data TEXT NOT NULL,
interventions TEXT NOT NULL,
is_crisis INTEGER NOT NULL DEFAULT 0,
crisis_message TEXT,
matched_triggers TEXT NOT NULL,
attention_weights TEXT NOT NULL,
created_at REAL NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id
ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_created_at
ON sessions(created_at);
"""
)
# ------------------------------------------------------------------
# User CRUD
# ------------------------------------------------------------------
def create_user(
self,
username: str,
password_hash: str,
) -> int:
"""Insert a new user and return their ``id``.
Raises
------
sqlite3.IntegrityError
If the username already exists.
"""
cur = self._conn.execute(
"INSERT INTO users (username, password_hash, encrypted_history, created_at) "
"VALUES (?, ?, NULL, ?)",
(username, password_hash, time.time()),
)
self._conn.commit()
return cur.lastrowid # type: ignore[return-value]
def get_user(self, username: str) -> Optional[dict[str, Any]]:
"""Return a user dict or ``None`` if not found."""
row = self._conn.execute(
"SELECT id, username, password_hash, encrypted_history, created_at "
"FROM users WHERE username = ?",
(username,),
).fetchone()
if row is None:
return None
return dict(row)
def user_exists(self, username: str) -> bool:
"""Check whether a username is already taken."""
row = self._conn.execute(
"SELECT 1 FROM users WHERE username = ?", (username,)
).fetchone()
return row is not None
def update_encrypted_history(
self, username: str, encrypted_history: str
) -> None:
"""Persist the user's updated encrypted temporal history."""
self._conn.execute(
"UPDATE users SET encrypted_history = ? WHERE username = ?",
(encrypted_history, username),
)
self._conn.commit()
# ------------------------------------------------------------------
# Session CRUD
# ------------------------------------------------------------------
def save_session(
self,
username: str,
stress_score: float,
stress_label: str,
temporal_data: dict,
interventions: list[dict],
is_crisis: bool,
crisis_message: Optional[str],
matched_triggers: list[str],
attention_weights: list[float],
) -> int:
"""Persist a single analysis session and return its ``id``."""
user = self.get_user(username)
if user is None:
raise ValueError(f"User '{username}' not found")
cur = self._conn.execute(
"INSERT INTO sessions "
"(user_id, stress_score, stress_label, temporal_data, "
" interventions, is_crisis, crisis_message, matched_triggers, "
" attention_weights, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
user["id"],
stress_score,
stress_label,
json.dumps(temporal_data),
json.dumps(interventions),
int(is_crisis),
crisis_message,
json.dumps(matched_triggers),
json.dumps(attention_weights),
time.time(),
),
)
self._conn.commit()
return cur.lastrowid # type: ignore[return-value]
def get_sessions(
self,
username: str,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
"""Return past sessions for a user, newest first.
Parameters
----------
username : str
The user whose sessions to retrieve.
limit : int
Maximum number of sessions to return.
offset : int
Number of sessions to skip (for pagination).
"""
user = self.get_user(username)
if user is None:
return []
rows = self._conn.execute(
"SELECT id, stress_score, stress_label, temporal_data, "
"interventions, is_crisis, crisis_message, matched_triggers, "
"attention_weights, created_at "
"FROM sessions WHERE user_id = ? "
"ORDER BY created_at DESC LIMIT ? OFFSET ?",
(user["id"], limit, offset),
).fetchall()
sessions = []
for row in rows:
session = dict(row)
session["temporal_data"] = json.loads(session["temporal_data"])
session["interventions"] = json.loads(session["interventions"])
session["is_crisis"] = bool(session["is_crisis"])
session["matched_triggers"] = json.loads(session["matched_triggers"])
session["attention_weights"] = json.loads(session["attention_weights"])
sessions.append(session)
return sessions
def get_session_count(self, username: str) -> int:
"""Return the total number of sessions for a user."""
user = self.get_user(username)
if user is None:
return 0
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM sessions WHERE user_id = ?",
(user["id"],),
).fetchone()
return row["cnt"]
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def close(self) -> None:
"""Close the database connection."""
self._conn.close()