Spaces:
Runtime error
Runtime error
| """ | |
| Batch utilities. | |
| - run_batch: generic batch concurrency runner | |
| - BatchTask: SSE task manager for admin batch operations | |
| """ | |
| import asyncio | |
| import time | |
| import uuid | |
| from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar | |
| from app.core.logger import logger | |
| T = TypeVar("T") | |
| async def run_batch( | |
| items: List[str], | |
| worker: Callable[[str], Awaitable[T]], | |
| *, | |
| batch_size: int = 50, | |
| task: Optional["BatchTask"] = None, | |
| on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, | |
| should_cancel: Optional[Callable[[], bool]] = None, | |
| ) -> Dict[str, Dict[str, Any]]: | |
| """ | |
| 分批并发执行,单项失败不影响整体 | |
| Args: | |
| items: 待处理项列表 | |
| worker: 异步处理函数 | |
| batch_size: 每批大小 | |
| Returns: | |
| {item: {"ok": bool, "data": ..., "error": ...}} | |
| """ | |
| try: | |
| batch_size = int(batch_size) | |
| except Exception: | |
| batch_size = 50 | |
| batch_size = max(1, batch_size) | |
| async def _one(item: str) -> tuple[str, dict]: | |
| if (should_cancel and should_cancel()) or (task and task.cancelled): | |
| return item, {"ok": False, "error": "cancelled", "cancelled": True} | |
| try: | |
| data = await worker(item) | |
| result = {"ok": True, "data": data} | |
| if task: | |
| task.record(True) | |
| if on_item: | |
| try: | |
| await on_item(item, result) | |
| except Exception: | |
| pass | |
| return item, result | |
| except Exception as e: | |
| logger.warning(f"Batch item failed: {item[:16]}... - {e}") | |
| result = {"ok": False, "error": str(e)} | |
| if task: | |
| task.record(False, error=str(e)) | |
| if on_item: | |
| try: | |
| await on_item(item, result) | |
| except Exception: | |
| pass | |
| return item, result | |
| results: Dict[str, dict] = {} | |
| # 分批执行,避免一次性创建所有 task | |
| for i in range(0, len(items), batch_size): | |
| if (should_cancel and should_cancel()) or (task and task.cancelled): | |
| break | |
| chunk = items[i : i + batch_size] | |
| pairs = await asyncio.gather(*(_one(x) for x in chunk)) | |
| results.update(dict(pairs)) | |
| return results | |
| class BatchTask: | |
| def __init__(self, total: int): | |
| self.id = uuid.uuid4().hex | |
| self.total = int(total) | |
| self.processed = 0 | |
| self.ok = 0 | |
| self.fail = 0 | |
| self.status = "running" | |
| self.warning: Optional[str] = None | |
| self.result: Optional[Dict[str, Any]] = None | |
| self.error: Optional[str] = None | |
| self.created_at = time.time() | |
| self._queues: List[asyncio.Queue] = [] | |
| self._final_event: Optional[Dict[str, Any]] = None | |
| self.cancelled = False | |
| def snapshot(self) -> Dict[str, Any]: | |
| return { | |
| "task_id": self.id, | |
| "status": self.status, | |
| "total": self.total, | |
| "processed": self.processed, | |
| "ok": self.ok, | |
| "fail": self.fail, | |
| "warning": self.warning, | |
| } | |
| def attach(self) -> asyncio.Queue: | |
| q: asyncio.Queue = asyncio.Queue(maxsize=200) | |
| self._queues.append(q) | |
| return q | |
| def detach(self, q: asyncio.Queue) -> None: | |
| if q in self._queues: | |
| self._queues.remove(q) | |
| def _publish(self, event: Dict[str, Any]) -> None: | |
| for q in list(self._queues): | |
| try: | |
| q.put_nowait(event) | |
| except Exception: | |
| # Drop if queue is full or closed | |
| pass | |
| def record( | |
| self, ok: bool, *, item: Any = None, detail: Any = None, error: str = "" | |
| ) -> None: | |
| self.processed += 1 | |
| if ok: | |
| self.ok += 1 | |
| else: | |
| self.fail += 1 | |
| event: Dict[str, Any] = { | |
| "type": "progress", | |
| "task_id": self.id, | |
| "total": self.total, | |
| "processed": self.processed, | |
| "ok": self.ok, | |
| "fail": self.fail, | |
| } | |
| if item is not None: | |
| event["item"] = item | |
| if detail is not None: | |
| event["detail"] = detail | |
| if error: | |
| event["error"] = error | |
| self._publish(event) | |
| def finish(self, result: Dict[str, Any], *, warning: Optional[str] = None) -> None: | |
| self.status = "done" | |
| self.result = result | |
| self.warning = warning | |
| event = { | |
| "type": "done", | |
| "task_id": self.id, | |
| "total": self.total, | |
| "processed": self.processed, | |
| "ok": self.ok, | |
| "fail": self.fail, | |
| "warning": self.warning, | |
| "result": result, | |
| } | |
| self._final_event = event | |
| self._publish(event) | |
| def fail_task(self, error: str) -> None: | |
| self.status = "error" | |
| self.error = error | |
| event = { | |
| "type": "error", | |
| "task_id": self.id, | |
| "total": self.total, | |
| "processed": self.processed, | |
| "ok": self.ok, | |
| "fail": self.fail, | |
| "error": error, | |
| } | |
| self._final_event = event | |
| self._publish(event) | |
| def cancel(self) -> None: | |
| self.cancelled = True | |
| def finish_cancelled(self) -> None: | |
| self.status = "cancelled" | |
| event = { | |
| "type": "cancelled", | |
| "task_id": self.id, | |
| "total": self.total, | |
| "processed": self.processed, | |
| "ok": self.ok, | |
| "fail": self.fail, | |
| } | |
| self._final_event = event | |
| self._publish(event) | |
| def final_event(self) -> Optional[Dict[str, Any]]: | |
| return self._final_event | |
| _TASKS: Dict[str, BatchTask] = {} | |
| def create_task(total: int) -> BatchTask: | |
| task = BatchTask(total) | |
| _TASKS[task.id] = task | |
| return task | |
| def get_task(task_id: str) -> Optional[BatchTask]: | |
| return _TASKS.get(task_id) | |
| def delete_task(task_id: str) -> None: | |
| _TASKS.pop(task_id, None) | |
| async def expire_task(task_id: str, delay: int = 300) -> None: | |
| await asyncio.sleep(delay) | |
| delete_task(task_id) | |
| __all__ = [ | |
| "run_batch", | |
| "BatchTask", | |
| "create_task", | |
| "get_task", | |
| "delete_task", | |
| "expire_task", | |
| ] | |