Spaces:
Paused
Paused
| from __future__ import annotations | |
| import threading | |
| import time | |
| import unittest | |
| from types import SimpleNamespace | |
| from unittest.mock import patch | |
| from cryptography.fernet import Fernet | |
| from core.course_bot import TaskResult | |
| from core.db import Database | |
| from core.security import SecretBox | |
| from core.task_manager import TaskManager | |
| from tests.helpers import workspace_tempdir | |
| class FakeCourseBot: | |
| lock = threading.Lock() | |
| release_event = threading.Event() | |
| two_started_event = threading.Event() | |
| started_task_ids: list[int] = [] | |
| peak_running = 0 | |
| current_running = 0 | |
| def reset(cls) -> None: | |
| cls.release_event = threading.Event() | |
| cls.two_started_event = threading.Event() | |
| cls.started_task_ids = [] | |
| cls.peak_running = 0 | |
| cls.current_running = 0 | |
| def __init__(self, *, config, store, task_id, user, password, logger) -> None: | |
| self.task_id = task_id | |
| def run(self, stop_event) -> TaskResult: | |
| with self.__class__.lock: | |
| self.__class__.started_task_ids.append(self.task_id) | |
| self.__class__.current_running += 1 | |
| self.__class__.peak_running = max(self.__class__.peak_running, self.__class__.current_running) | |
| if len(self.__class__.started_task_ids) >= 2: | |
| self.__class__.two_started_event.set() | |
| deadline = time.time() + 5 | |
| while not stop_event.is_set() and not self.__class__.release_event.is_set() and time.time() < deadline: | |
| time.sleep(0.05) | |
| with self.__class__.lock: | |
| self.__class__.current_running -= 1 | |
| if stop_event.is_set(): | |
| return TaskResult(status="stopped") | |
| return TaskResult(status="completed") | |
| class TaskManagerTests(unittest.TestCase): | |
| def test_parallel_limit_caps_concurrent_task_launches(self) -> None: | |
| FakeCourseBot.reset() | |
| with workspace_tempdir("task-manager-") as temp_dir: | |
| store = Database(temp_dir / "test.db", default_parallel_limit=2) | |
| store.init_db() | |
| store.set_parallel_limit(2) | |
| secret_box = SecretBox(Fernet.generate_key().decode("utf-8")) | |
| config = SimpleNamespace() | |
| manager = TaskManager(config=config, store=store, secret_box=secret_box) | |
| user_ids = [] | |
| for index in range(3): | |
| user_id = store.create_user( | |
| f"202300000000{index}", | |
| secret_box.encrypt("pw"), | |
| f"User {index}", | |
| ) | |
| store.add_course(user_id, "free", f"10010{index}", "01") | |
| user_ids.append(user_id) | |
| try: | |
| with patch("core.task_manager.CourseBot", FakeCourseBot): | |
| manager.start() | |
| for user_id in user_ids: | |
| manager.queue_task(user_id, requested_by="tester", requested_by_role="admin") | |
| self.assertTrue(FakeCourseBot.two_started_event.wait(5), "前两个任务没有按预期并行启动") | |
| time.sleep(0.6) | |
| self.assertEqual(len(FakeCourseBot.started_task_ids), 2) | |
| self.assertEqual(FakeCourseBot.peak_running, 2) | |
| FakeCourseBot.release_event.set() | |
| deadline = time.time() + 8 | |
| while time.time() < deadline: | |
| recent_tasks = store.list_recent_tasks(limit=10) | |
| if len(recent_tasks) == 3 and all(task["status"] == "completed" for task in recent_tasks): | |
| break | |
| time.sleep(0.1) | |
| else: | |
| self.fail("任务没有在预期时间内全部完成") | |
| self.assertEqual(len(FakeCourseBot.started_task_ids), 3) | |
| self.assertEqual(FakeCourseBot.peak_running, 2) | |
| finally: | |
| FakeCourseBot.release_event.set() | |
| manager.shutdown() | |
| if __name__ == "__main__": | |
| unittest.main() | |