""" Database Module - Turso (libSQL) Connection Handles all database operations for the CBT Companion app """ import os import libsql_experimental as libsql from datetime import datetime, date import uuid import json class Database: """ Database handler for Turso (libSQL). Required environment variables: - TURSO_DATABASE_URL: Your Turso database URL - TURSO_AUTH_TOKEN: Your Turso auth token """ def __init__(self): url = os.environ.get("TURSO_DATABASE_URL") token = os.environ.get("TURSO_AUTH_TOKEN") if not url or not token: raise ValueError("TURSO_DATABASE_URL and TURSO_AUTH_TOKEN are required") self.conn = libsql.connect(database=url, auth_token=token) self._init_tables() print("✅ Database connected to Turso") def _init_tables(self): """Create tables if they don't exist.""" # Users table self.conn.execute(""" CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, name TEXT NOT NULL, context TEXT DEFAULT 'person', created_at TEXT DEFAULT CURRENT_TIMESTAMP, fcm_token TEXT ) """) # Add fcm_token column to existing deployments that lack it try: self.conn.execute("ALTER TABLE users ADD COLUMN fcm_token TEXT") self.conn.commit() except Exception: pass # Column already exists # Add last_fcm_sent column for cooldown tracking try: self.conn.execute("ALTER TABLE users ADD COLUMN last_fcm_sent TEXT") self.conn.commit() except Exception: pass # Column already exists # Sessions table self.conn.execute(""" CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, started_at TEXT DEFAULT CURRENT_TIMESTAMP, ended_at TEXT, mood_start INTEGER, mood_end INTEGER, locked_group TEXT, stages_reached INTEGER DEFAULT 1, completed INTEGER DEFAULT 0, messages_in_current_stage INTEGER DEFAULT 0, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Messages table self.conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, timestamp TEXT DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (session_id) REFERENCES sessions(id), FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Exercises completed self.conn.execute(""" CREATE TABLE IF NOT EXISTS exercises_completed ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, session_id TEXT, exercise_id TEXT NOT NULL, exercise_name TEXT, group_type TEXT, completed_at TEXT DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Crisis flags self.conn.execute(""" CREATE TABLE IF NOT EXISTS crisis_flags ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, user_name TEXT NOT NULL, user_email TEXT NOT NULL, session_id TEXT NOT NULL, message_content TEXT NOT NULL, trigger_word TEXT NOT NULL, flagged_at TEXT DEFAULT CURRENT_TIMESTAMP, reviewed INTEGER DEFAULT 0, FOREIGN KEY (user_id) REFERENCES users(id), FOREIGN KEY (session_id) REFERENCES sessions(id) ) """) # User stats self.conn.execute(""" CREATE TABLE IF NOT EXISTS user_stats ( user_id TEXT PRIMARY KEY, total_sessions INTEGER DEFAULT 0, total_exercises INTEGER DEFAULT 0, current_streak INTEGER DEFAULT 0, last_session_date TEXT, distortion_counts TEXT DEFAULT '{}', FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Wearable sensor data self.conn.execute(""" CREATE TABLE IF NOT EXISTS wearable_data ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, ppg REAL NOT NULL, gsr REAL NOT NULL, acc_x REAL NOT NULL, acc_y REAL NOT NULL, acc_z REAL NOT NULL, device_timestamp TEXT, recorded_at TEXT DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Migrate: add columns that may be missing from older table versions for col, col_type, default in [ ("dri_score", "REAL", None), ("condition", "TEXT", None), ("acknowledged", "INTEGER", "0"), ("ml_prediction", "TEXT", None), ("ml_confidence", "REAL", None), ("risk_level", "INTEGER", None), ]: try: default_clause = f" DEFAULT {default}" if default is not None else "" self.conn.execute( f"ALTER TABLE wearable_data ADD COLUMN {col} {col_type}{default_clause}" ) except Exception: pass # Column already exists # Create index for faster queries on wearable data self.conn.execute(""" CREATE INDEX IF NOT EXISTS idx_wearable_user_time ON wearable_data(user_id, recorded_at DESC) """) # Device API keys for wearables self.conn.execute(""" CREATE TABLE IF NOT EXISTS device_keys ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, api_key TEXT UNIQUE NOT NULL, device_name TEXT DEFAULT 'ESP32 Wearable', created_at TEXT DEFAULT CURRENT_TIMESTAMP, last_used_at TEXT, is_active INTEGER DEFAULT 1, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Index for fast API key lookups self.conn.execute(""" CREATE INDEX IF NOT EXISTS idx_device_api_key ON device_keys(api_key) """) # Beck sessions table for cognitive restructuring protocol self.conn.execute(""" CREATE TABLE IF NOT EXISTS beck_sessions ( session_id TEXT PRIMARY KEY, -- Current state in the protocol beck_state TEXT DEFAULT 'VALIDATE', -- Phase 1: Capture original_thought TEXT, initial_belief_rating INTEGER, emotion TEXT, initial_emotion_intensity INTEGER, -- Phase 2: Discovery (6 Questions) q1_evidence_for TEXT, q1_evidence_against TEXT, q2_alternative TEXT, q3_worst TEXT, q3_best TEXT, q3_realistic TEXT, q4_effect TEXT, q5_friend TEXT, q6_action TEXT, -- Phase 3: Reframe adaptive_thought TEXT, new_thought_belief INTEGER, -- Phase 4: Measure final_belief_rating INTEGER, final_emotion_intensity INTEGER, -- Phase 5: Action action_plan TEXT, -- Metadata started_at TEXT DEFAULT CURRENT_TIMESTAMP, completed_at TEXT, belief_improvement INTEGER, emotion_improvement INTEGER, FOREIGN KEY (session_id) REFERENCES sessions(id) ) """) # Depression episodes table (ML-detected high stress periods) self.conn.execute(""" CREATE TABLE IF NOT EXISTS depression_episodes ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, start_time TEXT NOT NULL, end_time TEXT, peak_risk_level INTEGER, total_readings INTEGER DEFAULT 0, avg_confidence REAL, is_active INTEGER DEFAULT 1, created_at TEXT DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Index for fast episode lookups self.conn.execute(""" CREATE INDEX IF NOT EXISTS idx_depression_episodes_user ON depression_episodes(user_id, is_active DESC) """) self.conn.commit() # ==================== USER OPERATIONS ==================== def create_user(self, email: str, password_hash: str, name: str, context: str = "person") -> dict: """Create a new user.""" user_id = str(uuid.uuid4()) try: self.conn.execute( "INSERT INTO users (id, email, password_hash, name, context) VALUES (?, ?, ?, ?, ?)", (user_id, email.lower(), password_hash, name, context) ) # Initialize user stats self.conn.execute( "INSERT INTO user_stats (user_id) VALUES (?)", (user_id,) ) self.conn.commit() return {"id": user_id, "email": email, "name": name, "context": context} except Exception as e: if "UNIQUE constraint" in str(e): raise ValueError("Email already exists") raise e def get_user_by_email(self, email: str) -> dict: """Get user by email.""" result = self.conn.execute( "SELECT id, email, password_hash, name, context FROM users WHERE email = ?", (email.lower(),) ).fetchone() if result: return { "id": result[0], "email": result[1], "password_hash": result[2], "name": result[3], "context": result[4] } return None def get_user_by_id(self, user_id: str) -> dict: """Get user by ID.""" result = self.conn.execute( "SELECT id, email, name, context FROM users WHERE id = ?", (user_id,) ).fetchone() if result: return { "id": result[0], "email": result[1], "name": result[2], "context": result[3] } return None def save_fcm_token(self, user_id: str, token: str): """Save or update FCM device token for a user.""" self.conn.execute( "UPDATE users SET fcm_token = ? WHERE id = ?", (token, user_id) ) self.conn.commit() def get_fcm_token(self, user_id: str) -> str: """Get the FCM token for a user, or None.""" result = self.conn.execute( "SELECT fcm_token FROM users WHERE id = ?", (user_id,) ).fetchone() return result[0] if result else None def fcm_cooldown_ok(self, user_id: str, cooldown_minutes: int = 30) -> bool: """Return True if enough time has passed since the last FCM push.""" result = self.conn.execute( "SELECT last_fcm_sent FROM users WHERE id = ?", (user_id,) ).fetchone() if not result or not result[0]: return True try: last = datetime.fromisoformat(result[0]) return (datetime.utcnow() - last).total_seconds() >= cooldown_minutes * 60 except Exception: return True def update_fcm_sent_time(self, user_id: str): """Record the current UTC time as the last FCM send for cooldown tracking.""" self.conn.execute( "UPDATE users SET last_fcm_sent = ? WHERE id = ?", (datetime.utcnow().isoformat(), user_id) ) self.conn.commit() # ==================== SESSION OPERATIONS ==================== def create_session(self, user_id: str) -> str: """Create a new chat session.""" session_id = str(uuid.uuid4()) self.conn.execute( "INSERT INTO sessions (id, user_id) VALUES (?, ?)", (session_id, user_id) ) self.conn.commit() return session_id def get_session(self, session_id: str) -> dict: """Get session by ID.""" result = self.conn.execute( """SELECT id, user_id, started_at, ended_at, mood_start, mood_end, locked_group, stages_reached, completed, messages_in_current_stage FROM sessions WHERE id = ?""", (session_id,) ).fetchone() if result: return { "id": result[0], "user_id": result[1], "started_at": result[2], "ended_at": result[3], "mood_start": result[4], "mood_end": result[5], "locked_group": result[6], "current_stage": result[7], "completed": bool(result[8]), "messages_in_current_stage": result[9] or 0, "state": "COMPLETED" if result[8] else ("IN_PROGRESS" if result[6] else "WAITING_FOR_PROBLEM") } return None def update_session(self, session_id: str, **kwargs): """Update session fields.""" allowed_fields = { "mood_start", "mood_end", "locked_group", "stages_reached", "completed", "ended_at", "messages_in_current_stage" } updates = {k: v for k, v in kwargs.items() if k in allowed_fields} if not updates: return # Map stages_reached to the column if "current_stage" in kwargs: updates["stages_reached"] = kwargs["current_stage"] set_clause = ", ".join(f"{k} = ?" for k in updates.keys()) values = tuple(list(updates.values()) + [session_id]) self.conn.execute( f"UPDATE sessions SET {set_clause} WHERE id = ?", values ) self.conn.commit() def increment_stage_messages(self, session_id: str) -> int: """Increment message counter for current stage.""" self.conn.execute( "UPDATE sessions SET messages_in_current_stage = messages_in_current_stage + 1 WHERE id = ?", (session_id,) ) self.conn.commit() result = self.conn.execute( "SELECT messages_in_current_stage FROM sessions WHERE id = ?", (session_id,) ).fetchone() return result[0] if result else 0 def reset_stage_messages(self, session_id: str): """Reset message counter (called when advancing stage).""" self.conn.execute( "UPDATE sessions SET messages_in_current_stage = 0 WHERE id = ?", (session_id,) ) self.conn.commit() def get_user_sessions(self, user_id: str, limit: int = 20) -> list: """Get user's past sessions.""" results = self.conn.execute( """SELECT id, started_at, ended_at, mood_start, mood_end, locked_group, stages_reached, completed FROM sessions WHERE user_id = ? ORDER BY started_at DESC LIMIT ?""", (user_id, limit) ).fetchall() return [ { "id": r[0], "started_at": r[1], "ended_at": r[2], "mood_start": r[3], "mood_end": r[4], "locked_group": r[5], "stages_reached": r[6], "completed": bool(r[7]) } for r in results ] # ==================== MESSAGE OPERATIONS ==================== def add_message(self, session_id: str, user_id: str, role: str, content: str): """Add a message to conversation history.""" message_id = str(uuid.uuid4()) self.conn.execute( "INSERT INTO messages (id, session_id, user_id, role, content) VALUES (?, ?, ?, ?, ?)", (message_id, session_id, user_id, role, content) ) self.conn.commit() def get_session_messages(self, session_id: str, limit: int = None) -> list: """Get messages for a session.""" query = """SELECT role, content, timestamp FROM messages WHERE session_id = ? ORDER BY timestamp ASC""" if limit: query += f" LIMIT {limit}" results = self.conn.execute(query, (session_id,)).fetchall() return [ {"role": r[0], "content": r[1], "timestamp": r[2]} for r in results ] def get_recent_messages(self, session_id: str, n: int = 6) -> list: """Get last n messages for context window.""" results = self.conn.execute( """SELECT role, content FROM messages WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?""", (session_id, n) ).fetchall() # Reverse to get chronological order return [{"role": r[0], "content": r[1]} for r in reversed(results)] # ==================== EXERCISE OPERATIONS ==================== def log_exercise_completed(self, user_id: str, session_id: str, exercise_id: str, exercise_name: str, group_type: str): """Log a completed exercise.""" entry_id = str(uuid.uuid4()) self.conn.execute( """INSERT INTO exercises_completed (id, user_id, session_id, exercise_id, exercise_name, group_type) VALUES (?, ?, ?, ?, ?, ?)""", (entry_id, user_id, session_id, exercise_id, exercise_name, group_type) ) # Update user stats self.conn.execute( "UPDATE user_stats SET total_exercises = total_exercises + 1 WHERE user_id = ?", (user_id,) ) self.conn.commit() # ==================== CRISIS FLAG OPERATIONS ==================== def flag_crisis(self, user_id: str, user_name: str, user_email: str, session_id: str, message_content: str, trigger_word: str): """Flag a crisis message.""" flag_id = str(uuid.uuid4()) self.conn.execute( """INSERT INTO crisis_flags (id, user_id, user_name, user_email, session_id, message_content, trigger_word) VALUES (?, ?, ?, ?, ?, ?, ?)""", (flag_id, user_id, user_name, user_email, session_id, message_content, trigger_word) ) self.conn.commit() print(f"🚨 CRISIS FLAGGED: User {user_name} ({user_email}) - Trigger: {trigger_word}") # ==================== STATS OPERATIONS ==================== def get_user_stats(self, user_id: str) -> dict: """Get user statistics.""" result = self.conn.execute( """SELECT total_sessions, total_exercises, current_streak, last_session_date, distortion_counts FROM user_stats WHERE user_id = ?""", (user_id,) ).fetchone() if result: return { "total_sessions": result[0], "total_exercises": result[1], "current_streak": result[2], "last_session_date": result[3], "distortion_counts": json.loads(result[4] or "{}") } return { "total_sessions": 0, "total_exercises": 0, "current_streak": 0, "last_session_date": None, "distortion_counts": {} } def update_user_stats_on_session_end(self, user_id: str, locked_group: str): """Update user stats when a session ends.""" today = date.today().isoformat() # Get current stats stats = self.get_user_stats(user_id) # Update distortion counts distortion_counts = stats["distortion_counts"] if locked_group and locked_group != "G0": distortion_counts[locked_group] = distortion_counts.get(locked_group, 0) + 1 # Calculate streak last_date = stats["last_session_date"] if last_date == today: new_streak = stats["current_streak"] elif last_date == (date.today().replace(day=date.today().day - 1)).isoformat(): new_streak = stats["current_streak"] + 1 else: new_streak = 1 # Update stats self.conn.execute( """UPDATE user_stats SET total_sessions = total_sessions + 1, current_streak = ?, last_session_date = ?, distortion_counts = ? WHERE user_id = ?""", (new_streak, today, json.dumps(distortion_counts), user_id) ) self.conn.commit() # ==================== WEARABLE DATA OPERATIONS ==================== def save_wearable_data(self, user_id: str, ppg: float, gsr: float, acc_x: float, acc_y: float, acc_z: float, device_timestamp: str = None) -> str: """Save wearable sensor data.""" record_id = str(uuid.uuid4()) self.conn.execute( """INSERT INTO wearable_data (id, user_id, ppg, gsr, acc_x, acc_y, acc_z, device_timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", (record_id, user_id, ppg, gsr, acc_x, acc_y, acc_z, device_timestamp) ) self.conn.commit() return record_id def get_latest_wearable_data(self, user_id: str) -> dict: """Get the most recent wearable data for a user.""" result = self.conn.execute( """SELECT id, ppg, gsr, acc_x, acc_y, acc_z, device_timestamp, recorded_at FROM wearable_data WHERE user_id = ? ORDER BY recorded_at DESC LIMIT 1""", (user_id,) ).fetchone() if result: return { "id": result[0], "ppg": result[1], "gsr": result[2], "acc_x": result[3], "acc_y": result[4], "acc_z": result[5], "device_timestamp": result[6], "recorded_at": result[7] } return None def get_wearable_history(self, user_id: str, limit: int = 100, offset: int = 0, start_date: str = None, end_date: str = None) -> list: """Get wearable data history for a user.""" query = """SELECT id, ppg, gsr, acc_x, acc_y, acc_z, device_timestamp, recorded_at FROM wearable_data WHERE user_id = ?""" params = [user_id] if start_date: query += " AND recorded_at >= ?" params.append(start_date) if end_date: query += " AND recorded_at <= ?" params.append(end_date) query += " ORDER BY recorded_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) results = self.conn.execute(query, tuple(params)).fetchall() return [ { "id": r[0], "ppg": r[1], "gsr": r[2], "acc_x": r[3], "acc_y": r[4], "acc_z": r[5], "device_timestamp": r[6], "recorded_at": r[7] } for r in results ] def get_wearable_stats(self, user_id: str, period: str = "day") -> dict: """Get aggregated statistics for wearable data.""" # Calculate date filter based on period if period == "day": date_filter = "datetime('now', '-1 day')" elif period == "week": date_filter = "datetime('now', '-7 days')" else: # month date_filter = "datetime('now', '-30 days')" result = self.conn.execute( f"""SELECT COUNT(*) as count, AVG(ppg) as avg_ppg, MIN(ppg) as min_ppg, MAX(ppg) as max_ppg, AVG(gsr) as avg_gsr, MIN(gsr) as min_gsr, MAX(gsr) as max_gsr, AVG(acc_x) as avg_acc_x, AVG(acc_y) as avg_acc_y, AVG(acc_z) as avg_acc_z FROM wearable_data WHERE user_id = ? AND recorded_at >= {date_filter}""", (user_id,) ).fetchone() if result and result[0] > 0: return { "reading_count": result[0], "ppg": { "avg": round(result[1], 2) if result[1] else None, "min": round(result[2], 2) if result[2] else None, "max": round(result[3], 2) if result[3] else None }, "gsr": { "avg": round(result[4], 4) if result[4] else None, "min": round(result[5], 4) if result[5] else None, "max": round(result[6], 4) if result[6] else None }, "accelerometer": { "avg_x": round(result[7], 4) if result[7] else None, "avg_y": round(result[8], 4) if result[8] else None, "avg_z": round(result[9], 4) if result[9] else None } } return { "reading_count": 0, "ppg": {"avg": None, "min": None, "max": None}, "gsr": {"avg": None, "min": None, "max": None}, "accelerometer": {"avg_x": None, "avg_y": None, "avg_z": None} } # ==================== DEVICE KEY OPERATIONS ==================== def create_device_key(self, user_id: str, device_name: str = "ESP32 Wearable") -> dict: """Create a new device API key for a user.""" import secrets key_id = str(uuid.uuid4()) # Generate a secure 32-character API key api_key = secrets.token_hex(16) self.conn.execute( """INSERT INTO device_keys (id, user_id, api_key, device_name) VALUES (?, ?, ?, ?)""", (key_id, user_id, api_key, device_name) ) self.conn.commit() return { "id": key_id, "api_key": api_key, "device_name": device_name, "created_at": datetime.utcnow().isoformat() } def get_user_by_device_key(self, api_key: str) -> dict: """Get user associated with a device API key.""" result = self.conn.execute( """SELECT u.id, u.email, u.name, u.context, dk.id as device_id FROM device_keys dk JOIN users u ON dk.user_id = u.id WHERE dk.api_key = ? AND dk.is_active = 1""", (api_key,) ).fetchone() if result: # Update last_used_at self.conn.execute( "UPDATE device_keys SET last_used_at = ? WHERE api_key = ?", (datetime.utcnow().isoformat(), api_key) ) self.conn.commit() return { "id": result[0], "email": result[1], "name": result[2], "context": result[3], "device_id": result[4] } return None def get_user_device_keys(self, user_id: str) -> list: """Get all device keys for a user.""" results = self.conn.execute( """SELECT id, api_key, device_name, created_at, last_used_at, is_active FROM device_keys WHERE user_id = ? ORDER BY created_at DESC""", (user_id,) ).fetchall() return [ { "id": r[0], "api_key": r[1][:8] + "..." + r[1][-4:], # Masked for security "device_name": r[2], "created_at": r[3], "last_used_at": r[4], "is_active": bool(r[5]) } for r in results ] def revoke_device_key(self, key_id: str, user_id: str) -> bool: """Revoke a device API key.""" result = self.conn.execute( "UPDATE device_keys SET is_active = 0 WHERE id = ? AND user_id = ?", (key_id, user_id) ) self.conn.commit() return result.rowcount > 0 def delete_device_key(self, key_id: str, user_id: str) -> bool: """Permanently delete a device API key.""" result = self.conn.execute( "DELETE FROM device_keys WHERE id = ? AND user_id = ?", (key_id, user_id) ) self.conn.commit() return result.rowcount > 0 # ==================== ADMIN OPERATIONS ==================== def get_all_users(self) -> list: """Get all users with basic info for admin panel.""" results = self.conn.execute( """SELECT u.id, u.email, u.name, u.context, u.created_at, s.total_sessions, s.total_exercises, s.current_streak, s.last_session_date FROM users u LEFT JOIN user_stats s ON u.id = s.user_id ORDER BY u.created_at DESC""" ).fetchall() users = [] for r in results: # Count crisis flags for this user flag_count = self.conn.execute( "SELECT COUNT(*) FROM crisis_flags WHERE user_id = ? AND reviewed = 0", (r[0],) ).fetchone()[0] users.append({ "id": r[0], "email": r[1], "name": r[2], "context": r[3], "created_at": r[4], "total_sessions": r[5] or 0, "total_exercises": r[6] or 0, "current_streak": r[7] or 0, "last_session_date": r[8], "unreviewed_alerts": flag_count }) return users def get_user_full_details(self, user_id: str) -> dict: """Get complete user data for admin patient detail view.""" user = self.get_user_by_id(user_id) if not user: return None # Get user stats stats = self.get_user_stats(user_id) # Get created_at created_at = self.conn.execute( "SELECT created_at FROM users WHERE id = ?", (user_id,) ).fetchone() # Get sessions sessions = self.get_user_sessions(user_id, limit=50) # Get crisis history crisis_flags = self.conn.execute( """SELECT id, session_id, message_content, trigger_word, flagged_at, reviewed FROM crisis_flags WHERE user_id = ? ORDER BY flagged_at DESC""", (user_id,) ).fetchall() crisis_history = [ { "id": r[0], "session_id": r[1], "message_content": r[2], "trigger_word": r[3], "flagged_at": r[4], "reviewed": bool(r[5]) } for r in crisis_flags ] # Get latest wearable data latest_wearable = self.get_latest_wearable_data(user_id) return { **user, "created_at": created_at[0] if created_at else None, "stats": stats, "sessions": sessions, "crisis_history": crisis_history, "latest_wearable": latest_wearable } def get_all_crisis_flags(self, reviewed: bool = None) -> list: """Get all crisis flags, optionally filtered by reviewed status.""" query = """SELECT cf.id, cf.user_id, cf.user_name, cf.user_email, cf.session_id, cf.message_content, cf.trigger_word, cf.flagged_at, cf.reviewed FROM crisis_flags cf ORDER BY cf.flagged_at DESC""" if reviewed is not None: query = """SELECT cf.id, cf.user_id, cf.user_name, cf.user_email, cf.session_id, cf.message_content, cf.trigger_word, cf.flagged_at, cf.reviewed FROM crisis_flags cf WHERE cf.reviewed = ? ORDER BY cf.flagged_at DESC""" results = self.conn.execute(query, (1 if reviewed else 0,)).fetchall() else: results = self.conn.execute(query).fetchall() return [ { "id": r[0], "user_id": r[1], "user_name": r[2], "user_email": r[3], "session_id": r[4], "message_content": r[5], "trigger_word": r[6], "flagged_at": r[7], "reviewed": bool(r[8]) } for r in results ] def mark_crisis_reviewed(self, flag_id: str) -> bool: """Mark a crisis flag as reviewed.""" self.conn.execute( "UPDATE crisis_flags SET reviewed = 1 WHERE id = ?", (flag_id,) ) self.conn.commit() return True def get_dashboard_stats(self) -> dict: """Get overview statistics for admin dashboard.""" # Total users total_users = self.conn.execute( "SELECT COUNT(*) FROM users" ).fetchone()[0] # Sessions today sessions_today = self.conn.execute( """SELECT COUNT(*) FROM sessions WHERE DATE(started_at) = DATE('now')""" ).fetchone()[0] # Unreviewed crisis flags unreviewed_alerts = self.conn.execute( "SELECT COUNT(*) FROM crisis_flags WHERE reviewed = 0" ).fetchone()[0] # Average mood improvement (mood_end - mood_start for completed sessions) mood_result = self.conn.execute( """SELECT AVG(mood_end - mood_start) FROM sessions WHERE mood_start IS NOT NULL AND mood_end IS NOT NULL AND completed = 1""" ).fetchone() avg_mood_change = round(mood_result[0], 2) if mood_result[0] else 0 # Total sessions total_sessions = self.conn.execute( "SELECT COUNT(*) FROM sessions" ).fetchone()[0] # Completed sessions completed_sessions = self.conn.execute( "SELECT COUNT(*) FROM sessions WHERE completed = 1" ).fetchone()[0] return { "total_users": total_users, "sessions_today": sessions_today, "unreviewed_alerts": unreviewed_alerts, "avg_mood_change": avg_mood_change, "total_sessions": total_sessions, "completed_sessions": completed_sessions } def get_user_wearable_summary(self, user_id: str) -> dict: """Get vitals summary for a patient.""" # Get stats for day, week, month day_stats = self.get_wearable_stats(user_id, "day") week_stats = self.get_wearable_stats(user_id, "week") month_stats = self.get_wearable_stats(user_id, "month") # Get latest reading latest = self.get_latest_wearable_data(user_id) return { "latest": latest, "day": day_stats, "week": week_stats, "month": month_stats } def get_wearable_timeseries(self, user_id: str, hours: int = 24) -> list: """Get time-series wearable data for charts.""" results = self.conn.execute( """SELECT ppg, gsr, acc_x, acc_y, acc_z, recorded_at FROM wearable_data WHERE user_id = ? AND recorded_at >= datetime('now', ? || ' hours') ORDER BY recorded_at ASC""", (user_id, -hours) ).fetchall() return [ { "ppg": r[0], "gsr": r[1], "acc_x": r[2], "acc_y": r[3], "acc_z": r[4], "recorded_at": r[5] } for r in results ] def get_daily_session_counts(self, days: int = 30) -> list: """Get daily session counts for trend chart.""" results = self.conn.execute( """SELECT DATE(started_at) as day, COUNT(*) as count FROM sessions WHERE started_at >= datetime('now', ? || ' days') GROUP BY DATE(started_at) ORDER BY day ASC""", (-days,) ).fetchall() return [{"date": r[0], "count": r[1]} for r in results] def get_distortion_distribution(self) -> dict: """Get aggregate distortion group distribution.""" results = self.conn.execute( """SELECT locked_group, COUNT(*) as count FROM sessions WHERE locked_group IS NOT NULL AND locked_group != 'G0' GROUP BY locked_group""" ).fetchall() distribution = {"G1": 0, "G2": 0, "G3": 0, "G4": 0} for r in results: if r[0] in distribution: distribution[r[0]] = r[1] return distribution def get_user_mood_history(self, user_id: str, limit: int = 20) -> list: """Get mood history for a user's sessions.""" results = self.conn.execute( """SELECT started_at, mood_start, mood_end, locked_group, completed FROM sessions WHERE user_id = ? AND mood_start IS NOT NULL ORDER BY started_at DESC LIMIT ?""", (user_id, limit) ).fetchall() return [ { "date": r[0], "mood_start": r[1], "mood_end": r[2], "locked_group": r[3], "completed": bool(r[4]) } for r in reversed(results) # Chronological order ] def get_user_distortion_pattern(self, user_id: str) -> dict: """Get distortion pattern for radar chart.""" stats = self.get_user_stats(user_id) counts = stats.get("distortion_counts", {}) return { "G1": counts.get("G1", 0), "G2": counts.get("G2", 0), "G3": counts.get("G3", 0), "G4": counts.get("G4", 0) } # ==================== BECK SESSION OPERATIONS ==================== def create_beck_session(self, session_id: str) -> dict: """Initialize a Beck session when distortion is detected.""" try: self.conn.execute( """INSERT INTO beck_sessions (session_id, beck_state) VALUES (?, 'VALIDATE')""", (session_id,) ) self.conn.commit() return {"session_id": session_id, "beck_state": "VALIDATE"} except Exception as e: print(f"Error creating Beck session: {e}") return None def get_beck_session(self, session_id: str) -> dict: """Get current Beck session state and all data.""" result = self.conn.execute( """SELECT session_id, beck_state, original_thought, initial_belief_rating, emotion, initial_emotion_intensity, q1_evidence_for, q1_evidence_against, q2_alternative, q3_worst, q3_best, q3_realistic, q4_effect, q5_friend, q6_action, adaptive_thought, new_thought_belief, final_belief_rating, final_emotion_intensity, action_plan, started_at, completed_at, belief_improvement, emotion_improvement FROM beck_sessions WHERE session_id = ?""", (session_id,) ).fetchone() if result: return { "session_id": result[0], "beck_state": result[1], "original_thought": result[2], "initial_belief_rating": result[3], "emotion": result[4], "initial_emotion_intensity": result[5], "q1_evidence_for": result[6], "q1_evidence_against": result[7], "q2_alternative": result[8], "q3_worst": result[9], "q3_best": result[10], "q3_realistic": result[11], "q4_effect": result[12], "q5_friend": result[13], "q6_action": result[14], "adaptive_thought": result[15], "new_thought_belief": result[16], "final_belief_rating": result[17], "final_emotion_intensity": result[18], "action_plan": result[19], "started_at": result[20], "completed_at": result[21], "belief_improvement": result[22], "emotion_improvement": result[23] } return None def update_beck_state(self, session_id: str, new_state: str, **fields): """Update state and save any new field values.""" # Start with state update updates = {"beck_state": new_state} updates.update(fields) set_clause = ", ".join(f"{k} = ?" for k in updates.keys()) values = tuple(list(updates.values()) + [session_id]) self.conn.execute( f"UPDATE beck_sessions SET {set_clause} WHERE session_id = ?", values ) self.conn.commit() def complete_beck_session(self, session_id: str): """Mark session complete, calculate improvements.""" # Get current data beck_data = self.get_beck_session(session_id) if not beck_data: return # Calculate improvements belief_improvement = None emotion_improvement = None if beck_data.get('initial_belief_rating') and beck_data.get('final_belief_rating'): belief_improvement = beck_data['initial_belief_rating'] - beck_data['final_belief_rating'] if beck_data.get('initial_emotion_intensity') and beck_data.get('final_emotion_intensity'): emotion_improvement = beck_data['initial_emotion_intensity'] - beck_data['final_emotion_intensity'] # Update completion status self.conn.execute( """UPDATE beck_sessions SET completed_at = CURRENT_TIMESTAMP, belief_improvement = ?, emotion_improvement = ? WHERE session_id = ?""", (belief_improvement, emotion_improvement, session_id) ) self.conn.commit() # ==================== ML INFERENCE & DEPRESSION TRACKING ==================== def get_recent_readings_for_ml(self, user_id: str, limit: int = 50) -> list: """ Get most recent sensor readings for ML inference. Returns data ordered from OLDEST to NEWEST (required for time-series analysis). """ results = self.conn.execute( """SELECT ppg, gsr, acc_x, acc_y, acc_z, recorded_at FROM wearable_data WHERE user_id = ? ORDER BY recorded_at DESC LIMIT ?""", (user_id, limit) ).fetchall() # Reverse to get oldest-to-newest order readings = [ { "ppg": r[0], "gsr": r[1], "acc_x": r[2], "acc_y": r[3], "acc_z": r[4], "timestamp": r[5] } for r in reversed(results) ] return readings def update_ml_prediction(self, record_id: str, prediction: str, confidence: float, risk_level: int): """Update a wearable data record with ML prediction results.""" condition = prediction # prediction is already "NORMAL", "MILD_STRESS", or "HIGH_STRESS" self.conn.execute( """UPDATE wearable_data SET ml_prediction = ?, ml_confidence = ?, risk_level = ?, condition = ?, dri_score = ? WHERE id = ?""", (prediction, confidence, risk_level, condition, confidence, record_id) ) self.conn.commit() def get_active_depression_episode(self, user_id: str) -> dict: """Get the currently active depression episode for a user, if any.""" result = self.conn.execute( """SELECT id, user_id, start_time, peak_risk_level, total_readings, avg_confidence FROM depression_episodes WHERE user_id = ? AND is_active = 1 ORDER BY start_time DESC LIMIT 1""", (user_id,) ).fetchone() if result: return { "id": result[0], "user_id": result[1], "start_time": result[2], "peak_risk_level": result[3], "total_readings": result[4], "avg_confidence": result[5] } return None def start_depression_episode(self, user_id: str, risk_level: int, confidence: float) -> str: """Start a new depression episode.""" episode_id = str(uuid.uuid4()) self.conn.execute( """INSERT INTO depression_episodes (id, user_id, start_time, peak_risk_level, total_readings, avg_confidence, is_active) VALUES (?, ?, CURRENT_TIMESTAMP, ?, 1, ?, 1)""", (episode_id, user_id, risk_level, confidence) ) self.conn.commit() return episode_id def update_depression_episode(self, episode_id: str, risk_level: int, confidence: float): """Update an active depression episode with new reading.""" # Get current stats result = self.conn.execute( """SELECT total_readings, avg_confidence, peak_risk_level FROM depression_episodes WHERE id = ?""", (episode_id,) ).fetchone() if result: total_readings = result[0] avg_confidence = result[1] peak_risk = result[2] # Update rolling average confidence new_total = total_readings + 1 new_avg_confidence = ((avg_confidence * total_readings) + confidence) / new_total # Update peak risk level new_peak_risk = max(peak_risk, risk_level) self.conn.execute( """UPDATE depression_episodes SET total_readings = ?, avg_confidence = ?, peak_risk_level = ? WHERE id = ?""", (new_total, new_avg_confidence, new_peak_risk, episode_id) ) self.conn.commit() def end_depression_episode(self, episode_id: str): """Mark a depression episode as ended.""" self.conn.execute( """UPDATE depression_episodes SET end_time = CURRENT_TIMESTAMP, is_active = 0 WHERE id = ?""", (episode_id,) ) self.conn.commit() def get_user_depression_stats(self, user_id: str) -> dict: """Get depression episode statistics for a user.""" # Total episodes total_result = self.conn.execute( """SELECT COUNT(*) FROM depression_episodes WHERE user_id = ?""", (user_id,) ).fetchone() # Active episode active_result = self.conn.execute( """SELECT COUNT(*) FROM depression_episodes WHERE user_id = ? AND is_active = 1""", (user_id,) ).fetchone() # Last 7 days episodes recent_result = self.conn.execute( """SELECT COUNT(*) FROM depression_episodes WHERE user_id = ? AND start_time >= datetime('now', '-7 days')""", (user_id,) ).fetchone() # Peak risk level in last 7 days peak_result = self.conn.execute( """SELECT MAX(peak_risk_level) FROM depression_episodes WHERE user_id = ? AND start_time >= datetime('now', '-7 days')""", (user_id,) ).fetchone() # Latest ML prediction latest_pred = self.conn.execute( """SELECT ml_prediction, ml_confidence, risk_level, recorded_at FROM wearable_data WHERE user_id = ? AND ml_prediction IS NOT NULL ORDER BY recorded_at DESC LIMIT 1""", (user_id,) ).fetchone() return { "total_episodes": total_result[0] if total_result else 0, "has_active_episode": (active_result[0] > 0) if active_result else False, "episodes_last_7_days": recent_result[0] if recent_result else 0, "peak_risk_last_7_days": peak_result[0] if (peak_result and peak_result[0]) else 0, "latest_prediction": { "prediction": latest_pred[0], "confidence": latest_pred[1], "risk_level": latest_pred[2], "timestamp": latest_pred[3] } if latest_pred else None } def get_all_depression_episodes(self, user_id: str, limit: int = 50) -> list: """Get all depression episodes for a user.""" results = self.conn.execute( """SELECT id, start_time, end_time, peak_risk_level, total_readings, avg_confidence, is_active FROM depression_episodes WHERE user_id = ? ORDER BY start_time DESC LIMIT ?""", (user_id, limit) ).fetchall() return [ { "id": r[0], "start_time": r[1], "end_time": r[2], "peak_risk_level": r[3], "total_readings": r[4], "avg_confidence": r[5], "is_active": bool(r[6]) } for r in results ] def get_ml_prediction_history(self, user_id: str, limit: int = 100) -> list: """Get history of ML predictions for a user.""" results = self.conn.execute( """SELECT ml_prediction, ml_confidence, risk_level, recorded_at FROM wearable_data WHERE user_id = ? AND ml_prediction IS NOT NULL ORDER BY recorded_at DESC LIMIT ?""", (user_id, limit) ).fetchall() return [ { "prediction": r[0], "confidence": r[1], "risk_level": r[2], "timestamp": r[3] } for r in results ] # Singleton instance _db_instance = None def get_db() -> Database: """Get database singleton instance.""" global _db_instance if _db_instance is None: _db_instance = Database() return _db_instance