Spaces:
Paused
Paused
| from __future__ import annotations | |
| import logging | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from core.course_bot import CourseBot, TaskResult | |
| from core.db import Database | |
| from core.security import SecretBox | |
| LOGGER = logging.getLogger("sacc.task_manager") | |
| class RunningTask: | |
| task_id: int | |
| thread: threading.Thread | |
| stop_event: threading.Event | |
| class TaskManager: | |
| def __init__(self, *, config, store: Database, secret_box: SecretBox) -> None: | |
| self.config = config | |
| self.store = store | |
| self.secret_box = secret_box | |
| self._started = False | |
| self._startup_lock = threading.Lock() | |
| self._running_lock = threading.Lock() | |
| self._shutdown_event = threading.Event() | |
| self._dispatcher_thread: threading.Thread | None = None | |
| self._running: dict[int, RunningTask] = {} | |
| self._last_dispatch_snapshot: tuple[int, int, int] | None = None | |
| def start(self) -> None: | |
| with self._startup_lock: | |
| if self._started: | |
| LOGGER.info("Task manager start skipped because it is already running") | |
| return | |
| LOGGER.info( | |
| "Task manager starting | db_path=%s default_parallel_limit=%s", | |
| self.store.path, | |
| getattr(self.config, "default_parallel_limit", "-"), | |
| ) | |
| self.store.reset_inflight_tasks() | |
| self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, name="task-dispatcher", daemon=True) | |
| self._dispatcher_thread.start() | |
| self._started = True | |
| LOGGER.info("Task manager started") | |
| def shutdown(self) -> None: | |
| LOGGER.info("Task manager shutdown requested") | |
| self._shutdown_event.set() | |
| with self._running_lock: | |
| for running_task in self._running.values(): | |
| running_task.stop_event.set() | |
| LOGGER.info("Task manager stop signal broadcast to %s running task(s)", self._running_count()) | |
| def queue_task(self, user_id: int, requested_by: str, requested_by_role: str) -> tuple[dict, bool]: | |
| active_task = self.store.find_active_task_for_user(user_id) | |
| if active_task is not None: | |
| LOGGER.info( | |
| "Task queue skipped because an active task already exists | task_id=%s user_id=%s status=%s", | |
| active_task["id"], | |
| user_id, | |
| active_task["status"], | |
| ) | |
| return self.store.get_task(active_task["id"]) or active_task, False | |
| task_id = self.store.create_task(user_id, requested_by, requested_by_role) | |
| LOGGER.info( | |
| "Task queued | task_id=%s user_id=%s requested_by_role=%s requested_by=%s", | |
| task_id, | |
| user_id, | |
| requested_by_role, | |
| requested_by, | |
| ) | |
| self._log(task_id, user_id, "SYSTEM", "INFO", f"任务已进入队列,触发者: {requested_by_role}:{requested_by}") | |
| return self.store.get_task(task_id), True | |
| def stop_task(self, task_id: int) -> bool: | |
| requested = self.store.request_task_stop(task_id) | |
| if not requested: | |
| LOGGER.info("Task stop request ignored because task is not active | task_id=%s", task_id) | |
| return False | |
| LOGGER.info("Task stop requested | task_id=%s", task_id) | |
| self._log(task_id, None, "SYSTEM", "INFO", "收到停止请求,任务会在安全节点退出。") | |
| with self._running_lock: | |
| running_task = self._running.get(task_id) | |
| if running_task is not None: | |
| running_task.stop_event.set() | |
| return True | |
| def _dispatch_loop(self) -> None: | |
| LOGGER.info("Dispatcher loop started") | |
| while not self._shutdown_event.is_set(): | |
| self._cleanup_running_registry() | |
| parallel_limit = self.store.get_parallel_limit() | |
| running_count = self._running_count() | |
| available_slots = max(0, parallel_limit - running_count) | |
| pending_tasks: list[dict] = [] | |
| if available_slots > 0: | |
| pending_tasks = self.store.list_pending_tasks(available_slots) | |
| snapshot = (parallel_limit, running_count, len(pending_tasks)) | |
| if snapshot != self._last_dispatch_snapshot and (running_count > 0 or pending_tasks): | |
| LOGGER.info( | |
| "Dispatcher snapshot | parallel_limit=%s running=%s available_slots=%s fetched_pending=%s", | |
| parallel_limit, | |
| running_count, | |
| available_slots, | |
| len(pending_tasks), | |
| ) | |
| self._last_dispatch_snapshot = snapshot | |
| if pending_tasks: | |
| for task in pending_tasks: | |
| if self._shutdown_event.is_set(): | |
| break | |
| if not task["is_active"]: | |
| LOGGER.warning( | |
| "Queued task cancelled because user is inactive | task_id=%s user_id=%s", | |
| task["id"], | |
| task["user_id"], | |
| ) | |
| self.store.finish_task(task["id"], "failed", "该用户已被禁用,任务未执行。") | |
| self._log(task["id"], task["user_id"], "SYSTEM", "WARNING", "用户已禁用,队列中的任务被取消。") | |
| continue | |
| self._launch_task(task["id"]) | |
| time.sleep(1) | |
| LOGGER.info("Dispatcher loop stopped") | |
| def _launch_task(self, task_id: int) -> None: | |
| with self._running_lock: | |
| if task_id in self._running: | |
| LOGGER.info("Task launch skipped because thread already exists | task_id=%s", task_id) | |
| return | |
| stop_event = threading.Event() | |
| thread = threading.Thread(target=self._run_task, args=(task_id, stop_event), name=f"task-{task_id}", daemon=True) | |
| self._running[task_id] = RunningTask(task_id=task_id, thread=thread, stop_event=stop_event) | |
| LOGGER.info("Launching task thread | task_id=%s thread_name=%s", task_id, thread.name) | |
| thread.start() | |
| def _run_task(self, task_id: int, stop_event: threading.Event) -> None: | |
| task = self.store.get_task(task_id) | |
| if task is None: | |
| LOGGER.warning("Task record disappeared before execution | task_id=%s", task_id) | |
| self._remove_running(task_id) | |
| return | |
| user = self.store.get_user(task["user_id"]) | |
| if user is None: | |
| LOGGER.error("Task cannot start because user does not exist | task_id=%s user_id=%s", task_id, task["user_id"]) | |
| self.store.finish_task(task_id, "failed", "用户不存在。") | |
| self._remove_running(task_id) | |
| return | |
| LOGGER.info("Task runner starting | task_id=%s user_id=%s", task_id, user["id"]) | |
| self.store.mark_task_running(task_id) | |
| self._log(task_id, user["id"], "SYSTEM", "INFO", "任务开始执行。") | |
| try: | |
| password = self.secret_box.decrypt(user["password_encrypted"]) | |
| bot = CourseBot( | |
| config=self.config, | |
| store=self.store, | |
| task_id=task_id, | |
| user=user, | |
| password=password, | |
| logger=lambda level, message: self._log(task_id, user["id"], "RUNNER", level, message), | |
| ) | |
| result = bot.run(stop_event) | |
| except Exception as exc: # pragma: no cover - defensive fallback | |
| LOGGER.exception("Task initialization failed | task_id=%s user_id=%s", task_id, user["id"]) | |
| result = TaskResult(status="failed", error=str(exc)) | |
| self._log(task_id, user["id"], "SYSTEM", "ERROR", f"任务初始化失败: {exc}") | |
| current_task = self.store.get_task(task_id) | |
| final_status = result.status | |
| if current_task and current_task["status"] == "cancel_requested" and result.status != "completed": | |
| final_status = "stopped" | |
| elif stop_event.is_set() and result.status != "completed": | |
| final_status = "stopped" | |
| self.store.finish_task(task_id, final_status, result.error) | |
| final_text = result.error or f"任务已结束,状态: {final_status}" | |
| level = "ERROR" if final_status == "failed" else "INFO" | |
| self._log(task_id, user["id"], "SYSTEM", level, final_text) | |
| LOGGER.info( | |
| "Task runner finished | task_id=%s user_id=%s final_status=%s", | |
| task_id, | |
| user["id"], | |
| final_status, | |
| ) | |
| self._remove_running(task_id) | |
| def _log(self, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None: | |
| self.store.add_log(task_id, user_id, scope, level, message) | |
| self._emit_runtime_log(task_id=task_id, user_id=user_id, scope=scope, level=level, message=message) | |
| def _emit_runtime_log(*, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None: | |
| upper_level = level.upper() | |
| rendered = f"task_id={task_id or '-'} user_id={user_id or '-'} scope={scope} level={upper_level} | {message}" | |
| if upper_level == "ERROR": | |
| LOGGER.error(rendered) | |
| elif upper_level == "WARNING": | |
| LOGGER.warning(rendered) | |
| else: | |
| LOGGER.info(rendered) | |
| def _running_count(self) -> int: | |
| with self._running_lock: | |
| return len(self._running) | |
| def _cleanup_running_registry(self) -> None: | |
| finished_ids: list[int] = [] | |
| with self._running_lock: | |
| for task_id, running_task in self._running.items(): | |
| if not running_task.thread.is_alive(): | |
| finished_ids.append(task_id) | |
| for task_id in finished_ids: | |
| self._running.pop(task_id, None) | |
| for task_id in finished_ids: | |
| LOGGER.info("Task thread cleaned from registry | task_id=%s", task_id) | |
| def _remove_running(self, task_id: int) -> None: | |
| with self._running_lock: | |
| removed = self._running.pop(task_id, None) | |
| if removed is not None: | |
| LOGGER.info("Task removed from running registry | task_id=%s", task_id) | |