Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import sqlite3 | |
| import uuid | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| TERMINAL_TASK_STATUSES = {'success', 'failed', 'stopped', 'interrupted'} | |
| ACTIVE_TASK_STATUSES = {'queued', 'running', 'waiting_captcha'} | |
| def utc_now() -> str: | |
| return datetime.now(timezone.utc).isoformat(timespec='seconds') | |
| class Database: | |
| def __init__(self, db_path: Path): | |
| self.db_path = Path(db_path) | |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) | |
| def _connect(self) -> sqlite3.Connection: | |
| connection = sqlite3.connect(self.db_path, check_same_thread=False) | |
| connection.row_factory = sqlite3.Row | |
| connection.execute('PRAGMA foreign_keys = ON') | |
| return connection | |
| def initialize(self) -> None: | |
| with self._connect() as connection: | |
| connection.execute('PRAGMA journal_mode = WAL') | |
| connection.executescript( | |
| """ | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| student_id TEXT NOT NULL UNIQUE, | |
| display_name TEXT NOT NULL DEFAULT '', | |
| encrypted_password TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| updated_at TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS admins ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT NOT NULL UNIQUE, | |
| password_hash TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| updated_at TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS courses ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER NOT NULL, | |
| course_id TEXT NOT NULL, | |
| course_index TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| UNIQUE(user_id, course_id, course_index), | |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE | |
| ); | |
| CREATE TABLE IF NOT EXISTS tasks ( | |
| id TEXT PRIMARY KEY, | |
| user_id INTEGER NOT NULL, | |
| status TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| started_at TEXT, | |
| finished_at TEXT, | |
| requested_by_role TEXT NOT NULL, | |
| requested_by_identity TEXT NOT NULL, | |
| stop_requested INTEGER NOT NULL DEFAULT 0, | |
| last_error TEXT NOT NULL DEFAULT '', | |
| total_count INTEGER NOT NULL DEFAULT 0, | |
| completed_count INTEGER NOT NULL DEFAULT 0, | |
| task_payload TEXT NOT NULL, | |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE | |
| ); | |
| CREATE TABLE IF NOT EXISTS task_logs ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| task_id TEXT NOT NULL, | |
| level TEXT NOT NULL, | |
| message TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE | |
| ); | |
| CREATE TABLE IF NOT EXISTS settings ( | |
| key TEXT PRIMARY KEY, | |
| value TEXT NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_courses_user_id ON courses(user_id); | |
| CREATE INDEX IF NOT EXISTS idx_tasks_user_id ON tasks(user_id); | |
| CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); | |
| CREATE INDEX IF NOT EXISTS idx_task_logs_task_id ON task_logs(task_id, id); | |
| """ | |
| ) | |
| connection.execute( | |
| 'INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)', | |
| ('max_parallel_tasks', '2'), | |
| ) | |
| connection.execute( | |
| """ | |
| UPDATE tasks | |
| SET status = 'interrupted', | |
| finished_at = ?, | |
| last_error = CASE | |
| WHEN COALESCE(last_error, '') = '' THEN '应用重启,上一轮任务被中断。' | |
| ELSE last_error | |
| END | |
| WHERE status IN ('queued', 'running', 'waiting_captcha') | |
| """, | |
| (utc_now(),), | |
| ) | |
| def get_setting(self, key: str, default: str = '') -> str: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT value FROM settings WHERE key = ?', (key,)).fetchone() | |
| return row['value'] if row else default | |
| def set_setting(self, key: str, value: str) -> None: | |
| with self._connect() as connection: | |
| connection.execute( | |
| """ | |
| INSERT INTO settings(key, value) VALUES (?, ?) | |
| ON CONFLICT(key) DO UPDATE SET value = excluded.value | |
| """, | |
| (key, value), | |
| ) | |
| def get_max_parallel_tasks(self) -> int: | |
| raw = self.get_setting('max_parallel_tasks', '2') | |
| try: | |
| return max(1, min(8, int(raw))) | |
| except ValueError: | |
| return 2 | |
| def create_user(self, student_id: str, encrypted_password: str, display_name: str = '') -> int: | |
| now = utc_now() | |
| with self._connect() as connection: | |
| cursor = connection.execute( | |
| """ | |
| INSERT INTO users(student_id, display_name, encrypted_password, created_at, updated_at) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, | |
| (student_id, display_name, encrypted_password, now, now), | |
| ) | |
| return int(cursor.lastrowid) | |
| def update_user( | |
| self, | |
| user_id: int, | |
| *, | |
| student_id: str | None = None, | |
| encrypted_password: str | None = None, | |
| display_name: str | None = None, | |
| ) -> None: | |
| fields: list[str] = [] | |
| values: list[Any] = [] | |
| if student_id is not None: | |
| fields.append('student_id = ?') | |
| values.append(student_id) | |
| if encrypted_password is not None: | |
| fields.append('encrypted_password = ?') | |
| values.append(encrypted_password) | |
| if display_name is not None: | |
| fields.append('display_name = ?') | |
| values.append(display_name) | |
| if not fields: | |
| return | |
| fields.append('updated_at = ?') | |
| values.append(utc_now()) | |
| values.append(user_id) | |
| with self._connect() as connection: | |
| connection.execute(f"UPDATE users SET {', '.join(fields)} WHERE id = ?", values) | |
| def delete_user(self, user_id: int) -> None: | |
| with self._connect() as connection: | |
| connection.execute('DELETE FROM users WHERE id = ?', (user_id,)) | |
| def get_user_by_student_id(self, student_id: str) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT * FROM users WHERE student_id = ?', (student_id,)).fetchone() | |
| return dict(row) if row else None | |
| def get_user_by_id(self, user_id: int) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT * FROM users WHERE id = ?', (user_id,)).fetchone() | |
| return dict(row) if row else None | |
| def list_users(self) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute( | |
| """ | |
| SELECT | |
| u.*, | |
| COUNT(c.id) AS course_count, | |
| ( | |
| SELECT t.status | |
| FROM tasks t | |
| WHERE t.user_id = u.id | |
| ORDER BY t.created_at DESC | |
| LIMIT 1 | |
| ) AS latest_task_status | |
| FROM users u | |
| LEFT JOIN courses c ON c.user_id = u.id | |
| GROUP BY u.id | |
| ORDER BY u.created_at ASC | |
| """ | |
| ).fetchall() | |
| return [dict(row) for row in rows] | |
| def list_courses_for_user(self, user_id: int) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute( | |
| """ | |
| SELECT id, course_id, course_index, created_at | |
| FROM courses | |
| WHERE user_id = ? | |
| ORDER BY created_at ASC, id ASC | |
| """, | |
| (user_id,), | |
| ).fetchall() | |
| return [dict(row) for row in rows] | |
| def add_course(self, user_id: int, course_id: str, course_index: str) -> None: | |
| with self._connect() as connection: | |
| connection.execute( | |
| """ | |
| INSERT OR IGNORE INTO courses(user_id, course_id, course_index, created_at) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (user_id, course_id, course_index, utc_now()), | |
| ) | |
| def delete_course(self, course_row_id: int, user_id: int | None = None) -> None: | |
| with self._connect() as connection: | |
| if user_id is None: | |
| connection.execute('DELETE FROM courses WHERE id = ?', (course_row_id,)) | |
| else: | |
| connection.execute('DELETE FROM courses WHERE id = ? AND user_id = ?', (course_row_id, user_id)) | |
| def list_admins(self) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute('SELECT id, username, created_at, updated_at FROM admins ORDER BY created_at ASC').fetchall() | |
| return [dict(row) for row in rows] | |
| def get_admin_by_username(self, username: str) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT * FROM admins WHERE username = ?', (username,)).fetchone() | |
| return dict(row) if row else None | |
| def create_admin(self, username: str, password_hash: str) -> int: | |
| now = utc_now() | |
| with self._connect() as connection: | |
| cursor = connection.execute( | |
| """ | |
| INSERT INTO admins(username, password_hash, created_at, updated_at) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (username, password_hash, now, now), | |
| ) | |
| return int(cursor.lastrowid) | |
| def update_admin_password(self, admin_id: int, password_hash: str) -> None: | |
| with self._connect() as connection: | |
| connection.execute( | |
| 'UPDATE admins SET password_hash = ?, updated_at = ? WHERE id = ?', | |
| (password_hash, utc_now(), admin_id), | |
| ) | |
| def delete_admin(self, admin_id: int) -> None: | |
| with self._connect() as connection: | |
| connection.execute('DELETE FROM admins WHERE id = ?', (admin_id,)) | |
| def find_active_task_for_user(self, user_id: int) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute( | |
| """ | |
| SELECT * FROM tasks | |
| WHERE user_id = ? AND status IN ('queued', 'running', 'waiting_captcha') | |
| ORDER BY created_at DESC | |
| LIMIT 1 | |
| """, | |
| (user_id,), | |
| ).fetchone() | |
| return dict(row) if row else None | |
| def create_task( | |
| self, | |
| *, | |
| user_id: int, | |
| requested_by_role: str, | |
| requested_by_identity: str, | |
| payload: dict[str, Any], | |
| ) -> str: | |
| task_id = str(uuid.uuid4()) | |
| now = utc_now() | |
| with self._connect() as connection: | |
| connection.execute( | |
| """ | |
| INSERT INTO tasks( | |
| id, | |
| user_id, | |
| status, | |
| created_at, | |
| requested_by_role, | |
| requested_by_identity, | |
| total_count, | |
| completed_count, | |
| task_payload | |
| ) | |
| VALUES (?, ?, 'queued', ?, ?, ?, ?, 0, ?) | |
| """, | |
| ( | |
| task_id, | |
| user_id, | |
| now, | |
| requested_by_role, | |
| requested_by_identity, | |
| len(payload.get('courses', [])), | |
| json.dumps(payload, ensure_ascii=False), | |
| ), | |
| ) | |
| return task_id | |
| def claim_next_queued_task(self) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute("SELECT id FROM tasks WHERE status = 'queued' ORDER BY created_at ASC LIMIT 1").fetchone() | |
| if not row: | |
| return None | |
| updated = connection.execute( | |
| """ | |
| UPDATE tasks | |
| SET status = 'running', started_at = ?, last_error = '' | |
| WHERE id = ? AND status = 'queued' | |
| """, | |
| (utc_now(), row['id']), | |
| ).rowcount | |
| if not updated: | |
| return None | |
| claimed = connection.execute('SELECT * FROM tasks WHERE id = ?', (row['id'],)).fetchone() | |
| return dict(claimed) if claimed else None | |
| def get_task(self, task_id: str) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT * FROM tasks WHERE id = ?', (task_id,)).fetchone() | |
| return dict(row) if row else None | |
| def get_task_with_user(self, task_id: str) -> dict[str, Any] | None: | |
| with self._connect() as connection: | |
| row = connection.execute( | |
| """ | |
| SELECT | |
| t.*, | |
| u.student_id, | |
| u.display_name | |
| FROM tasks t | |
| JOIN users u ON u.id = t.user_id | |
| WHERE t.id = ? | |
| """, | |
| (task_id,), | |
| ).fetchone() | |
| return dict(row) if row else None | |
| def list_recent_tasks_for_user(self, user_id: int, limit: int = 12) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute( | |
| """ | |
| SELECT id, user_id, status, created_at, started_at, finished_at, stop_requested, | |
| last_error, total_count, completed_count | |
| FROM tasks | |
| WHERE user_id = ? | |
| ORDER BY created_at DESC | |
| LIMIT ? | |
| """, | |
| (user_id, limit), | |
| ).fetchall() | |
| return [dict(row) for row in rows] | |
| def list_recent_tasks(self, limit: int = 20) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute( | |
| """ | |
| SELECT | |
| t.id, | |
| t.user_id, | |
| t.status, | |
| t.created_at, | |
| t.started_at, | |
| t.finished_at, | |
| t.stop_requested, | |
| t.last_error, | |
| t.total_count, | |
| t.completed_count, | |
| u.student_id, | |
| u.display_name | |
| FROM tasks t | |
| JOIN users u ON u.id = t.user_id | |
| ORDER BY t.created_at DESC | |
| LIMIT ? | |
| """, | |
| (limit,), | |
| ).fetchall() | |
| return [dict(row) for row in rows] | |
| def set_task_status( | |
| self, | |
| task_id: str, | |
| status: str, | |
| *, | |
| last_error: str | None = None, | |
| completed_count: int | None = None, | |
| ) -> None: | |
| assignments = ['status = ?'] | |
| values: list[Any] = [status] | |
| if last_error is not None: | |
| assignments.append('last_error = ?') | |
| values.append(last_error) | |
| if completed_count is not None: | |
| assignments.append('completed_count = ?') | |
| values.append(completed_count) | |
| if status == 'running': | |
| assignments.append('started_at = COALESCE(started_at, ?)') | |
| values.append(utc_now()) | |
| if status in TERMINAL_TASK_STATUSES: | |
| assignments.append('finished_at = ?') | |
| values.append(utc_now()) | |
| values.append(task_id) | |
| with self._connect() as connection: | |
| connection.execute(f"UPDATE tasks SET {', '.join(assignments)} WHERE id = ?", values) | |
| def update_task_progress(self, task_id: str, completed_count: int) -> None: | |
| with self._connect() as connection: | |
| connection.execute('UPDATE tasks SET completed_count = ? WHERE id = ?', (completed_count, task_id)) | |
| def request_task_stop(self, task_id: str) -> None: | |
| with self._connect() as connection: | |
| connection.execute('UPDATE tasks SET stop_requested = 1 WHERE id = ?', (task_id,)) | |
| def is_task_stop_requested(self, task_id: str) -> bool: | |
| with self._connect() as connection: | |
| row = connection.execute('SELECT stop_requested FROM tasks WHERE id = ?', (task_id,)).fetchone() | |
| return bool(row and row['stop_requested']) | |
| def append_task_log(self, task_id: str, level: str, message: str) -> int: | |
| with self._connect() as connection: | |
| cursor = connection.execute( | |
| """ | |
| INSERT INTO task_logs(task_id, level, message, created_at) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (task_id, level, message, utc_now()), | |
| ) | |
| return int(cursor.lastrowid) | |
| def list_task_logs(self, task_id: str, after_id: int = 0, limit: int = 200) -> list[dict[str, Any]]: | |
| with self._connect() as connection: | |
| rows = connection.execute( | |
| """ | |
| SELECT id, level, message, created_at | |
| FROM task_logs | |
| WHERE task_id = ? AND id > ? | |
| ORDER BY id ASC | |
| LIMIT ? | |
| """, | |
| (task_id, after_id, limit), | |
| ).fetchall() | |
| return [dict(row) for row in rows] | |
| def get_system_snapshot(self) -> dict[str, Any]: | |
| with self._connect() as connection: | |
| return { | |
| 'users': connection.execute('SELECT COUNT(*) FROM users').fetchone()[0], | |
| 'admins': connection.execute('SELECT COUNT(*) FROM admins').fetchone()[0], | |
| 'queued': connection.execute("SELECT COUNT(*) FROM tasks WHERE status = 'queued'").fetchone()[0], | |
| 'running': connection.execute( | |
| "SELECT COUNT(*) FROM tasks WHERE status IN ('running', 'waiting_captcha')" | |
| ).fetchone()[0], | |
| 'max_parallel_tasks': self.get_max_parallel_tasks(), | |
| } | |