| from __future__ import annotations |
|
|
| import logging |
| import threading |
| import time |
| from dataclasses import dataclass |
| from datetime import datetime |
| from zoneinfo import ZoneInfo |
|
|
| from core.course_bot import CourseBot, TaskResult |
| from core.db import Database |
| from core.security import SecretBox |
|
|
|
|
| LOGGER = logging.getLogger("sacc.task_manager") |
|
|
|
|
| @dataclass(slots=True) |
| class RunningTask: |
| task_id: int |
| thread: threading.Thread |
| stop_event: threading.Event |
|
|
|
|
| class TaskManager: |
| def __init__(self, *, config, store: Database, secret_box: SecretBox) -> None: |
| self.config = config |
| self.store = store |
| self.secret_box = secret_box |
| self._started = False |
| self._startup_lock = threading.Lock() |
| self._running_lock = threading.Lock() |
| self._shutdown_event = threading.Event() |
| self._dispatcher_thread: threading.Thread | None = None |
| self._running: dict[int, RunningTask] = {} |
| self._last_dispatch_snapshot: tuple[int, int, int] | None = None |
| self._schedule_timezone = ZoneInfo(getattr(self.config, "schedule_timezone", "Asia/Shanghai")) |
|
|
| def start(self) -> None: |
| with self._startup_lock: |
| if self._started: |
| LOGGER.info("Task manager start skipped because it is already running") |
| return |
| LOGGER.info( |
| "Task manager starting | db_path=%s default_parallel_limit=%s schedule_timezone=%s", |
| self.store.path, |
| getattr(self.config, "default_parallel_limit", "-"), |
| getattr(self.config, "schedule_timezone", "Asia/Shanghai"), |
| ) |
| self.store.reset_inflight_tasks() |
| self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, name="task-dispatcher", daemon=True) |
| self._dispatcher_thread.start() |
| self._started = True |
| LOGGER.info("Task manager started") |
|
|
| def shutdown(self) -> None: |
| LOGGER.info("Task manager shutdown requested") |
| self._shutdown_event.set() |
| with self._running_lock: |
| for running_task in self._running.values(): |
| running_task.stop_event.set() |
| LOGGER.info("Task manager stop signal broadcast to %s running task(s)", self._running_count()) |
|
|
| def queue_task(self, user_id: int, requested_by: str, requested_by_role: str) -> tuple[dict, bool]: |
| active_task = self.store.find_active_task_for_user(user_id) |
| if active_task is not None: |
| LOGGER.info( |
| "Task queue skipped because an active task already exists | task_id=%s user_id=%s status=%s", |
| active_task["id"], |
| user_id, |
| active_task["status"], |
| ) |
| return self.store.get_task(active_task["id"]) or active_task, False |
|
|
| task_id = self.store.create_task(user_id, requested_by, requested_by_role) |
| LOGGER.info( |
| "Task queued | task_id=%s user_id=%s requested_by_role=%s requested_by=%s", |
| task_id, |
| user_id, |
| requested_by_role, |
| requested_by, |
| ) |
| self._log(task_id, user_id, "SYSTEM", "INFO", f"任务已进入队列,触发者: {requested_by_role}:{requested_by}") |
| return self.store.get_task(task_id), True |
|
|
| def stop_task(self, task_id: int) -> bool: |
| requested = self.store.request_task_stop(task_id) |
| if not requested: |
| LOGGER.info("Task stop request ignored because task is not active | task_id=%s", task_id) |
| return False |
| LOGGER.info("Task stop requested | task_id=%s", task_id) |
| self._log(task_id, None, "SYSTEM", "INFO", "收到停止请求,任务会在安全节点退出。") |
| with self._running_lock: |
| running_task = self._running.get(task_id) |
| if running_task is not None: |
| running_task.stop_event.set() |
| return True |
|
|
| def _current_schedule_now(self) -> datetime: |
| return datetime.now(self._schedule_timezone) |
|
|
| def _within_schedule_date_window(self, schedule: dict, today_text: str) -> bool: |
| start_date = str(schedule.get("start_date") or "").strip() |
| end_date = str(schedule.get("end_date") or "").strip() |
| if start_date and today_text < start_date: |
| return False |
| if end_date and today_text > end_date: |
| return False |
| return True |
|
|
| def _apply_user_schedules(self) -> None: |
| now = self._current_schedule_now() |
| today_text = now.date().isoformat() |
| current_time_text = now.strftime("%H:%M") |
|
|
| for schedule in self.store.list_enabled_user_schedules(): |
| if not bool(schedule.get("user_is_active", 0)): |
| continue |
| if not self._within_schedule_date_window(schedule, today_text): |
| continue |
|
|
| start_time = str(schedule.get("daily_start_time") or "").strip() |
| stop_time = str(schedule.get("daily_stop_time") or "").strip() |
| if not start_time or not stop_time: |
| continue |
|
|
| user_id = int(schedule["user_id"]) |
| active_task = self.store.find_active_task_for_user(user_id) |
|
|
| if current_time_text >= stop_time and str(schedule.get("last_auto_stop_on") or "") != today_text: |
| if active_task and self.stop_task(int(active_task["id"])): |
| self._log(int(active_task["id"]), user_id, "SYSTEM", "INFO", "已按管理员定时设置发送停止请求。") |
| else: |
| self._log(None, user_id, "SYSTEM", "INFO", "已到定时停止时间,当前没有运行中的任务。") |
| self.store.mark_schedule_auto_stop(user_id, today_text) |
| continue |
|
|
| if not (start_time <= current_time_text < stop_time): |
| continue |
| if str(schedule.get("last_auto_start_on") or "") == today_text: |
| continue |
|
|
| user = self.store.get_user(user_id) |
| if user is None: |
| self.store.mark_schedule_auto_start(user_id, today_text) |
| continue |
| if not self.store.list_courses_for_user(user_id): |
| self._log(None, user_id, "SYSTEM", "INFO", "已到定时启动时间,但当前没有课程目标,今天不自动启动。") |
| self.store.mark_schedule_auto_start(user_id, today_text) |
| continue |
|
|
| task, created = self.queue_task(user_id, requested_by="scheduler", requested_by_role="system") |
| if created: |
| self._log(task["id"], user_id, "SYSTEM", "INFO", "已按管理员定时设置自动加入任务队列。") |
| else: |
| self._log(task["id"], user_id, "SYSTEM", "INFO", "已到定时启动时间,但当前已有任务在运行或排队。") |
| self.store.mark_schedule_auto_start(user_id, today_text) |
|
|
| def _dispatch_loop(self) -> None: |
| LOGGER.info("Dispatcher loop started") |
| while not self._shutdown_event.is_set(): |
| self._cleanup_running_registry() |
| self._apply_user_schedules() |
| parallel_limit = self.store.get_parallel_limit() |
| running_count = self._running_count() |
| available_slots = max(0, parallel_limit - running_count) |
| pending_tasks: list[dict] = [] |
| if available_slots > 0: |
| pending_tasks = self.store.list_pending_tasks(available_slots) |
|
|
| snapshot = (parallel_limit, running_count, len(pending_tasks)) |
| if snapshot != self._last_dispatch_snapshot and (running_count > 0 or pending_tasks): |
| LOGGER.info( |
| "Dispatcher snapshot | parallel_limit=%s running=%s available_slots=%s fetched_pending=%s", |
| parallel_limit, |
| running_count, |
| available_slots, |
| len(pending_tasks), |
| ) |
| self._last_dispatch_snapshot = snapshot |
|
|
| if pending_tasks: |
| for task in pending_tasks: |
| if self._shutdown_event.is_set(): |
| break |
| if not task["is_active"]: |
| LOGGER.warning( |
| "Queued task cancelled because user is inactive | task_id=%s user_id=%s", |
| task["id"], |
| task["user_id"], |
| ) |
| self.store.finish_task(task["id"], "failed", "该用户已被禁用,任务未执行。") |
| self._log(task["id"], task["user_id"], "SYSTEM", "WARNING", "用户已禁用,队列中的任务被取消。") |
| continue |
| self._launch_task(task["id"]) |
| time.sleep(1) |
| LOGGER.info("Dispatcher loop stopped") |
|
|
| def _launch_task(self, task_id: int) -> None: |
| with self._running_lock: |
| if task_id in self._running: |
| LOGGER.info("Task launch skipped because thread already exists | task_id=%s", task_id) |
| return |
| stop_event = threading.Event() |
| thread = threading.Thread(target=self._run_task, args=(task_id, stop_event), name=f"task-{task_id}", daemon=True) |
| self._running[task_id] = RunningTask(task_id=task_id, thread=thread, stop_event=stop_event) |
| LOGGER.info("Launching task thread | task_id=%s thread_name=%s", task_id, thread.name) |
| thread.start() |
|
|
| def _run_task(self, task_id: int, stop_event: threading.Event) -> None: |
| task = self.store.get_task(task_id) |
| if task is None: |
| LOGGER.warning("Task record disappeared before execution | task_id=%s", task_id) |
| self._remove_running(task_id) |
| return |
|
|
| user = self.store.get_user(task["user_id"]) |
| if user is None: |
| LOGGER.error("Task cannot start because user does not exist | task_id=%s user_id=%s", task_id, task["user_id"]) |
| self.store.finish_task(task_id, "failed", "用户不存在。") |
| self._remove_running(task_id) |
| return |
|
|
| LOGGER.info("Task runner starting | task_id=%s user_id=%s", task_id, user["id"]) |
| self.store.mark_task_running(task_id) |
| self._log(task_id, user["id"], "SYSTEM", "INFO", "任务开始执行。") |
|
|
| try: |
| password = self.secret_box.decrypt(user["password_encrypted"]) |
| bot = CourseBot( |
| config=self.config, |
| store=self.store, |
| task_id=task_id, |
| user=user, |
| password=password, |
| logger=lambda level, message: self._log(task_id, user["id"], "RUNNER", level, message), |
| ) |
| result = bot.run(stop_event) |
| except Exception as exc: |
| LOGGER.exception("Task initialization failed | task_id=%s user_id=%s", task_id, user["id"]) |
| result = TaskResult(status="failed", error=str(exc)) |
| self._log(task_id, user["id"], "SYSTEM", "ERROR", f"任务初始化失败: {exc}") |
|
|
| current_task = self.store.get_task(task_id) |
| final_status = result.status |
| if current_task and current_task["status"] == "cancel_requested" and result.status != "completed": |
| final_status = "stopped" |
| elif stop_event.is_set() and result.status != "completed": |
| final_status = "stopped" |
|
|
| self.store.finish_task(task_id, final_status, result.error) |
| final_text = result.error or f"任务已结束,状态: {final_status}" |
| level = "ERROR" if final_status == "failed" else "INFO" |
| self._log(task_id, user["id"], "SYSTEM", level, final_text) |
| LOGGER.info( |
| "Task runner finished | task_id=%s user_id=%s final_status=%s", |
| task_id, |
| user["id"], |
| final_status, |
| ) |
| self._remove_running(task_id) |
|
|
| def _log(self, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None: |
| upper_level = level.upper() |
| self.store.add_log(task_id, user_id, scope, upper_level, message) |
| if task_id is not None and upper_level in {"WARNING", "ERROR"}: |
| self.store.increment_task_errors(task_id) |
| self._emit_runtime_log(task_id=task_id, user_id=user_id, scope=scope, level=upper_level, message=message) |
|
|
| @staticmethod |
| def _emit_runtime_log(*, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None: |
| upper_level = level.upper() |
| rendered = f"task_id={task_id or '-'} user_id={user_id or '-'} scope={scope} level={upper_level} | {message}" |
| if upper_level == "ERROR": |
| LOGGER.error(rendered) |
| elif upper_level == "WARNING": |
| LOGGER.warning(rendered) |
| else: |
| LOGGER.info(rendered) |
|
|
| def _running_count(self) -> int: |
| with self._running_lock: |
| return len(self._running) |
|
|
| def _cleanup_running_registry(self) -> None: |
| finished_ids: list[int] = [] |
| with self._running_lock: |
| for task_id, running_task in self._running.items(): |
| if not running_task.thread.is_alive(): |
| finished_ids.append(task_id) |
| for task_id in finished_ids: |
| self._running.pop(task_id, None) |
| for task_id in finished_ids: |
| LOGGER.info("Task thread cleaned from registry | task_id=%s", task_id) |
|
|
| def _remove_running(self, task_id: int) -> None: |
| with self._running_lock: |
| removed = self._running.pop(task_id, None) |
| if removed is not None: |
| LOGGER.info("Task removed from running registry | task_id=%s", task_id)
|
|
|