| from __future__ import annotations |
|
|
| import sqlite3 |
| import threading |
| from contextlib import contextmanager |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| from course_catcher.security import CredentialCipher, hash_password, verify_password |
|
|
|
|
| def utc_now() -> str: |
| return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" |
|
|
|
|
| class Database: |
| def __init__(self, db_path: Path, cipher: CredentialCipher, default_parallelism: int) -> None: |
| self.db_path = Path(db_path) |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) |
| self.cipher = cipher |
| self.default_parallelism = default_parallelism |
| self._write_lock = threading.RLock() |
| self._initialize() |
|
|
| @contextmanager |
| def connect(self) -> sqlite3.Connection: |
| connection = sqlite3.connect(self.db_path, timeout=30, check_same_thread=False) |
| connection.row_factory = sqlite3.Row |
| connection.execute("PRAGMA foreign_keys = ON") |
| connection.execute("PRAGMA busy_timeout = 30000") |
| try: |
| yield connection |
| connection.commit() |
| finally: |
| connection.close() |
|
|
| def _initialize(self) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.executescript( |
| """ |
| CREATE TABLE IF NOT EXISTS users ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| student_id TEXT NOT NULL UNIQUE, |
| password_hash TEXT NOT NULL, |
| password_encrypted TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL, |
| last_login_at TEXT |
| ); |
| |
| CREATE TABLE IF NOT EXISTS admins ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| username TEXT NOT NULL UNIQUE, |
| password_hash TEXT NOT NULL, |
| role TEXT NOT NULL CHECK(role IN ('admin', 'superadmin')), |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL, |
| last_login_at TEXT |
| ); |
| |
| CREATE TABLE IF NOT EXISTS courses ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| user_id INTEGER NOT NULL, |
| course_id TEXT NOT NULL, |
| course_index TEXT NOT NULL, |
| status TEXT NOT NULL DEFAULT 'pending', |
| last_result TEXT NOT NULL DEFAULT '', |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL, |
| UNIQUE(user_id, course_id, course_index), |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE |
| ); |
| |
| CREATE TABLE IF NOT EXISTS tasks ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| user_id INTEGER NOT NULL, |
| status TEXT NOT NULL CHECK(status IN ('queued', 'running', 'completed', 'failed', 'stopped')), |
| requested_by_type TEXT NOT NULL, |
| requested_by_name TEXT NOT NULL, |
| stop_requested INTEGER NOT NULL DEFAULT 0, |
| attempt_count INTEGER NOT NULL DEFAULT 0, |
| last_error TEXT NOT NULL DEFAULT '', |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL, |
| started_at TEXT, |
| finished_at TEXT, |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE |
| ); |
| |
| CREATE TABLE IF NOT EXISTS settings ( |
| key TEXT PRIMARY KEY, |
| value TEXT NOT NULL |
| ); |
| |
| CREATE TABLE IF NOT EXISTS logs ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| task_id INTEGER, |
| user_id INTEGER, |
| created_at TEXT NOT NULL, |
| level TEXT NOT NULL, |
| actor TEXT NOT NULL, |
| message TEXT NOT NULL, |
| FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE SET NULL, |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE |
| ); |
| """ |
| ) |
| connection.execute( |
| "INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)", |
| ("max_parallel_tasks", str(self.default_parallelism)), |
| ) |
| connection.execute( |
| "INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)", |
| ("loop_interval_seconds", "3"), |
| ) |
|
|
| def ensure_superadmin(self, username: str, password: str) -> None: |
| now = utc_now() |
| password_hash = hash_password(password) |
| with self._write_lock, self.connect() as connection: |
| row = connection.execute( |
| "SELECT id FROM admins WHERE role = 'superadmin' ORDER BY id LIMIT 1" |
| ).fetchone() |
| if row: |
| connection.execute( |
| """ |
| UPDATE admins |
| SET username = ?, password_hash = ?, updated_at = ? |
| WHERE id = ? |
| """, |
| (username, password_hash, now, row["id"]), |
| ) |
| else: |
| connection.execute( |
| """ |
| INSERT INTO admins(username, password_hash, role, created_at, updated_at) |
| VALUES (?, ?, 'superadmin', ?, ?) |
| """, |
| (username, password_hash, now, now), |
| ) |
|
|
| def get_setting_int(self, key: str, fallback: int) -> int: |
| with self.connect() as connection: |
| row = connection.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() |
| if not row: |
| return fallback |
| try: |
| return int(row["value"]) |
| except (TypeError, ValueError): |
| return fallback |
|
|
| def get_setting_float(self, key: str, fallback: float) -> float: |
| with self.connect() as connection: |
| row = connection.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() |
| if not row: |
| return fallback |
| try: |
| return float(row["value"]) |
| except (TypeError, ValueError): |
| return fallback |
|
|
| def set_setting(self, key: str, value: str) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO settings(key, value) VALUES (?, ?) |
| ON CONFLICT(key) DO UPDATE SET value = excluded.value |
| """, |
| (key, value), |
| ) |
|
|
| def list_admins(self) -> list[dict[str, Any]]: |
| with self.connect() as connection: |
| rows = connection.execute( |
| "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins ORDER BY role DESC, username ASC" |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def create_admin(self, username: str, password: str) -> dict[str, Any]: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO admins(username, password_hash, role, created_at, updated_at) |
| VALUES (?, ?, 'admin', ?, ?) |
| """, |
| (username, hash_password(password), now, now), |
| ) |
| row = connection.execute( |
| "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE username = ?", |
| (username,), |
| ).fetchone() |
| return dict(row) |
|
|
| def verify_admin_login(self, username: str, password: str) -> dict[str, Any] | None: |
| with self.connect() as connection: |
| row = connection.execute("SELECT * FROM admins WHERE username = ?", (username,)).fetchone() |
| if not row or not verify_password(row["password_hash"], password): |
| return None |
| connection.execute( |
| "UPDATE admins SET last_login_at = ?, updated_at = ? WHERE id = ?", |
| (utc_now(), utc_now(), row["id"]), |
| ) |
| refreshed = connection.execute( |
| "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE id = ?", |
| (row["id"],), |
| ).fetchone() |
| return dict(refreshed) |
|
|
| def get_admin_by_id(self, admin_id: int) -> dict[str, Any] | None: |
| with self.connect() as connection: |
| row = connection.execute( |
| "SELECT id, username, role, created_at, updated_at, last_login_at FROM admins WHERE id = ?", |
| (admin_id,), |
| ).fetchone() |
| return dict(row) if row else None |
|
|
| def get_user_by_id(self, user_id: int, include_password: bool = False) -> dict[str, Any] | None: |
| fields = "id, student_id, created_at, updated_at, last_login_at" |
| if include_password: |
| fields += ", password_hash, password_encrypted" |
| with self.connect() as connection: |
| row = connection.execute(f"SELECT {fields} FROM users WHERE id = ?", (user_id,)).fetchone() |
| return dict(row) if row else None |
|
|
| def get_user_by_student_id(self, student_id: str, include_password: bool = False) -> dict[str, Any] | None: |
| fields = "id, student_id, created_at, updated_at, last_login_at" |
| if include_password: |
| fields += ", password_hash, password_encrypted" |
| with self.connect() as connection: |
| row = connection.execute( |
| f"SELECT {fields} FROM users WHERE student_id = ?", |
| (student_id,), |
| ).fetchone() |
| return dict(row) if row else None |
|
|
| def create_user(self, student_id: str, password: str) -> dict[str, Any]: |
| now = utc_now() |
| encrypted = self.cipher.encrypt(password) |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO users(student_id, password_hash, password_encrypted, created_at, updated_at) |
| VALUES (?, ?, ?, ?, ?) |
| """, |
| (student_id, hash_password(password), encrypted, now, now), |
| ) |
| row = connection.execute( |
| "SELECT id, student_id, created_at, updated_at, last_login_at FROM users WHERE student_id = ?", |
| (student_id,), |
| ).fetchone() |
| return dict(row) |
|
|
| def create_or_update_user(self, student_id: str, password: str) -> dict[str, Any]: |
| existing = self.get_user_by_student_id(student_id) |
| if existing: |
| self.update_user_password(existing["id"], password) |
| return self.get_user_by_id(existing["id"]) or existing |
| return self.create_user(student_id, password) |
|
|
| def verify_or_create_user_login(self, student_id: str, password: str) -> tuple[dict[str, Any] | None, bool]: |
| user = self.get_user_by_student_id(student_id, include_password=True) |
| if not user: |
| created = self.create_user(student_id, password) |
| return created, True |
| if not verify_password(user["password_hash"], password): |
| return None, False |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| "UPDATE users SET last_login_at = ?, updated_at = ? WHERE id = ?", |
| (now, now, user["id"]), |
| ) |
| refreshed = connection.execute( |
| "SELECT id, student_id, created_at, updated_at, last_login_at FROM users WHERE id = ?", |
| (user["id"],), |
| ).fetchone() |
| return dict(refreshed), False |
|
|
| def update_user_password(self, user_id: int, password: str) -> None: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE users |
| SET password_hash = ?, password_encrypted = ?, updated_at = ? |
| WHERE id = ? |
| """, |
| (hash_password(password), self.cipher.encrypt(password), now, user_id), |
| ) |
|
|
| def get_user_runtime_credentials(self, user_id: int) -> dict[str, Any] | None: |
| with self.connect() as connection: |
| row = connection.execute( |
| "SELECT id, student_id, password_encrypted FROM users WHERE id = ?", |
| (user_id,), |
| ).fetchone() |
| if not row: |
| return None |
| data = dict(row) |
| data["password"] = self.cipher.decrypt(data["password_encrypted"]) |
| return data |
|
|
| def list_courses(self, user_id: int) -> list[dict[str, Any]]: |
| with self.connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT id, user_id, course_id, course_index, status, last_result, created_at, updated_at |
| FROM courses |
| WHERE user_id = ? |
| ORDER BY status DESC, course_id ASC, course_index ASC |
| """, |
| (user_id,), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def list_pending_courses(self, user_id: int) -> list[dict[str, Any]]: |
| with self.connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT id, user_id, course_id, course_index, status, last_result, created_at, updated_at |
| FROM courses |
| WHERE user_id = ? AND status != 'selected' |
| ORDER BY course_id ASC, course_index ASC |
| """, |
| (user_id,), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def add_course(self, user_id: int, course_id: str, course_index: str) -> None: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO courses(user_id, course_id, course_index, status, last_result, created_at, updated_at) |
| VALUES (?, ?, ?, 'pending', '', ?, ?) |
| ON CONFLICT(user_id, course_id, course_index) |
| DO UPDATE SET status = 'pending', updated_at = excluded.updated_at |
| """, |
| (user_id, course_id, course_index, now, now), |
| ) |
|
|
| def delete_course(self, course_id: int, user_id: int) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.execute("DELETE FROM courses WHERE id = ? AND user_id = ?", (course_id, user_id)) |
|
|
| def mark_course_result( |
| self, |
| user_id: int, |
| course_id: str, |
| course_index: str, |
| status: str, |
| detail: str, |
| ) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE courses |
| SET status = ?, last_result = ?, updated_at = ? |
| WHERE user_id = ? AND course_id = ? AND course_index = ? |
| """, |
| (status, detail, utc_now(), user_id, course_id, course_index), |
| ) |
|
|
| def list_users_with_summary(self) -> list[dict[str, Any]]: |
| with self.connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT |
| u.id, |
| u.student_id, |
| u.created_at, |
| u.updated_at, |
| u.last_login_at, |
| COALESCE(course_stats.course_count, 0) AS course_count, |
| COALESCE(course_stats.selected_count, 0) AS selected_count, |
| active_task.id AS active_task_id, |
| active_task.status AS active_task_status |
| FROM users u |
| LEFT JOIN ( |
| SELECT |
| user_id, |
| COUNT(*) AS course_count, |
| SUM(CASE WHEN status = 'selected' THEN 1 ELSE 0 END) AS selected_count |
| FROM courses |
| GROUP BY user_id |
| ) AS course_stats ON course_stats.user_id = u.id |
| LEFT JOIN ( |
| SELECT t1.* |
| FROM tasks t1 |
| JOIN ( |
| SELECT user_id, MAX(id) AS id |
| FROM tasks |
| WHERE status IN ('queued', 'running') |
| GROUP BY user_id |
| ) latest ON latest.id = t1.id |
| ) AS active_task ON active_task.user_id = u.id |
| ORDER BY u.updated_at DESC, u.student_id ASC |
| """ |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def get_user_dashboard_state(self, user_id: int) -> dict[str, Any]: |
| with self.connect() as connection: |
| counts = connection.execute( |
| """ |
| SELECT |
| COUNT(*) AS total_courses, |
| SUM(CASE WHEN status = 'selected' THEN 1 ELSE 0 END) AS selected_courses, |
| SUM(CASE WHEN status != 'selected' THEN 1 ELSE 0 END) AS pending_courses |
| FROM courses |
| WHERE user_id = ? |
| """, |
| (user_id,), |
| ).fetchone() |
| task = connection.execute( |
| """ |
| SELECT id, status, attempt_count, updated_at, started_at, finished_at |
| FROM tasks |
| WHERE user_id = ? |
| ORDER BY id DESC |
| LIMIT 1 |
| """, |
| (user_id,), |
| ).fetchone() |
| return { |
| "total_courses": counts["total_courses"] or 0, |
| "selected_courses": counts["selected_courses"] or 0, |
| "pending_courses": counts["pending_courses"] or 0, |
| "task": dict(task) if task else None, |
| } |
|
|
| def get_admin_summary(self) -> dict[str, Any]: |
| with self.connect() as connection: |
| users = connection.execute("SELECT COUNT(*) AS total FROM users").fetchone()["total"] |
| running_tasks = connection.execute( |
| "SELECT COUNT(*) AS total FROM tasks WHERE status = 'running'" |
| ).fetchone()["total"] |
| queued_tasks = connection.execute( |
| "SELECT COUNT(*) AS total FROM tasks WHERE status = 'queued'" |
| ).fetchone()["total"] |
| pending_courses = connection.execute( |
| "SELECT COUNT(*) AS total FROM courses WHERE status != 'selected'" |
| ).fetchone()["total"] |
| return { |
| "users": users, |
| "running_tasks": running_tasks, |
| "queued_tasks": queued_tasks, |
| "pending_courses": pending_courses, |
| "parallelism": self.get_setting_int("max_parallel_tasks", self.default_parallelism), |
| } |
|
|
| def create_task(self, user_id: int, requested_by_type: str, requested_by_name: str) -> dict[str, Any]: |
| with self._write_lock, self.connect() as connection: |
| existing = connection.execute( |
| """ |
| SELECT * |
| FROM tasks |
| WHERE user_id = ? AND status IN ('queued', 'running') |
| ORDER BY id DESC |
| LIMIT 1 |
| """, |
| (user_id,), |
| ).fetchone() |
| if existing: |
| return dict(existing) |
| now = utc_now() |
| connection.execute( |
| """ |
| INSERT INTO tasks( |
| user_id, status, requested_by_type, requested_by_name, stop_requested, |
| attempt_count, last_error, created_at, updated_at |
| ) |
| VALUES (?, 'queued', ?, ?, 0, 0, '', ?, ?) |
| """, |
| (user_id, requested_by_type, requested_by_name, now, now), |
| ) |
| created = connection.execute( |
| "SELECT * FROM tasks WHERE id = last_insert_rowid()" |
| ).fetchone() |
| return dict(created) |
|
|
| def get_task(self, task_id: int) -> dict[str, Any] | None: |
| with self.connect() as connection: |
| row = connection.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() |
| return dict(row) if row else None |
|
|
| def fetch_queued_tasks(self, limit: int) -> list[dict[str, Any]]: |
| with self.connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT * |
| FROM tasks |
| WHERE status = 'queued' |
| ORDER BY created_at ASC, id ASC |
| LIMIT ? |
| """, |
| (limit,), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def mark_task_running(self, task_id: int) -> None: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET status = 'running', updated_at = ?, started_at = COALESCE(started_at, ?) |
| WHERE id = ? |
| """, |
| (now, now, task_id), |
| ) |
|
|
| def increment_task_attempt(self, task_id: int) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET attempt_count = attempt_count + 1, updated_at = ? |
| WHERE id = ? |
| """, |
| (utc_now(), task_id), |
| ) |
|
|
| def request_stop_task(self, user_id: int) -> None: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET stop_requested = 1, updated_at = ? |
| WHERE user_id = ? AND status = 'running' |
| """, |
| (now, user_id), |
| ) |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET status = 'stopped', stop_requested = 1, updated_at = ?, finished_at = ? |
| WHERE user_id = ? AND status = 'queued' |
| """, |
| (now, now, user_id), |
| ) |
|
|
| def is_stop_requested(self, task_id: int) -> bool: |
| with self.connect() as connection: |
| row = connection.execute( |
| "SELECT stop_requested FROM tasks WHERE id = ?", |
| (task_id,), |
| ).fetchone() |
| return bool(row and row["stop_requested"]) |
|
|
| def finish_task(self, task_id: int, status: str, last_error: str = "") -> None: |
| now = utc_now() |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET status = ?, last_error = ?, updated_at = ?, finished_at = ? |
| WHERE id = ? |
| """, |
| (status, last_error, now, now, task_id), |
| ) |
|
|
| def add_log(self, task_id: int | None, user_id: int | None, actor: str, level: str, message: str) -> None: |
| with self._write_lock, self.connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO logs(task_id, user_id, created_at, level, actor, message) |
| VALUES (?, ?, ?, ?, ?, ?) |
| """, |
| (task_id, user_id, utc_now(), level.upper(), actor, message), |
| ) |
|
|
| def get_logs_after(self, last_id: int, user_id: int | None = None, limit: int = 200) -> list[dict[str, Any]]: |
| query = """ |
| SELECT |
| logs.id, |
| logs.task_id, |
| logs.user_id, |
| logs.created_at, |
| logs.level, |
| logs.actor, |
| logs.message, |
| users.student_id |
| FROM logs |
| LEFT JOIN users ON users.id = logs.user_id |
| WHERE logs.id > ? |
| """ |
| params: list[Any] = [last_id] |
| if user_id is not None: |
| query += " AND logs.user_id = ?" |
| params.append(user_id) |
| query += " ORDER BY logs.id ASC LIMIT ?" |
| params.append(limit) |
| with self.connect() as connection: |
| rows = connection.execute(query, params).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def get_recent_logs(self, user_id: int | None = None, limit: int = 120) -> list[dict[str, Any]]: |
| query = """ |
| SELECT |
| logs.id, |
| logs.task_id, |
| logs.user_id, |
| logs.created_at, |
| logs.level, |
| logs.actor, |
| logs.message, |
| users.student_id |
| FROM logs |
| LEFT JOIN users ON users.id = logs.user_id |
| """ |
| params: list[Any] = [] |
| if user_id is not None: |
| query += " WHERE logs.user_id = ?" |
| params.append(user_id) |
| query += " ORDER BY logs.id DESC LIMIT ?" |
| params.append(limit) |
| with self.connect() as connection: |
| rows = connection.execute(query, params).fetchall() |
| items = [dict(row) for row in rows] |
| items.reverse() |
| return items |
|
|