File size: 3,951 Bytes
e28c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

import threading
from typing import Callable

from course_catcher.automation import CourseAutomation
from course_catcher.config import AppConfig
from course_catcher.db import Database


class TaskManager:
    def __init__(self, db: Database, config: AppConfig) -> None:
        self.db = db
        self.config = config
        self.automation = CourseAutomation(config)
        self._workers: dict[int, threading.Thread] = {}
        self._lock = threading.RLock()
        self._stop_event = threading.Event()
        self._scheduler_thread: threading.Thread | None = None

    def start(self) -> None:
        with self._lock:
            if self._scheduler_thread and self._scheduler_thread.is_alive():
                return
            self._scheduler_thread = threading.Thread(
                target=self._scheduler_loop,
                name="task-scheduler",
                daemon=True,
            )
            self._scheduler_thread.start()

    def stop(self) -> None:
        self._stop_event.set()

    def _scheduler_loop(self) -> None:
        while not self._stop_event.is_set():
            self._reap_workers()
            parallelism = self.db.get_setting_int("max_parallel_tasks", self.config.default_parallelism)
            available_slots = max(0, parallelism - len(self._workers))
            if available_slots > 0:
                queued_tasks = self.db.fetch_queued_tasks(available_slots)
                for task in queued_tasks:
                    task_id = task["id"]
                    with self._lock:
                        if task_id in self._workers:
                            continue
                        self.db.mark_task_running(task_id)
                        worker = threading.Thread(
                            target=self._run_task,
                            args=(task_id,),
                            name=f"task-{task_id}",
                            daemon=True,
                        )
                        self._workers[task_id] = worker
                        worker.start()
            self._stop_event.wait(1)

    def _reap_workers(self) -> None:
        with self._lock:
            finished = [task_id for task_id, worker in self._workers.items() if not worker.is_alive()]
            for task_id in finished:
                self._workers.pop(task_id, None)

    def _run_task(self, task_id: int) -> None:
        task = self.db.get_task(task_id)
        if not task:
            return

        user_id = task["user_id"]

        def log(level: str, message: str) -> None:
            self.db.add_log(task_id=task_id, user_id=user_id, actor="runner", level=level, message=message)

        try:
            credentials = self.db.get_user_runtime_credentials(user_id)
            if not credentials:
                self.db.finish_task(task_id, "failed", "关联用户不存在")
                log("ERROR", "任务关联的用户记录不存在。")
                return

            log("INFO", f"任务已启动,当前用户 {credentials['student_id']}。")
            final_status, last_error = self.automation.run_until_stopped(
                task_id=task_id,
                user_credentials=credentials,
                db=self.db,
                should_stop=lambda: self.db.is_stop_requested(task_id) or self._stop_event.is_set(),
                log=log,
            )
            self.db.finish_task(task_id, final_status, last_error)
            if final_status == "completed":
                log("SUCCESS", "任务完成,所有待抢课程均已处理。")
            elif final_status == "stopped":
                log("INFO", "任务已停止。")
            else:
                log("ERROR", f"任务失败:{last_error}")
        except Exception as exc:  # pragma: no cover - defensive path
            self.db.finish_task(task_id, "failed", str(exc))
            log("ERROR", f"任务崩溃:{exc}")