| """ |
| Batch utilities. |
| |
| - run_batch: generic batch concurrency runner |
| - BatchTask: SSE task manager for admin batch operations |
| """ |
|
|
| import asyncio |
| import time |
| import uuid |
| from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar |
|
|
| from app.core.logger import logger |
|
|
| T = TypeVar("T") |
|
|
|
|
| async def run_batch( |
| items: List[str], |
| worker: Callable[[str], Awaitable[T]], |
| *, |
| batch_size: int = 50, |
| task: Optional["BatchTask"] = None, |
| on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, |
| should_cancel: Optional[Callable[[], bool]] = None, |
| ) -> Dict[str, Dict[str, Any]]: |
| """ |
| 分批并发执行,单项失败不影响整体 |
| |
| Args: |
| items: 待处理项列表 |
| worker: 异步处理函数 |
| batch_size: 每批大小 |
| |
| Returns: |
| {item: {"ok": bool, "data": ..., "error": ...}} |
| """ |
| try: |
| batch_size = int(batch_size) |
| except Exception: |
| batch_size = 50 |
|
|
| batch_size = max(1, batch_size) |
|
|
| async def _one(item: str) -> tuple[str, dict]: |
| if (should_cancel and should_cancel()) or (task and task.cancelled): |
| return item, {"ok": False, "error": "cancelled", "cancelled": True} |
| try: |
| data = await worker(item) |
| result = {"ok": True, "data": data} |
| if task: |
| task.record(True) |
| if on_item: |
| try: |
| await on_item(item, result) |
| except Exception: |
| pass |
| return item, result |
| except Exception as e: |
| logger.warning(f"Batch item failed: {item[:16]}... - {e}") |
| result = {"ok": False, "error": str(e)} |
| if task: |
| task.record(False, error=str(e)) |
| if on_item: |
| try: |
| await on_item(item, result) |
| except Exception: |
| pass |
| return item, result |
|
|
| results: Dict[str, dict] = {} |
|
|
| |
| for i in range(0, len(items), batch_size): |
| if (should_cancel and should_cancel()) or (task and task.cancelled): |
| break |
| chunk = items[i : i + batch_size] |
| pairs = await asyncio.gather(*(_one(x) for x in chunk)) |
| results.update(dict(pairs)) |
|
|
| return results |
|
|
|
|
| class BatchTask: |
| def __init__(self, total: int): |
| self.id = uuid.uuid4().hex |
| self.total = int(total) |
| self.processed = 0 |
| self.ok = 0 |
| self.fail = 0 |
| self.status = "running" |
| self.warning: Optional[str] = None |
| self.result: Optional[Dict[str, Any]] = None |
| self.error: Optional[str] = None |
| self.created_at = time.time() |
| self._queues: List[asyncio.Queue] = [] |
| self._final_event: Optional[Dict[str, Any]] = None |
| self.cancelled = False |
|
|
| def snapshot(self) -> Dict[str, Any]: |
| return { |
| "task_id": self.id, |
| "status": self.status, |
| "total": self.total, |
| "processed": self.processed, |
| "ok": self.ok, |
| "fail": self.fail, |
| "warning": self.warning, |
| } |
|
|
| def attach(self) -> asyncio.Queue: |
| q: asyncio.Queue = asyncio.Queue(maxsize=200) |
| self._queues.append(q) |
| return q |
|
|
| def detach(self, q: asyncio.Queue) -> None: |
| if q in self._queues: |
| self._queues.remove(q) |
|
|
| def _publish(self, event: Dict[str, Any]) -> None: |
| for q in list(self._queues): |
| try: |
| q.put_nowait(event) |
| except Exception: |
| |
| pass |
|
|
| def record( |
| self, ok: bool, *, item: Any = None, detail: Any = None, error: str = "" |
| ) -> None: |
| self.processed += 1 |
| if ok: |
| self.ok += 1 |
| else: |
| self.fail += 1 |
| event: Dict[str, Any] = { |
| "type": "progress", |
| "task_id": self.id, |
| "total": self.total, |
| "processed": self.processed, |
| "ok": self.ok, |
| "fail": self.fail, |
| } |
| if item is not None: |
| event["item"] = item |
| if detail is not None: |
| event["detail"] = detail |
| if error: |
| event["error"] = error |
| self._publish(event) |
|
|
| def finish(self, result: Dict[str, Any], *, warning: Optional[str] = None) -> None: |
| self.status = "done" |
| self.result = result |
| self.warning = warning |
| event = { |
| "type": "done", |
| "task_id": self.id, |
| "total": self.total, |
| "processed": self.processed, |
| "ok": self.ok, |
| "fail": self.fail, |
| "warning": self.warning, |
| "result": result, |
| } |
| self._final_event = event |
| self._publish(event) |
|
|
| def fail_task(self, error: str) -> None: |
| self.status = "error" |
| self.error = error |
| event = { |
| "type": "error", |
| "task_id": self.id, |
| "total": self.total, |
| "processed": self.processed, |
| "ok": self.ok, |
| "fail": self.fail, |
| "error": error, |
| } |
| self._final_event = event |
| self._publish(event) |
|
|
| def cancel(self) -> None: |
| self.cancelled = True |
|
|
| def finish_cancelled(self) -> None: |
| self.status = "cancelled" |
| event = { |
| "type": "cancelled", |
| "task_id": self.id, |
| "total": self.total, |
| "processed": self.processed, |
| "ok": self.ok, |
| "fail": self.fail, |
| } |
| self._final_event = event |
| self._publish(event) |
|
|
| def final_event(self) -> Optional[Dict[str, Any]]: |
| return self._final_event |
|
|
|
|
| _TASKS: Dict[str, BatchTask] = {} |
|
|
|
|
| def create_task(total: int) -> BatchTask: |
| task = BatchTask(total) |
| _TASKS[task.id] = task |
| return task |
|
|
|
|
| def get_task(task_id: str) -> Optional[BatchTask]: |
| return _TASKS.get(task_id) |
|
|
|
|
| def delete_task(task_id: str) -> None: |
| _TASKS.pop(task_id, None) |
|
|
|
|
| async def expire_task(task_id: str, delay: int = 300) -> None: |
| await asyncio.sleep(delay) |
| delete_task(task_id) |
|
|
|
|
| __all__ = [ |
| "run_batch", |
| "BatchTask", |
| "create_task", |
| "get_task", |
| "delete_task", |
| "expire_task", |
| ] |
|
|