File size: 10,359 Bytes
e28c9e4
 
8385899
e28c9e4
 
 
 
 
 
 
 
 
8385899
 
 
e28c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8385899
e28c9e4
 
 
 
8385899
e28c9e4
8385899
 
 
 
 
e28c9e4
 
 
 
8385899
e28c9e4
 
8385899
e28c9e4
 
 
 
8385899
e28c9e4
 
 
 
8385899
 
 
 
 
 
e28c9e4
 
 
8385899
 
 
 
 
 
 
e28c9e4
 
 
 
 
 
8385899
e28c9e4
8385899
e28c9e4
 
 
 
 
 
 
 
8385899
e28c9e4
 
 
 
 
8385899
e28c9e4
 
8385899
 
 
 
 
 
 
 
 
 
 
 
 
e28c9e4
 
 
 
8385899
 
 
 
 
e28c9e4
 
 
 
 
8385899
e28c9e4
 
 
 
8385899
e28c9e4
 
 
 
8385899
e28c9e4
 
 
 
 
8385899
e28c9e4
 
 
 
 
8385899
e28c9e4
 
 
 
8385899
e28c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8385899
e28c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
8385899
 
 
 
 
 
e28c9e4
 
 
 
8385899
 
 
 
 
 
 
 
 
 
 
 
e28c9e4
 
 
 
 
 
 
 
 
 
 
 
 
8385899
 
e28c9e4
 
 
8385899
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from __future__ import annotations

import logging
import threading
import time
from dataclasses import dataclass

from core.course_bot import CourseBot, TaskResult
from core.db import Database
from core.security import SecretBox


LOGGER = logging.getLogger("sacc.task_manager")


@dataclass(slots=True)
class RunningTask:
    task_id: int
    thread: threading.Thread
    stop_event: threading.Event


