| from __future__ import annotations |
|
|
| import threading |
| import time |
| import unittest |
| from types import SimpleNamespace |
| from unittest.mock import patch |
|
|
| from cryptography.fernet import Fernet |
|
|
| from core.course_bot import TaskResult |
| from core.db import Database |
| from core.security import SecretBox |
| from core.task_manager import TaskManager |
| from tests.helpers import workspace_tempdir |
|
|
|
|
| class FakeCourseBot: |
| lock = threading.Lock() |
| release_event = threading.Event() |
| two_started_event = threading.Event() |
| started_task_ids: list[int] = [] |
| peak_running = 0 |
| current_running = 0 |
|
|
| @classmethod |
| def reset(cls) -> None: |
| cls.release_event = threading.Event() |
| cls.two_started_event = threading.Event() |
| cls.started_task_ids = [] |
| cls.peak_running = 0 |
| cls.current_running = 0 |
|
|
| def __init__(self, *, config, store, task_id, user, password, logger) -> None: |
| self.task_id = task_id |
|
|
| def run(self, stop_event) -> TaskResult: |
| with self.__class__.lock: |
| self.__class__.started_task_ids.append(self.task_id) |
| self.__class__.current_running += 1 |
| self.__class__.peak_running = max(self.__class__.peak_running, self.__class__.current_running) |
| if len(self.__class__.started_task_ids) >= 2: |
| self.__class__.two_started_event.set() |
|
|
| deadline = time.time() + 5 |
| while not stop_event.is_set() and not self.__class__.release_event.is_set() and time.time() < deadline: |
| time.sleep(0.05) |
|
|
| with self.__class__.lock: |
| self.__class__.current_running -= 1 |
|
|
| if stop_event.is_set(): |
| return TaskResult(status="stopped") |
| return TaskResult(status="completed") |
|
|
|
|
| class TaskManagerTests(unittest.TestCase): |
| def test_parallel_limit_caps_concurrent_task_launches(self) -> None: |
| FakeCourseBot.reset() |
| with workspace_tempdir("task-manager-") as temp_dir: |
| store = Database(temp_dir / "test.db", default_parallel_limit=2) |
| store.init_db() |
| store.set_parallel_limit(2) |
| secret_box = SecretBox(Fernet.generate_key().decode("utf-8")) |
| config = SimpleNamespace() |
| manager = TaskManager(config=config, store=store, secret_box=secret_box) |
|
|
| user_ids = [] |
| for index in range(3): |
| user_id = store.create_user( |
| f"202300000000{index}", |
| secret_box.encrypt("pw"), |
| f"User {index}", |
| ) |
| store.add_course(user_id, "free", f"10010{index}", "01") |
| user_ids.append(user_id) |
|
|
| try: |
| with patch("core.task_manager.CourseBot", FakeCourseBot): |
| manager.start() |
| for user_id in user_ids: |
| manager.queue_task(user_id, requested_by="tester", requested_by_role="admin") |
|
|
| self.assertTrue(FakeCourseBot.two_started_event.wait(5), "前两个任务没有按预期并行启动") |
| time.sleep(0.6) |
| self.assertEqual(len(FakeCourseBot.started_task_ids), 2) |
| self.assertEqual(FakeCourseBot.peak_running, 2) |
|
|
| FakeCourseBot.release_event.set() |
|
|
| deadline = time.time() + 8 |
| while time.time() < deadline: |
| recent_tasks = store.list_recent_tasks(limit=10) |
| if len(recent_tasks) == 3 and all(task["status"] == "completed" for task in recent_tasks): |
| break |
| time.sleep(0.1) |
| else: |
| self.fail("任务没有在预期时间内全部完成") |
|
|
| self.assertEqual(len(FakeCourseBot.started_task_ids), 3) |
| self.assertEqual(FakeCourseBot.peak_running, 2) |
| finally: |
| FakeCourseBot.release_event.set() |
| manager.shutdown() |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main()
|
|
|