from __future__ import annotations import sqlite3 import threading from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any from course_catcher.security import CredentialCipher, hash_password, verify_password def utc_now() -> str: return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" class Database: def __init__(self, db_path: Path, cipher: CredentialCipher, default_parallelism: int) -> None: self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self.cipher = cipher self.default_parallelism = default_parallelism self._write_lock = threading.RLock() self._initialize() @contextmanager def connect(self) -> sqlite3.Connection: connection = sqlite3.connect(self.db_path, timeout=30, check_same_thread=False) connection.row_factory = sqlite3.Row connection.execute("PRAGMA foreign_keys = ON") connection.execute("PRAGMA busy_timeout = 30000") try: yield connection connection.commit() finally: connection.close() def _initialize(self) -> None: with self._write_lock, self.connect() as connection: connection.executescript( """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, student_id TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, password_encrypted TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, last_login_at TEXT ); CREATE TABLE IF NOT EXISTS admins ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, role TEXT NOT NULL CHECK(role IN ('admin', 'superadmin')), created_at TEXT NOT NULL, updated_at TEXT NOT NULL, last_login_at TEXT ); CREATE TABLE IF NOT EXISTS courses ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, course_id TEXT NOT NULL, course_index TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', last_result TEXT NOT NULL DEFAULT '', created_at TEXT NOT NULL, updated_at TEXT NOT NULL, UNIQUE(user_id, course_id, course_index), FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS tasks ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, status TEXT NOT NULL CHECK(status IN ('queued', 'running', 'completed', 'failed', 'stopped')), requested_by_type TEXT NOT NULL, requested_by_name TEXT NOT NULL, stop_requested INTEGER NOT NULL DEFAULT 0, attempt_count INTEGER NOT NULL DEFAULT 0, last_error TEXT NOT NULL DEFAULT '', created_at TEXT NOT NULL, updated_at TEXT NOT NULL, started_at TEXT, finished_at TEXT, FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS settings ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id INTEGER, user_id INTEGER, created_at TEXT NOT NULL, level TEXT NOT NULL, actor TEXT NOT NULL, message TEXT NOT NULL, FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE SET NULL, FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE ); """ ) connection.execute( "INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)", ("max_parallel_tasks", str(self.default_parallelism)), ) connection.execute( "INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)", ("loop_interval_seconds", "3"), ) def ensure_superadmin(self, username: str, password: str) -> None: now = utc_now() password_hash = hash_password(password) with self._write_lock, self.connect() as connection: row = connection.execute( "SELECT id FROM admins WHERE role = 'superadmin' ORDER BY id LIMIT 1" ).fetchone() if row: connection.execute( """ UPDATE admins SET username = ?, password_hash = ?, updated_at = ? WHERE id = ? """, (username, password_hash, now, row["id"]), ) else: connection.execute( """ INSERT INTO admins(username, password_hash, role, created_at, updated_at) VALUES (?, ?, 'superadmin', ?, ?) """, (username, password_hash, now, now), ) def get_setting_int(self, key: str, fallback: int) -> int: with self.connect() as connection: row = connection.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() if not row: return fallback try: return int(row["value"]) except (TypeError, ValueError): return fallback def get_setting_float(self, key: str, fallback: float) -> float: with self.connect() as connection: row = connection.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() if not row: return fallback try: return float(row["value"]) except (TypeError, ValueError): return fallback def set_setting(self, key: str, value: str) -> None: with self._write_lock, self.connect() as connection: connection.execute( """ INSERT INTO settings(key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value """, (key, value), ) def list_admins(self) -> list[dict[str, Any]]: with self.connect() as connection: rows = connection.execute( "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins ORDER BY role DESC, username ASC" ).fetchall() return [dict(row) for row in rows] def create_admin(self, username: str, password: str) -> dict[str, Any]: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ INSERT INTO admins(username, password_hash, role, created_at, updated_at) VALUES (?, ?, 'admin', ?, ?) """, (username, hash_password(password), now, now), ) row = connection.execute( "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE username = ?", (username,), ).fetchone() return dict(row) def verify_admin_login(self, username: str, password: str) -> dict[str, Any] | None: with self.connect() as connection: row = connection.execute("SELECT * FROM admins WHERE username = ?", (username,)).fetchone() if not row or not verify_password(row["password_hash"], password): return None connection.execute( "UPDATE admins SET last_login_at = ?, updated_at = ? WHERE id = ?", (utc_now(), utc_now(), row["id"]), ) refreshed = connection.execute( "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE id = ?", (row["id"],), ).fetchone() return dict(refreshed) def get_admin_by_id(self, admin_id: int) -> dict[str, Any] | None: with self.connect() as connection: row = connection.execute( "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE id = ?", (admin_id,), ).fetchone() return dict(row) if row else None def get_user_by_id(self, user_id: int, include_password: bool = False) -> dict[str, Any] | None: fields = "id, student_id, created_at, updated_at, last_login_at" if include_password: fields += ", password_hash, password_encrypted" with self.connect() as connection: row = connection.execute(f"SELECT {fields} FROM users WHERE id = ?", (user_id,)).fetchone() return dict(row) if row else None def get_user_by_student_id(self, student_id: str, include_password: bool = False) -> dict[str, Any] | None: fields = "id, student_id, created_at, updated_at, last_login_at" if include_password: fields += ", password_hash, password_encrypted" with self.connect() as connection: row = connection.execute( f"SELECT {fields} FROM users WHERE student_id = ?", (student_id,), ).fetchone() return dict(row) if row else None def create_user(self, student_id: str, password: str) -> dict[str, Any]: now = utc_now() encrypted = self.cipher.encrypt(password) with self._write_lock, self.connect() as connection: connection.execute( """ INSERT INTO users(student_id, password_hash, password_encrypted, created_at, updated_at) VALUES (?, ?, ?, ?, ?) """, (student_id, hash_password(password), encrypted, now, now), ) row = connection.execute( "SELECT id, student_id, created_at, updated_at, last_login_at FROM users WHERE student_id = ?", (student_id,), ).fetchone() return dict(row) def create_or_update_user(self, student_id: str, password: str) -> dict[str, Any]: existing = self.get_user_by_student_id(student_id) if existing: self.update_user_password(existing["id"], password) return self.get_user_by_id(existing["id"]) or existing return self.create_user(student_id, password) def verify_or_create_user_login(self, student_id: str, password: str) -> tuple[dict[str, Any] | None, bool]: user = self.get_user_by_student_id(student_id, include_password=True) if not user: created = self.create_user(student_id, password) return created, True if not verify_password(user["password_hash"], password): return None, False now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( "UPDATE users SET last_login_at = ?, updated_at = ? WHERE id = ?", (now, now, user["id"]), ) refreshed = connection.execute( "SELECT id, student_id, created_at, updated_at, last_login_at FROM users WHERE id = ?", (user["id"],), ).fetchone() return dict(refreshed), False def update_user_password(self, user_id: int, password: str) -> None: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE users SET password_hash = ?, password_encrypted = ?, updated_at = ? WHERE id = ? """, (hash_password(password), self.cipher.encrypt(password), now, user_id), ) def get_user_runtime_credentials(self, user_id: int) -> dict[str, Any] | None: with self.connect() as connection: row = connection.execute( "SELECT id, student_id, password_encrypted FROM users WHERE id = ?", (user_id,), ).fetchone() if not row: return None data = dict(row) data["password"] = self.cipher.decrypt(data["password_encrypted"]) return data def list_courses(self, user_id: int) -> list[dict[str, Any]]: with self.connect() as connection: rows = connection.execute( """ SELECT id, user_id, course_id, course_index, status, last_result, created_at, updated_at FROM courses WHERE user_id = ? ORDER BY status DESC, course_id ASC, course_index ASC """, (user_id,), ).fetchall() return [dict(row) for row in rows] def list_pending_courses(self, user_id: int) -> list[dict[str, Any]]: with self.connect() as connection: rows = connection.execute( """ SELECT id, user_id, course_id, course_index, status, last_result, created_at, updated_at FROM courses WHERE user_id = ? AND status != 'selected' ORDER BY course_id ASC, course_index ASC """, (user_id,), ).fetchall() return [dict(row) for row in rows] def add_course(self, user_id: int, course_id: str, course_index: str) -> None: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ INSERT INTO courses(user_id, course_id, course_index, status, last_result, created_at, updated_at) VALUES (?, ?, ?, 'pending', '', ?, ?) ON CONFLICT(user_id, course_id, course_index) DO UPDATE SET status = 'pending', updated_at = excluded.updated_at """, (user_id, course_id, course_index, now, now), ) def delete_course(self, course_id: int, user_id: int) -> None: with self._write_lock, self.connect() as connection: connection.execute("DELETE FROM courses WHERE id = ? AND user_id = ?", (course_id, user_id)) def mark_course_result( self, user_id: int, course_id: str, course_index: str, status: str, detail: str, ) -> None: with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE courses SET status = ?, last_result = ?, updated_at = ? WHERE user_id = ? AND course_id = ? AND course_index = ? """, (status, detail, utc_now(), user_id, course_id, course_index), ) def list_users_with_summary(self) -> list[dict[str, Any]]: with self.connect() as connection: rows = connection.execute( """ SELECT u.id, u.student_id, u.created_at, u.updated_at, u.last_login_at, COALESCE(course_stats.course_count, 0) AS course_count, COALESCE(course_stats.selected_count, 0) AS selected_count, active_task.id AS active_task_id, active_task.status AS active_task_status FROM users u LEFT JOIN ( SELECT user_id, COUNT(*) AS course_count, SUM(CASE WHEN status = 'selected' THEN 1 ELSE 0 END) AS selected_count FROM courses GROUP BY user_id ) AS course_stats ON course_stats.user_id = u.id LEFT JOIN ( SELECT t1.* FROM tasks t1 JOIN ( SELECT user_id, MAX(id) AS id FROM tasks WHERE status IN ('queued', 'running') GROUP BY user_id ) latest ON latest.id = t1.id ) AS active_task ON active_task.user_id = u.id ORDER BY u.updated_at DESC, u.student_id ASC """ ).fetchall() return [dict(row) for row in rows] def get_user_dashboard_state(self, user_id: int) -> dict[str, Any]: with self.connect() as connection: counts = connection.execute( """ SELECT COUNT(*) AS total_courses, SUM(CASE WHEN status = 'selected' THEN 1 ELSE 0 END) AS selected_courses, SUM(CASE WHEN status != 'selected' THEN 1 ELSE 0 END) AS pending_courses FROM courses WHERE user_id = ? """, (user_id,), ).fetchone() task = connection.execute( """ SELECT id, status, attempt_count, updated_at, started_at, finished_at FROM tasks WHERE user_id = ? ORDER BY id DESC LIMIT 1 """, (user_id,), ).fetchone() return { "total_courses": counts["total_courses"] or 0, "selected_courses": counts["selected_courses"] or 0, "pending_courses": counts["pending_courses"] or 0, "task": dict(task) if task else None, } def get_admin_summary(self) -> dict[str, Any]: with self.connect() as connection: users = connection.execute("SELECT COUNT(*) AS total FROM users").fetchone()["total"] running_tasks = connection.execute( "SELECT COUNT(*) AS total FROM tasks WHERE status = 'running'" ).fetchone()["total"] queued_tasks = connection.execute( "SELECT COUNT(*) AS total FROM tasks WHERE status = 'queued'" ).fetchone()["total"] pending_courses = connection.execute( "SELECT COUNT(*) AS total FROM courses WHERE status != 'selected'" ).fetchone()["total"] return { "users": users, "running_tasks": running_tasks, "queued_tasks": queued_tasks, "pending_courses": pending_courses, "parallelism": self.get_setting_int("max_parallel_tasks", self.default_parallelism), } def create_task(self, user_id: int, requested_by_type: str, requested_by_name: str) -> dict[str, Any]: with self._write_lock, self.connect() as connection: existing = connection.execute( """ SELECT * FROM tasks WHERE user_id = ? AND status IN ('queued', 'running') ORDER BY id DESC LIMIT 1 """, (user_id,), ).fetchone() if existing: return dict(existing) now = utc_now() connection.execute( """ INSERT INTO tasks( user_id, status, requested_by_type, requested_by_name, stop_requested, attempt_count, last_error, created_at, updated_at ) VALUES (?, 'queued', ?, ?, 0, 0, '', ?, ?) """, (user_id, requested_by_type, requested_by_name, now, now), ) created = connection.execute( "SELECT * FROM tasks WHERE id = last_insert_rowid()" ).fetchone() return dict(created) def get_task(self, task_id: int) -> dict[str, Any] | None: with self.connect() as connection: row = connection.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() return dict(row) if row else None def fetch_queued_tasks(self, limit: int) -> list[dict[str, Any]]: with self.connect() as connection: rows = connection.execute( """ SELECT * FROM tasks WHERE status = 'queued' ORDER BY created_at ASC, id ASC LIMIT ? """, (limit,), ).fetchall() return [dict(row) for row in rows] def mark_task_running(self, task_id: int) -> None: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE tasks SET status = 'running', updated_at = ?, started_at = COALESCE(started_at, ?) WHERE id = ? """, (now, now, task_id), ) def increment_task_attempt(self, task_id: int) -> None: with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE tasks SET attempt_count = attempt_count + 1, updated_at = ? WHERE id = ? """, (utc_now(), task_id), ) def request_stop_task(self, user_id: int) -> None: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE tasks SET stop_requested = 1, updated_at = ? WHERE user_id = ? AND status = 'running' """, (now, user_id), ) connection.execute( """ UPDATE tasks SET status = 'stopped', stop_requested = 1, updated_at = ?, finished_at = ? WHERE user_id = ? AND status = 'queued' """, (now, now, user_id), ) def is_stop_requested(self, task_id: int) -> bool: with self.connect() as connection: row = connection.execute( "SELECT stop_requested FROM tasks WHERE id = ?", (task_id,), ).fetchone() return bool(row and row["stop_requested"]) def finish_task(self, task_id: int, status: str, last_error: str = "") -> None: now = utc_now() with self._write_lock, self.connect() as connection: connection.execute( """ UPDATE tasks SET status = ?, last_error = ?, updated_at = ?, finished_at = ? WHERE id = ? """, (status, last_error, now, now, task_id), ) def add_log(self, task_id: int | None, user_id: int | None, actor: str, level: str, message: str) -> None: with self._write_lock, self.connect() as connection: connection.execute( """ INSERT INTO logs(task_id, user_id, created_at, level, actor, message) VALUES (?, ?, ?, ?, ?, ?) """, (task_id, user_id, utc_now(), level.upper(), actor, message), ) def get_logs_after(self, last_id: int, user_id: int | None = None, limit: int = 200) -> list[dict[str, Any]]: query = """ SELECT logs.id, logs.task_id, logs.user_id, logs.created_at, logs.level, logs.actor, logs.message, users.student_id FROM logs LEFT JOIN users ON users.id = logs.user_id WHERE logs.id > ? """ params: list[Any] = [last_id] if user_id is not None: query += " AND logs.user_id = ?" params.append(user_id) query += " ORDER BY logs.id ASC LIMIT ?" params.append(limit) with self.connect() as connection: rows = connection.execute(query, params).fetchall() return [dict(row) for row in rows] def get_recent_logs(self, user_id: int | None = None, limit: int = 120) -> list[dict[str, Any]]: query = """ SELECT logs.id, logs.task_id, logs.user_id, logs.created_at, logs.level, logs.actor, logs.message, users.student_id FROM logs LEFT JOIN users ON users.id = logs.user_id """ params: list[Any] = [] if user_id is not None: query += " WHERE logs.user_id = ?" params.append(user_id) query += " ORDER BY logs.id DESC LIMIT ?" params.append(limit) with self.connect() as connection: rows = connection.execute(query, params).fetchall() items = [dict(row) for row in rows] items.reverse() return items