| from __future__ import annotations |
|
|
| import threading |
| from typing import Callable |
|
|
| from course_catcher.automation import CourseAutomation |
| from course_catcher.config import AppConfig |
| from course_catcher.db import Database |
|
|
|
|
| class TaskManager: |
| def __init__(self, db: Database, config: AppConfig) -> None: |
| self.db = db |
| self.config = config |
| self.automation = CourseAutomation(config) |
| self._workers: dict[int, threading.Thread] = {} |
| self._lock = threading.RLock() |
| self._stop_event = threading.Event() |
| self._scheduler_thread: threading.Thread | None = None |
|
|
| def start(self) -> None: |
| with self._lock: |
| if self._scheduler_thread and self._scheduler_thread.is_alive(): |
| return |
| self._scheduler_thread = threading.Thread( |
| target=self._scheduler_loop, |
| name="task-scheduler", |
| daemon=True, |
| ) |
| self._scheduler_thread.start() |
|
|
| def stop(self) -> None: |
| self._stop_event.set() |
|
|
| def _scheduler_loop(self) -> None: |
| while not self._stop_event.is_set(): |
| self._reap_workers() |
| parallelism = self.db.get_setting_int("max_parallel_tasks", self.config.default_parallelism) |
| available_slots = max(0, parallelism - len(self._workers)) |
| if available_slots > 0: |
| queued_tasks = self.db.fetch_queued_tasks(available_slots) |
| for task in queued_tasks: |
| task_id = task["id"] |
| with self._lock: |
| if task_id in self._workers: |
| continue |
| self.db.mark_task_running(task_id) |
| worker = threading.Thread( |
| target=self._run_task, |
| args=(task_id,), |
| name=f"task-{task_id}", |
| daemon=True, |
| ) |
| self._workers[task_id] = worker |
| worker.start() |
| self._stop_event.wait(1) |
|
|
| def _reap_workers(self) -> None: |
| with self._lock: |
| finished = [task_id for task_id, worker in self._workers.items() if not worker.is_alive()] |
| for task_id in finished: |
| self._workers.pop(task_id, None) |
|
|
| def _run_task(self, task_id: int) -> None: |
| task = self.db.get_task(task_id) |
| if not task: |
| return |
|
|
| user_id = task["user_id"] |
|
|
| def log(level: str, message: str) -> None: |
| self.db.add_log(task_id=task_id, user_id=user_id, actor="runner", level=level, message=message) |
|
|
| try: |
| credentials = self.db.get_user_runtime_credentials(user_id) |
| if not credentials: |
| self.db.finish_task(task_id, "failed", "关联用户不存在") |
| log("ERROR", "任务关联的用户记录不存在。") |
| return |
|
|
| log("INFO", f"任务已启动,当前用户 {credentials['student_id']}。") |
| final_status, last_error = self.automation.run_until_stopped( |
| task_id=task_id, |
| user_credentials=credentials, |
| db=self.db, |
| should_stop=lambda: self.db.is_stop_requested(task_id) or self._stop_event.is_set(), |
| log=log, |
| ) |
| self.db.finish_task(task_id, final_status, last_error) |
| if final_status == "completed": |
| log("SUCCESS", "任务完成,所有待抢课程均已处理。") |
| elif final_status == "stopped": |
| log("INFO", "任务已停止。") |
| else: |
| log("ERROR", f"任务失败:{last_error}") |
| except Exception as exc: |
| self.db.finish_task(task_id, "failed", str(exc)) |
| log("ERROR", f"任务崩溃:{exc}") |
|
|