FlashWorld-ZeroGPU / concurrency_manager.py
imlixinyang's picture
add app!
c8df52d
import threading
import time
import uuid
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass
from enum import Enum
class TaskStatus(Enum):
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class Task:
task_id: str
status: TaskStatus
created_at: float
started_at: Optional[float] = None
completed_at: Optional[float] = None
result: Optional[Any] = None
error: Optional[str] = None
function: Optional[Callable] = None
args: tuple = ()
kwargs: dict = None
def __post_init__(self):
if self.kwargs is None:
self.kwargs = {}
class ConcurrencyManager:
def __init__(self, max_concurrent: int = 2):
"""
并发控制管理器
Args:
max_concurrent: 最大并发数量
"""
self.max_concurrent = max_concurrent
self.running_tasks: Dict[str, Task] = {}
self.queued_tasks: List[Task] = []
self.completed_tasks: Dict[str, Task] = {}
self.lock = threading.RLock()
self.worker_threads: List[threading.Thread] = []
self.shutdown_event = threading.Event()
# 启动工作线程
self._start_workers()
def _start_workers(self):
"""启动工作线程"""
for i in range(self.max_concurrent):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self.worker_threads.append(worker)
def _worker_loop(self):
"""工作线程主循环"""
while not self.shutdown_event.is_set():
try:
task = self._get_next_task()
if task:
self._execute_task(task)
else:
# 没有任务时短暂休眠
time.sleep(0.1)
except Exception as e:
print(f"Worker thread error: {e}")
time.sleep(1)
def _get_next_task(self) -> Optional[Task]:
"""获取下一个要执行的任务"""
with self.lock:
if self.queued_tasks:
return self.queued_tasks.pop(0)
return None
def _execute_task(self, task: Task):
"""执行任务"""
try:
with self.lock:
task.status = TaskStatus.RUNNING
task.started_at = time.time()
self.running_tasks[task.task_id] = task
# 执行任务
if task.function:
result = task.function(*task.args, **task.kwargs)
task.result = result
# 标记完成
with self.lock:
task.status = TaskStatus.COMPLETED
task.completed_at = time.time()
self.completed_tasks[task.task_id] = task
if task.task_id in self.running_tasks:
del self.running_tasks[task.task_id]
except Exception as e:
# 标记失败
with self.lock:
task.status = TaskStatus.FAILED
task.completed_at = time.time()
task.error = str(e)
self.completed_tasks[task.task_id] = task
if task.task_id in self.running_tasks:
del self.running_tasks[task.task_id]
def submit_task(self, func: Callable, *args, **kwargs) -> str:
"""
提交任务
Args:
func: 要执行的函数
*args: 函数参数
**kwargs: 函数关键字参数
Returns:
task_id: 任务ID
"""
task_id = str(uuid.uuid4())
task = Task(
task_id=task_id,
status=TaskStatus.QUEUED,
created_at=time.time(),
function=func,
args=args,
kwargs=kwargs
)
with self.lock:
self.queued_tasks.append(task)
return task_id
def get_task_status(self, task_id: str) -> Optional[Task]:
"""获取任务状态"""
with self.lock:
if task_id in self.running_tasks:
return self.running_tasks[task_id]
elif task_id in self.completed_tasks:
return self.completed_tasks[task_id]
else:
# 检查队列中的任务
for task in self.queued_tasks:
if task.task_id == task_id:
return task
return None
def get_queue_status(self) -> Dict[str, Any]:
"""获取队列状态"""
with self.lock:
return {
"max_concurrent": self.max_concurrent,
"running_count": len(self.running_tasks),
"queued_count": len(self.queued_tasks),
"completed_count": len(self.completed_tasks),
"running_tasks": [task.task_id for task in self.running_tasks.values()],
"queued_tasks": [task.task_id for task in self.queued_tasks],
}
def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Task:
"""
等待任务完成
Args:
task_id: 任务ID
timeout: 超时时间(秒),None表示无限等待
Returns:
Task: 完成的任务
"""
start_time = time.time()
while True:
task = self.get_task_status(task_id)
if task and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return task
if timeout and (time.time() - start_time) > timeout:
raise TimeoutError(f"Task {task_id} timed out after {timeout} seconds")
time.sleep(0.1)
def cleanup_old_tasks(self, max_age_hours: int = 24):
"""清理旧任务"""
current_time = time.time()
max_age_seconds = max_age_hours * 3600
with self.lock:
# 清理已完成的任务
old_tasks = [
task_id for task_id, task in self.completed_tasks.items()
if current_time - task.completed_at > max_age_seconds
]
for task_id in old_tasks:
del self.completed_tasks[task_id]
def shutdown(self):
"""关闭管理器"""
self.shutdown_event.set()
for worker in self.worker_threads:
worker.join(timeout=5)