SACC-release / core /database.py
cacode's picture
Deploy updated SCU course catcher
e28c9e4 verified
from __future__ import annotations
import json
import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
TERMINAL_TASK_STATUSES = {'success', 'failed', 'stopped', 'interrupted'}
ACTIVE_TASK_STATUSES = {'queued', 'running', 'waiting_captcha'}
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat(timespec='seconds')
class Database:
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
connection = sqlite3.connect(self.db_path, check_same_thread=False)
connection.row_factory = sqlite3.Row
connection.execute('PRAGMA foreign_keys = ON')
return connection
def initialize(self) -> None:
with self._connect() as connection:
connection.execute('PRAGMA journal_mode = WAL')
connection.executescript(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
student_id TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL DEFAULT '',
encrypted_password TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS admins (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS courses (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
course_id TEXT NOT NULL,
course_index TEXT NOT NULL,
created_at TEXT NOT NULL,
UNIQUE(user_id, course_id, course_index),
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
started_at TEXT,
finished_at TEXT,
requested_by_role TEXT NOT NULL,
requested_by_identity TEXT NOT NULL,
stop_requested INTEGER NOT NULL DEFAULT 0,
last_error TEXT NOT NULL DEFAULT '',
total_count INTEGER NOT NULL DEFAULT 0,
completed_count INTEGER NOT NULL DEFAULT 0,
task_payload TEXT NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS task_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
level TEXT NOT NULL,
message TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_courses_user_id ON courses(user_id);
CREATE INDEX IF NOT EXISTS idx_tasks_user_id ON tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id ON task_logs(task_id, id);
"""
)
connection.execute(
'INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)',
('max_parallel_tasks', '2'),
)
connection.execute(
"""
UPDATE tasks
SET status = 'interrupted',
finished_at = ?,
last_error = CASE
WHEN COALESCE(last_error, '') = '' THEN '应用重启,上一轮任务被中断。'
ELSE last_error
END
WHERE status IN ('queued', 'running', 'waiting_captcha')
""",
(utc_now(),),
)
def get_setting(self, key: str, default: str = '') -> str:
with self._connect() as connection:
row = connection.execute('SELECT value FROM settings WHERE key = ?', (key,)).fetchone()
return row['value'] if row else default
def set_setting(self, key: str, value: str) -> None:
with self._connect() as connection:
connection.execute(
"""
INSERT INTO settings(key, value) VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = excluded.value
""",
(key, value),
)
def get_max_parallel_tasks(self) -> int:
raw = self.get_setting('max_parallel_tasks', '2')
try:
return max(1, min(8, int(raw)))
except ValueError:
return 2
def create_user(self, student_id: str, encrypted_password: str, display_name: str = '') -> int:
now = utc_now()
with self._connect() as connection:
cursor = connection.execute(
"""
INSERT INTO users(student_id, display_name, encrypted_password, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(student_id, display_name, encrypted_password, now, now),
)
return int(cursor.lastrowid)
def update_user(
self,
user_id: int,
*,
student_id: str | None = None,
encrypted_password: str | None = None,
display_name: str | None = None,
) -> None:
fields: list[str] = []
values: list[Any] = []
if student_id is not None:
fields.append('student_id = ?')
values.append(student_id)
if encrypted_password is not None:
fields.append('encrypted_password = ?')
values.append(encrypted_password)
if display_name is not None:
fields.append('display_name = ?')
values.append(display_name)
if not fields:
return
fields.append('updated_at = ?')
values.append(utc_now())
values.append(user_id)
with self._connect() as connection:
connection.execute(f"UPDATE users SET {', '.join(fields)} WHERE id = ?", values)
def delete_user(self, user_id: int) -> None:
with self._connect() as connection:
connection.execute('DELETE FROM users WHERE id = ?', (user_id,))
def get_user_by_student_id(self, student_id: str) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute('SELECT * FROM users WHERE student_id = ?', (student_id,)).fetchone()
return dict(row) if row else None
def get_user_by_id(self, user_id: int) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute('SELECT * FROM users WHERE id = ?', (user_id,)).fetchone()
return dict(row) if row else None
def list_users(self) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute(
"""
SELECT
u.*,
COUNT(c.id) AS course_count,
(
SELECT t.status
FROM tasks t
WHERE t.user_id = u.id
ORDER BY t.created_at DESC
LIMIT 1
) AS latest_task_status
FROM users u
LEFT JOIN courses c ON c.user_id = u.id
GROUP BY u.id
ORDER BY u.created_at ASC
"""
).fetchall()
return [dict(row) for row in rows]
def list_courses_for_user(self, user_id: int) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute(
"""
SELECT id, course_id, course_index, created_at
FROM courses
WHERE user_id = ?
ORDER BY created_at ASC, id ASC
""",
(user_id,),
).fetchall()
return [dict(row) for row in rows]
def add_course(self, user_id: int, course_id: str, course_index: str) -> None:
with self._connect() as connection:
connection.execute(
"""
INSERT OR IGNORE INTO courses(user_id, course_id, course_index, created_at)
VALUES (?, ?, ?, ?)
""",
(user_id, course_id, course_index, utc_now()),
)
def delete_course(self, course_row_id: int, user_id: int | None = None) -> None:
with self._connect() as connection:
if user_id is None:
connection.execute('DELETE FROM courses WHERE id = ?', (course_row_id,))
else:
connection.execute('DELETE FROM courses WHERE id = ? AND user_id = ?', (course_row_id, user_id))
def list_admins(self) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute('SELECT id, username, created_at, updated_at FROM admins ORDER BY created_at ASC').fetchall()
return [dict(row) for row in rows]
def get_admin_by_username(self, username: str) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute('SELECT * FROM admins WHERE username = ?', (username,)).fetchone()
return dict(row) if row else None
def create_admin(self, username: str, password_hash: str) -> int:
now = utc_now()
with self._connect() as connection:
cursor = connection.execute(
"""
INSERT INTO admins(username, password_hash, created_at, updated_at)
VALUES (?, ?, ?, ?)
""",
(username, password_hash, now, now),
)
return int(cursor.lastrowid)
def update_admin_password(self, admin_id: int, password_hash: str) -> None:
with self._connect() as connection:
connection.execute(
'UPDATE admins SET password_hash = ?, updated_at = ? WHERE id = ?',
(password_hash, utc_now(), admin_id),
)
def delete_admin(self, admin_id: int) -> None:
with self._connect() as connection:
connection.execute('DELETE FROM admins WHERE id = ?', (admin_id,))
def find_active_task_for_user(self, user_id: int) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute(
"""
SELECT * FROM tasks
WHERE user_id = ? AND status IN ('queued', 'running', 'waiting_captcha')
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
).fetchone()
return dict(row) if row else None
def create_task(
self,
*,
user_id: int,
requested_by_role: str,
requested_by_identity: str,
payload: dict[str, Any],
) -> str:
task_id = str(uuid.uuid4())
now = utc_now()
with self._connect() as connection:
connection.execute(
"""
INSERT INTO tasks(
id,
user_id,
status,
created_at,
requested_by_role,
requested_by_identity,
total_count,
completed_count,
task_payload
)
VALUES (?, ?, 'queued', ?, ?, ?, ?, 0, ?)
""",
(
task_id,
user_id,
now,
requested_by_role,
requested_by_identity,
len(payload.get('courses', [])),
json.dumps(payload, ensure_ascii=False),
),
)
return task_id
def claim_next_queued_task(self) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute("SELECT id FROM tasks WHERE status = 'queued' ORDER BY created_at ASC LIMIT 1").fetchone()
if not row:
return None
updated = connection.execute(
"""
UPDATE tasks
SET status = 'running', started_at = ?, last_error = ''
WHERE id = ? AND status = 'queued'
""",
(utc_now(), row['id']),
).rowcount
if not updated:
return None
claimed = connection.execute('SELECT * FROM tasks WHERE id = ?', (row['id'],)).fetchone()
return dict(claimed) if claimed else None
def get_task(self, task_id: str) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute('SELECT * FROM tasks WHERE id = ?', (task_id,)).fetchone()
return dict(row) if row else None
def get_task_with_user(self, task_id: str) -> dict[str, Any] | None:
with self._connect() as connection:
row = connection.execute(
"""
SELECT
t.*,
u.student_id,
u.display_name
FROM tasks t
JOIN users u ON u.id = t.user_id
WHERE t.id = ?
""",
(task_id,),
).fetchone()
return dict(row) if row else None
def list_recent_tasks_for_user(self, user_id: int, limit: int = 12) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute(
"""
SELECT id, user_id, status, created_at, started_at, finished_at, stop_requested,
last_error, total_count, completed_count
FROM tasks
WHERE user_id = ?
ORDER BY created_at DESC
LIMIT ?
""",
(user_id, limit),
).fetchall()
return [dict(row) for row in rows]
def list_recent_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute(
"""
SELECT
t.id,
t.user_id,
t.status,
t.created_at,
t.started_at,
t.finished_at,
t.stop_requested,
t.last_error,
t.total_count,
t.completed_count,
u.student_id,
u.display_name
FROM tasks t
JOIN users u ON u.id = t.user_id
ORDER BY t.created_at DESC
LIMIT ?
""",
(limit,),
).fetchall()
return [dict(row) for row in rows]
def set_task_status(
self,
task_id: str,
status: str,
*,
last_error: str | None = None,
completed_count: int | None = None,
) -> None:
assignments = ['status = ?']
values: list[Any] = [status]
if last_error is not None:
assignments.append('last_error = ?')
values.append(last_error)
if completed_count is not None:
assignments.append('completed_count = ?')
values.append(completed_count)
if status == 'running':
assignments.append('started_at = COALESCE(started_at, ?)')
values.append(utc_now())
if status in TERMINAL_TASK_STATUSES:
assignments.append('finished_at = ?')
values.append(utc_now())
values.append(task_id)
with self._connect() as connection:
connection.execute(f"UPDATE tasks SET {', '.join(assignments)} WHERE id = ?", values)
def update_task_progress(self, task_id: str, completed_count: int) -> None:
with self._connect() as connection:
connection.execute('UPDATE tasks SET completed_count = ? WHERE id = ?', (completed_count, task_id))
def request_task_stop(self, task_id: str) -> None:
with self._connect() as connection:
connection.execute('UPDATE tasks SET stop_requested = 1 WHERE id = ?', (task_id,))
def is_task_stop_requested(self, task_id: str) -> bool:
with self._connect() as connection:
row = connection.execute('SELECT stop_requested FROM tasks WHERE id = ?', (task_id,)).fetchone()
return bool(row and row['stop_requested'])
def append_task_log(self, task_id: str, level: str, message: str) -> int:
with self._connect() as connection:
cursor = connection.execute(
"""
INSERT INTO task_logs(task_id, level, message, created_at)
VALUES (?, ?, ?, ?)
""",
(task_id, level, message, utc_now()),
)
return int(cursor.lastrowid)
def list_task_logs(self, task_id: str, after_id: int = 0, limit: int = 200) -> list[dict[str, Any]]:
with self._connect() as connection:
rows = connection.execute(
"""
SELECT id, level, message, created_at
FROM task_logs
WHERE task_id = ? AND id > ?
ORDER BY id ASC
LIMIT ?
""",
(task_id, after_id, limit),
).fetchall()
return [dict(row) for row in rows]
def get_system_snapshot(self) -> dict[str, Any]:
with self._connect() as connection:
return {
'users': connection.execute('SELECT COUNT(*) FROM users').fetchone()[0],
'admins': connection.execute('SELECT COUNT(*) FROM admins').fetchone()[0],
'queued': connection.execute("SELECT COUNT(*) FROM tasks WHERE status = 'queued'").fetchone()[0],
'running': connection.execute(
"SELECT COUNT(*) FROM tasks WHERE status IN ('running', 'waiting_captcha')"
).fetchone()[0],
'max_parallel_tasks': self.get_max_parallel_tasks(),
}