| from __future__ import annotations
|
|
|
| import logging |
| import re |
| import threading |
| import time |
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from zoneinfo import ZoneInfo |
|
|
| from core.course_bot import CourseBot, TaskResult |
| from core.db import Database |
| from core.login_modes import LOGIN_MODE_LABELS, LOGIN_MODE_UNIFIED, normalize_login_mode |
| from core.security import SecretBox |
| from core.task_modes import TASK_RUN_MODE_LABELS, TASK_RUN_MODE_STABLE, normalize_task_run_mode |
|
|
|
|
| LOGGER = logging.getLogger("sacc.task_manager")
|
|
|
|
|
| @dataclass(slots=True) |
| class RunningTask: |
| task_id: int |
| thread: threading.Thread |
| stop_event: threading.Event |
|
|
|
|
| @dataclass(slots=True) |
| class LoginTwoFactorChallenge: |
| task_id: int |
| user_id: int |
| student_id: str |
| display_name: str |
| phone_mask: str |
| created_at: float |
| expires_at: float |
| event: threading.Event = field(default_factory=threading.Event) |
| submitted_code: str = "" |
| submitted_by: str = "" |
|
|
|
|
| class TaskManager: |
| login_2fa_timeout_seconds = 300 |
|
|
| def __init__(self, *, config, store: Database, secret_box: SecretBox) -> None: |
| self.config = config |
| self.store = store |
| self.secret_box = secret_box |
| self._started = False |
| self._startup_lock = threading.Lock() |
| self._queue_lock = threading.Lock() |
| self._running_lock = threading.Lock() |
| self._shutdown_event = threading.Event()
|
| self._dispatcher_thread: threading.Thread | None = None |
| self._running: dict[int, RunningTask] = {} |
| self._last_dispatch_snapshot: tuple[int, int, int] | None = None |
| self._schedule_timezone = ZoneInfo(getattr(self.config, "schedule_timezone", "Asia/Shanghai")) |
| self._login_2fa_lock = threading.Lock() |
| self._login_2fa_challenges: dict[int, LoginTwoFactorChallenge] = {} |
|
|
| def start(self) -> None:
|
| with self._startup_lock:
|
| if self._started:
|
| LOGGER.info("Task manager start skipped because it is already running")
|
| return
|
| LOGGER.info(
|
| "Task manager starting | db_path=%s default_parallel_limit=%s schedule_timezone=%s",
|
| self.store.path,
|
| getattr(self.config, "default_parallel_limit", "-"),
|
| getattr(self.config, "schedule_timezone", "Asia/Shanghai"),
|
| ) |
| self.store.reset_inflight_tasks() |
| normalized_modes = self.store.normalize_legacy_task_modes() |
| if normalized_modes: |
| LOGGER.info("Normalized %s legacy task mode record(s) to stable", normalized_modes) |
| normalized_login_modes = self.store.normalize_legacy_login_modes() |
| if normalized_login_modes: |
| LOGGER.info("Normalized %s legacy login mode record(s) to unified", normalized_login_modes) |
| self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, name="task-dispatcher", daemon=True) |
| self._dispatcher_thread.start() |
| self._started = True
|
| LOGGER.info("Task manager started")
|
|
|
| def shutdown(self) -> None: |
| LOGGER.info("Task manager shutdown requested") |
| self._shutdown_event.set() |
| with self._running_lock: |
| for running_task in self._running.values(): |
| running_task.stop_event.set() |
| with self._login_2fa_lock: |
| self._login_2fa_challenges.clear() |
| LOGGER.info("Task manager stop signal broadcast to %s running task(s)", self._running_count()) |
|
|
| def queue_task( |
| self, |
| user_id: int, |
| requested_by: str, |
| requested_by_role: str, |
| requested_mode: str = TASK_RUN_MODE_STABLE, |
| login_mode: str = LOGIN_MODE_UNIFIED, |
| use_proxy: bool = False, |
| ) -> tuple[dict, bool]: |
| normalized_mode = normalize_task_run_mode(requested_mode) |
| normalized_login_mode = normalize_login_mode(login_mode) |
| with self._queue_lock: |
| active_task = self.store.find_active_task_for_user(user_id) |
| if active_task is None: |
| task_id = self.store.create_task( |
| user_id, |
| requested_by, |
| requested_by_role, |
| requested_mode=normalized_mode, |
| login_mode=normalized_login_mode, |
| use_proxy=use_proxy, |
| ) |
| else: |
| task_id = int(active_task["id"]) |
|
|
| if active_task is not None: |
| LOGGER.info( |
| "Task queue skipped because an active task already exists | task_id=%s user_id=%s status=%s", |
| active_task["id"],
|
| user_id, |
| active_task["status"], |
| ) |
| self._log( |
| int(active_task["id"]), |
| user_id, |
| "SYSTEM", |
| "INFO", |
| f"收到新的启动请求,但已有活动任务处于 {active_task['status']} 状态,本次不重复创建。触发者: {requested_by_role}:{requested_by}。", |
| ) |
| return self.store.get_task(active_task["id"]) or active_task, False |
|
|
| course_count = len(self.store.list_courses_for_user(user_id)) |
| LOGGER.info( |
| "Task queued | task_id=%s user_id=%s requested_by_role=%s requested_by=%s requested_mode=%s login_mode=%s use_proxy=%s", |
| task_id, |
| user_id, |
| requested_by_role, |
| requested_by, |
| normalized_mode, |
| normalized_login_mode, |
| bool(use_proxy), |
| ) |
| self._log( |
| task_id, |
| user_id, |
| "SYSTEM", |
| "INFO", |
| f"任务已进入队列,触发者: {requested_by_role}:{requested_by},模式: {self._mode_label(normalized_mode)},登录方式: {self._login_mode_label(normalized_login_mode)},代理: {'启用' if use_proxy else '关闭'},待选课程 {course_count} 门。", |
| ) |
| return self.store.get_task(task_id), True |
|
|
| def stop_task(self, task_id: int) -> bool:
|
| requested = self.store.request_task_stop(task_id)
|
| if not requested:
|
| LOGGER.info("Task stop request ignored because task is not active | task_id=%s", task_id)
|
| return False
|
| LOGGER.info("Task stop requested | task_id=%s", task_id)
|
| self._log(task_id, None, "SYSTEM", "INFO", "收到停止请求,任务会在安全节点退出。")
|
| with self._running_lock: |
| running_task = self._running.get(task_id) |
| if running_task is not None: |
| running_task.stop_event.set() |
| self._remove_login_2fa_challenge(task_id) |
| return True |
|
|
| def request_login_2fa_code(self, task_id: int, user: dict, phone_mask: str, stop_event: threading.Event) -> str: |
| now = time.time() |
| challenge = LoginTwoFactorChallenge( |
| task_id=task_id, |
| user_id=int(user["id"]), |
| student_id=str(user.get("student_id") or ""), |
| display_name=str(user.get("display_name") or ""), |
| phone_mask=str(phone_mask or "绑定手机"), |
| created_at=now, |
| expires_at=now + self.login_2fa_timeout_seconds, |
| ) |
| with self._login_2fa_lock: |
| self._prune_login_2fa_challenges_locked(now=now) |
| self._login_2fa_challenges[task_id] = challenge |
|
|
| self.store.update_task_status(task_id, "waiting_2fa", "等待短信验证码。") |
| self._log( |
| task_id, |
| int(user["id"]), |
| "SYSTEM", |
| "INFO", |
| f"已发送短信验证码,请在 {self.login_2fa_timeout_seconds // 60} 分钟内到面板输入 6 位验证码。手机号={challenge.phone_mask}。", |
| ) |
|
|
| next_wait_log_at = now + 60 |
| while time.time() < challenge.expires_at: |
| if stop_event.is_set() or self._shutdown_event.is_set(): |
| self._remove_login_2fa_challenge(task_id) |
| self._log(task_id, int(user["id"]), "SYSTEM", "WARNING", "短信验证码等待被停止,已清理本次验证码挑战。") |
| raise RuntimeError("短信验证码等待已停止。") |
| if challenge.event.wait(timeout=0.2): |
| code = challenge.submitted_code |
| self._remove_login_2fa_challenge(task_id) |
| if not code: |
| self._log(task_id, int(user["id"]), "SYSTEM", "WARNING", "短信验证码等待被停止,未收到有效提交。") |
| raise RuntimeError("短信验证码等待已停止。") |
| current_task = self.store.get_task(task_id) |
| if current_task and current_task.get("status") == "waiting_2fa": |
| self.store.update_task_status(task_id, "running", "") |
| self._log(task_id, int(user["id"]), "SYSTEM", "INFO", "短信验证码已转交后台 Selenium,会继续完成统一认证。") |
| return code |
| current_time = time.time() |
| if current_time >= next_wait_log_at: |
| remaining_seconds = max(0, int(challenge.expires_at - current_time)) |
| self._log( |
| task_id, |
| int(user["id"]), |
| "SYSTEM", |
| "INFO", |
| f"仍在等待短信验证码提交,剩余约 {remaining_seconds} 秒。", |
| ) |
| next_wait_log_at = current_time + 60 |
|
|
| self._remove_login_2fa_challenge(task_id) |
| self._log(task_id, int(user["id"]), "SYSTEM", "ERROR", "等待短信验证码超时,已清理本次验证码挑战。") |
| raise TimeoutError("等待短信验证码超时,请重新启动任务。") |
|
|
| def submit_login_2fa_code(self, task_id: int, code: str, *, submitted_by: str) -> tuple[bool, str]: |
| normalized_code = str(code or "").strip() |
| if not re.fullmatch(r"\d{6}", normalized_code): |
| return False, "短信验证码必须是 6 位数字。" |
| now = time.time() |
| with self._login_2fa_lock: |
| self._prune_login_2fa_challenges_locked(now=now) |
| challenge = self._login_2fa_challenges.get(task_id) |
| if challenge is None: |
| return False, "当前没有等待验证码的任务,或验证码已过期。" |
| if challenge.event.is_set(): |
| return False, "验证码已经提交,请等待后台继续登录。" |
| challenge.submitted_code = normalized_code |
| challenge.submitted_by = submitted_by |
| challenge.event.set() |
|
|
| self._log(task_id, challenge.user_id, "SYSTEM", "INFO", f"已收到短信验证码,后台继续登录。提交者: {submitted_by}") |
| return True, "验证码已提交,后台会继续完成登录。" |
|
|
| def get_login_2fa_challenge(self, task_id: int) -> dict | None: |
| with self._login_2fa_lock: |
| self._prune_login_2fa_challenges_locked() |
| challenge = self._login_2fa_challenges.get(task_id) |
| return None if challenge is None else self._serialize_login_2fa_challenge(challenge) |
|
|
| def get_login_2fa_challenge_for_user(self, user_id: int) -> dict | None: |
| with self._login_2fa_lock: |
| self._prune_login_2fa_challenges_locked() |
| for challenge in self._login_2fa_challenges.values(): |
| if challenge.user_id == int(user_id): |
| return self._serialize_login_2fa_challenge(challenge) |
| return None |
|
|
| def list_login_2fa_challenges(self) -> list[dict]: |
| with self._login_2fa_lock: |
| self._prune_login_2fa_challenges_locked() |
| challenges = [self._serialize_login_2fa_challenge(challenge) for challenge in self._login_2fa_challenges.values()] |
| return sorted(challenges, key=lambda item: int(item["remaining_seconds"])) |
|
|
| def _remove_login_2fa_challenge(self, task_id: int) -> None: |
| with self._login_2fa_lock: |
| self._login_2fa_challenges.pop(task_id, None) |
|
|
| def _prune_login_2fa_challenges_locked(self, *, now: float | None = None) -> None: |
| current_time = time.time() if now is None else now |
| expired_task_ids = [ |
| task_id |
| for task_id, challenge in self._login_2fa_challenges.items() |
| if current_time >= challenge.expires_at |
| ] |
| for task_id in expired_task_ids: |
| self._login_2fa_challenges.pop(task_id, None) |
|
|
| @staticmethod |
| def _serialize_login_2fa_challenge(challenge: LoginTwoFactorChallenge) -> dict: |
| remaining_seconds = max(0, int(challenge.expires_at - time.time() + 0.999)) |
| return { |
| "task_id": challenge.task_id, |
| "user_id": challenge.user_id, |
| "student_id": challenge.student_id, |
| "display_name": challenge.display_name, |
| "phone_mask": challenge.phone_mask, |
| "remaining_seconds": remaining_seconds, |
| } |
|
|
| def _current_schedule_now(self) -> datetime:
|
| return datetime.now(self._schedule_timezone)
|
|
|
| def _within_schedule_date_window(self, schedule: dict, today_text: str) -> bool:
|
| start_date = str(schedule.get("start_date") or "").strip()
|
| end_date = str(schedule.get("end_date") or "").strip()
|
| if start_date and today_text < start_date:
|
| return False
|
| if end_date and today_text > end_date:
|
| return False
|
| return True
|
|
|
| def _apply_user_schedules(self) -> None:
|
| now = self._current_schedule_now()
|
| today_text = now.date().isoformat()
|
| current_time_text = now.strftime("%H:%M")
|
|
|
| for schedule in self.store.list_enabled_user_schedules():
|
| if not bool(schedule.get("user_is_active", 0)):
|
| continue
|
| if not self._within_schedule_date_window(schedule, today_text):
|
| continue
|
|
|
| start_time = str(schedule.get("daily_start_time") or "").strip()
|
| stop_time = str(schedule.get("daily_stop_time") or "").strip()
|
| if not start_time or not stop_time:
|
| continue
|
|
|
| user_id = int(schedule["user_id"])
|
| active_task = self.store.find_active_task_for_user(user_id)
|
|
|
| if current_time_text >= stop_time and str(schedule.get("last_auto_stop_on") or "") != today_text:
|
| if active_task and self.stop_task(int(active_task["id"])):
|
| self._log(int(active_task["id"]), user_id, "SYSTEM", "INFO", "已按管理员定时设置发送停止请求。")
|
| else:
|
| self._log(None, user_id, "SYSTEM", "INFO", "已到定时停止时间,当前没有运行中的任务。")
|
| self.store.mark_schedule_auto_stop(user_id, today_text)
|
| continue
|
|
|
| if not (start_time <= current_time_text < stop_time):
|
| continue
|
| if str(schedule.get("last_auto_start_on") or "") == today_text:
|
| continue
|
|
|
| user = self.store.get_user(user_id)
|
| if user is None:
|
| self.store.mark_schedule_auto_start(user_id, today_text)
|
| continue
|
| if not self.store.list_courses_for_user(user_id):
|
| self._log(None, user_id, "SYSTEM", "INFO", "已到定时启动时间,但当前没有课程目标,今天不自动启动。")
|
| self.store.mark_schedule_auto_start(user_id, today_text)
|
| continue
|
|
|
| task, created = self.queue_task( |
| user_id, |
| requested_by="scheduler", |
| requested_by_role="system", |
| requested_mode=TASK_RUN_MODE_STABLE, |
| login_mode=LOGIN_MODE_UNIFIED, |
| use_proxy=False, |
| ) |
| if created:
|
| self._log(task["id"], user_id, "SYSTEM", "INFO", "已按管理员定时设置自动加入任务队列。")
|
| else:
|
| self._log(task["id"], user_id, "SYSTEM", "INFO", "已到定时启动时间,但当前已有任务在运行或排队。")
|
| self.store.mark_schedule_auto_start(user_id, today_text)
|
|
|
| def _dispatch_loop(self) -> None:
|
| LOGGER.info("Dispatcher loop started")
|
| while not self._shutdown_event.is_set():
|
| self._cleanup_running_registry()
|
| self._apply_user_schedules()
|
| parallel_limit = self.store.get_parallel_limit()
|
| running_count = self._running_count()
|
| available_slots = max(0, parallel_limit - running_count)
|
| pending_tasks: list[dict] = []
|
| if available_slots > 0:
|
| pending_tasks = self.store.list_pending_tasks(available_slots)
|
|
|
| snapshot = (parallel_limit, running_count, len(pending_tasks))
|
| if snapshot != self._last_dispatch_snapshot and (running_count > 0 or pending_tasks):
|
| LOGGER.info(
|
| "Dispatcher snapshot | parallel_limit=%s running=%s available_slots=%s fetched_pending=%s",
|
| parallel_limit,
|
| running_count,
|
| available_slots,
|
| len(pending_tasks),
|
| )
|
| self._last_dispatch_snapshot = snapshot
|
|
|
| if pending_tasks: |
| for task in pending_tasks: |
| if self._shutdown_event.is_set(): |
| break |
| if not task["is_active"]: |
| LOGGER.warning( |
| "Queued task cancelled because user is inactive | task_id=%s user_id=%s",
|
| task["id"],
|
| task["user_id"],
|
| )
|
| self.store.finish_task(task["id"], "failed", "该用户已被禁用,任务未执行。")
|
| self._log(task["id"], task["user_id"], "SYSTEM", "WARNING", "用户已禁用,队列中的任务被取消。")
|
| continue |
| self._log( |
| task["id"], |
| task["user_id"], |
| "SYSTEM", |
| "INFO", |
| f"调度器已获取执行名额,准备启动任务线程。当前并发 {running_count}/{parallel_limit}。", |
| ) |
| self._launch_task(task["id"]) |
| time.sleep(1) |
| LOGGER.info("Dispatcher loop stopped") |
|
|
| def _launch_task(self, task_id: int) -> None:
|
| with self._running_lock:
|
| if task_id in self._running:
|
| LOGGER.info("Task launch skipped because thread already exists | task_id=%s", task_id)
|
| return
|
| stop_event = threading.Event()
|
| thread = threading.Thread(target=self._run_task, args=(task_id, stop_event), name=f"task-{task_id}", daemon=True)
|
| self._running[task_id] = RunningTask(task_id=task_id, thread=thread, stop_event=stop_event) |
| LOGGER.info("Launching task thread | task_id=%s thread_name=%s", task_id, thread.name) |
| task = self.store.get_task(task_id) |
| self._log( |
| task_id, |
| int(task["user_id"]) if task else None, |
| "SYSTEM", |
| "INFO", |
| f"调度器已分配执行线程 {thread.name},任务即将开始运行。", |
| ) |
| thread.start() |
|
|
| def _run_task(self, task_id: int, stop_event: threading.Event) -> None:
|
| task = self.store.get_task(task_id)
|
| if task is None:
|
| LOGGER.warning("Task record disappeared before execution | task_id=%s", task_id)
|
| self._remove_running(task_id)
|
| return
|
|
|
| user = self.store.get_user(task["user_id"])
|
| if user is None:
|
| LOGGER.error("Task cannot start because user does not exist | task_id=%s user_id=%s", task_id, task["user_id"])
|
| self.store.finish_task(task_id, "failed", "用户不存在。")
|
| self._remove_running(task_id)
|
| return
|
|
|
| LOGGER.info("Task runner starting | task_id=%s user_id=%s", task_id, user["id"]) |
| if not self.store.mark_task_running(task_id): |
| latest_task = self.store.get_task(task_id) or task |
| latest_status = str(latest_task.get("status") or "") |
| if latest_status == "cancel_requested": |
| self.store.finish_task(task_id, "stopped", "任务启动前已收到停止请求。") |
| self._log(task_id, user["id"], "SYSTEM", "INFO", "任务启动前已收到停止请求,未启动 Selenium。") |
| else: |
| self._log(task_id, user["id"], "SYSTEM", "WARNING", f"任务启动前状态已变为 {latest_status or '未知'},未覆盖为运行态。") |
| self._remove_running(task_id) |
| return |
| course_count = len(self.store.list_courses_for_user(user["id"])) |
| requested_mode = normalize_task_run_mode(task.get("requested_mode")) |
| effective_mode = normalize_task_run_mode(task.get("effective_mode")) |
| login_mode = normalize_login_mode(task.get("login_mode")) |
| use_proxy = bool(int(task.get("use_proxy") or 0)) |
| refresh_interval = user.get("refresh_interval_seconds") or getattr(self.config, "poll_interval_seconds", "-") |
| self._log( |
| task_id, |
| user["id"], |
| "SYSTEM", |
| "INFO", |
| "任务开始执行。" |
| f"学号={user.get('student_id') or '-'}," |
| f"待选课程 {course_count} 门," |
| f"请求模式={self._mode_label(requested_mode)},当前模式={self._mode_label(effective_mode)}," |
| f"登录方式={self._login_mode_label(login_mode)}," |
| f"代理={'启用' if use_proxy else '关闭'}," |
| f"刷新间隔={refresh_interval} 秒。", |
| ) |
|
|
| try:
|
| password = self.secret_box.decrypt(user["password_encrypted"])
|
| bot = CourseBot( |
| config=self.config, |
| store=self.store, |
| task_id=task_id, |
| user=user, |
| password=password, |
| logger=lambda level, message: self._log(task_id, user["id"], "RUNNER", level, message), |
| login_2fa_code_provider=lambda phone_mask, current_stop_event: self.request_login_2fa_code( |
| task_id, |
| user, |
| phone_mask, |
| current_stop_event, |
| ), |
| ) |
| result = bot.run(stop_event)
|
| except Exception as exc:
|
| LOGGER.exception("Task initialization failed | task_id=%s user_id=%s", task_id, user["id"])
|
| result = TaskResult(status="failed", error=str(exc))
|
| self._log(task_id, user["id"], "SYSTEM", "ERROR", f"任务初始化失败: {exc}")
|
|
|
| current_task = self.store.get_task(task_id)
|
| final_status = result.status
|
| if current_task and current_task["status"] == "cancel_requested" and result.status != "completed":
|
| final_status = "stopped"
|
| elif stop_event.is_set() and result.status != "completed":
|
| final_status = "stopped"
|
|
|
| self.store.finish_task(task_id, final_status, result.error) |
| final_task = self.store.get_task(task_id) or {} |
| final_attempts = int(final_task.get("total_attempts") or 0) |
| final_errors = int(final_task.get("total_errors") or 0) |
| remaining_courses = len(self.store.list_courses_for_user(user["id"])) |
| final_text = result.error or f"任务已结束,状态: {final_status}" |
| final_text = f"{final_text} 总尝试 {final_attempts} 次,累计异常 {final_errors} 次,剩余待选课程 {remaining_courses} 门。" |
| level = "ERROR" if final_status == "failed" else "INFO" |
| self._log(task_id, user["id"], "SYSTEM", level, final_text) |
| LOGGER.info(
|
| "Task runner finished | task_id=%s user_id=%s final_status=%s",
|
| task_id,
|
| user["id"],
|
| final_status,
|
| )
|
| self._remove_running(task_id)
|
|
|
| def _log(self, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None:
|
| upper_level = level.upper()
|
| self.store.add_log(task_id, user_id, scope, upper_level, message)
|
| if task_id is not None and upper_level in {"WARNING", "ERROR"}:
|
| self.store.increment_task_errors(task_id)
|
| self._emit_runtime_log(task_id=task_id, user_id=user_id, scope=scope, level=upper_level, message=message)
|
|
|
| @staticmethod
|
| def _emit_runtime_log(*, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None:
|
| upper_level = level.upper()
|
| rendered = f"task_id={task_id or '-'} user_id={user_id or '-'} scope={scope} level={upper_level} | {message}"
|
| if upper_level == "ERROR":
|
| LOGGER.error(rendered)
|
| elif upper_level == "WARNING":
|
| LOGGER.warning(rendered)
|
| else:
|
| LOGGER.info(rendered)
|
|
|
| def _running_count(self) -> int: |
| with self._running_lock: |
| return len(self._running) |
|
|
| @staticmethod |
| def _mode_label(mode: str) -> str: |
| return TASK_RUN_MODE_LABELS.get(mode, mode) |
|
|
| @staticmethod |
| def _login_mode_label(mode: str) -> str: |
| return LOGIN_MODE_LABELS.get(mode, mode) |
|
|
| def _cleanup_running_registry(self) -> None:
|
| finished_ids: list[int] = []
|
| with self._running_lock:
|
| for task_id, running_task in self._running.items():
|
| if not running_task.thread.is_alive():
|
| finished_ids.append(task_id)
|
| for task_id in finished_ids:
|
| self._running.pop(task_id, None)
|
| for task_id in finished_ids:
|
| LOGGER.info("Task thread cleaned from registry | task_id=%s", task_id)
|
|
|
| def _remove_running(self, task_id: int) -> None:
|
| with self._running_lock:
|
| removed = self._running.pop(task_id, None)
|
| if removed is not None:
|
| LOGGER.info("Task removed from running registry | task_id=%s", task_id)
|
|
|
|
|
|
|
|
|