class TaskManager:
    def __init__(self, *, config, store: Database, secret_box: SecretBox) -> None:
        self.config = config
        self.store = store
        self.secret_box = secret_box
        self._started = False
        self._startup_lock = threading.Lock()
        self._running_lock = threading.Lock()
        self._shutdown_event = threading.Event()
        self._dispatcher_thread: threading.Thread | None = None
        self._running: dict[int, RunningTask] = {}
        self._last_dispatch_snapshot: tuple[int, int, int] | None = None

    def start(self) -> None:
        with self._startup_lock:
            if self._started:
                LOGGER.info("Task manager start skipped because it is already running")
                return
            LOGGER.info(
                "Task manager starting | db_path=%s default_parallel_limit=%s",
                self.store.path,
                getattr(self.config, "default_parallel_limit", "-"),
            )
            self.store.reset_inflight_tasks()
            self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, name="task-dispatcher", daemon=True)
            self._dispatcher_thread.start()
            self._started = True
            LOGGER.info("Task manager started")

    def shutdown(self) -> None:
        LOGGER.info("Task manager shutdown requested")
        self._shutdown_event.set()
        with self._running_lock:
            for running_task in self._running.values():
                running_task.stop_event.set()
        LOGGER.info("Task manager stop signal broadcast to %s running task(s)", self._running_count())

    def queue_task(self, user_id: int, requested_by: str, requested_by_role: str) -> tuple[dict, bool]:
        active_task = self.store.find_active_task_for_user(user_id)
        if active_task is not None:
            LOGGER.info(
                "Task queue skipped because an active task already exists | task_id=%s user_id=%s status=%s",
                active_task["id"],
                user_id,
                active_task["status"],
            )
            return self.store.get_task(active_task["id"]) or active_task, False

        task_id = self.store.create_task(user_id, requested_by, requested_by_role)
        LOGGER.info(
            "Task queued | task_id=%s user_id=%s requested_by_role=%s requested_by=%s",
            task_id,
            user_id,
            requested_by_role,
            requested_by,
        )
        self._log(task_id, user_id, "SYSTEM", "INFO", f"任务已进入队列,触发者: {requested_by_role}:{requested_by}")
        return self.store.get_task(task_id), True

    def stop_task(self, task_id: int) -> bool:
        requested = self.store.request_task_stop(task_id)
        if not requested:
            LOGGER.info("Task stop request ignored because task is not active | task_id=%s", task_id)
            return False
        LOGGER.info("Task stop requested | task_id=%s", task_id)
        self._log(task_id, None, "SYSTEM", "INFO", "收到停止请求,任务会在安全节点退出。")
        with self._running_lock:
            running_task = self._running.get(task_id)
            if running_task is not None:
                running_task.stop_event.set()
        return True

    def _dispatch_loop(self) -> None:
        LOGGER.info("Dispatcher loop started")
        while not self._shutdown_event.is_set():
            self._cleanup_running_registry()
            parallel_limit = self.store.get_parallel_limit()
            running_count = self._running_count()
            available_slots = max(0, parallel_limit - running_count)
            pending_tasks: list[dict] = []
            if available_slots > 0:
                pending_tasks = self.store.list_pending_tasks(available_slots)

            snapshot = (parallel_limit, running_count, len(pending_tasks))
            if snapshot != self._last_dispatch_snapshot and (running_count > 0 or pending_tasks):
                LOGGER.info(
                    "Dispatcher snapshot | parallel_limit=%s running=%s available_slots=%s fetched_pending=%s",
                    parallel_limit,
                    running_count,
                    available_slots,
                    len(pending_tasks),
                )
                self._last_dispatch_snapshot = snapshot

            if pending_tasks:
                for task in pending_tasks:
                    if self._shutdown_event.is_set():
                        break
                    if not task["is_active"]:
                        LOGGER.warning(
                            "Queued task cancelled because user is inactive | task_id=%s user_id=%s",
                            task["id"],
                            task["user_id"],
                        )
                        self.store.finish_task(task["id"], "failed", "该用户已被禁用,任务未执行。")
                        self._log(task["id"], task["user_id"], "SYSTEM", "WARNING", "用户已禁用,队列中的任务被取消。")
                        continue
                    self._launch_task(task["id"])
            time.sleep(1)
        LOGGER.info("Dispatcher loop stopped")

    def _launch_task(self, task_id: int) -> None:
        with self._running_lock:
            if task_id in self._running:
                LOGGER.info("Task launch skipped because thread already exists | task_id=%s", task_id)
                return
            stop_event = threading.Event()
            thread = threading.Thread(target=self._run_task, args=(task_id, stop_event), name=f"task-{task_id}", daemon=True)
            self._running[task_id] = RunningTask(task_id=task_id, thread=thread, stop_event=stop_event)
            LOGGER.info("Launching task thread | task_id=%s thread_name=%s", task_id, thread.name)
            thread.start()

    def _run_task(self, task_id: int, stop_event: threading.Event) -> None:
        task = self.store.get_task(task_id)
        if task is None:
            LOGGER.warning("Task record disappeared before execution | task_id=%s", task_id)
            self._remove_running(task_id)
            return

        user = self.store.get_user(task["user_id"])
        if user is None:
            LOGGER.error("Task cannot start because user does not exist | task_id=%s user_id=%s", task_id, task["user_id"])
            self.store.finish_task(task_id, "failed", "用户不存在。")
            self._remove_running(task_id)
            return

        LOGGER.info("Task runner starting | task_id=%s user_id=%s", task_id, user["id"])
        self.store.mark_task_running(task_id)
        self._log(task_id, user["id"], "SYSTEM", "INFO", "任务开始执行。")

        try:
            password = self.secret_box.decrypt(user["password_encrypted"])
            bot = CourseBot(
                config=self.config,
                store=self.store,
                task_id=task_id,
                user=user,
                password=password,
                logger=lambda level, message: self._log(task_id, user["id"], "RUNNER", level, message),
            )
            result = bot.run(stop_event)
        except Exception as exc:  # pragma: no cover - defensive fallback
            LOGGER.exception("Task initialization failed | task_id=%s user_id=%s", task_id, user["id"])
            result = TaskResult(status="failed", error=str(exc))
            self._log(task_id, user["id"], "SYSTEM", "ERROR", f"任务初始化失败: {exc}")

        current_task = self.store.get_task(task_id)
        final_status = result.status
        if current_task and current_task["status"] == "cancel_requested" and result.status != "completed":
            final_status = "stopped"
        elif stop_event.is_set() and result.status != "completed":
            final_status = "stopped"

        self.store.finish_task(task_id, final_status, result.error)
        final_text = result.error or f"任务已结束,状态: {final_status}"
        level = "ERROR" if final_status == "failed" else "INFO"
        self._log(task_id, user["id"], "SYSTEM", level, final_text)
        LOGGER.info(
            "Task runner finished | task_id=%s user_id=%s final_status=%s",
            task_id,
            user["id"],
            final_status,
        )
        self._remove_running(task_id)

    def _log(self, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None:
        self.store.add_log(task_id, user_id, scope, level, message)
        self._emit_runtime_log(task_id=task_id, user_id=user_id, scope=scope, level=level, message=message)

    @staticmethod
    def _emit_runtime_log(*, task_id: int | None, user_id: int | None, scope: str, level: str, message: str) -> None:
        upper_level = level.upper()
        rendered = f"task_id={task_id or '-'} user_id={user_id or '-'} scope={scope} level={upper_level} | {message}"
        if upper_level == "ERROR":
            LOGGER.error(rendered)
        elif upper_level == "WARNING":
            LOGGER.warning(rendered)
        else:
            LOGGER.info(rendered)

    def _running_count(self) -> int:
        with self._running_lock:
            return len(self._running)

    def _cleanup_running_registry(self) -> None:
        finished_ids: list[int] = []
        with self._running_lock:
            for task_id, running_task in self._running.items():
                if not running_task.thread.is_alive():
                    finished_ids.append(task_id)
            for task_id in finished_ids:
                self._running.pop(task_id, None)
        for task_id in finished_ids:
            LOGGER.info("Task thread cleaned from registry | task_id=%s", task_id)

    def _remove_running(self, task_id: int) -> None:
        with self._running_lock:
            removed = self._running.pop(task_id, None)
        if removed is not None:
            LOGGER.info("Task removed from running registry | task_id=%s", task_id)