ggload / app /core /batch.py
f2d90b38's picture
Upload 120 files
8cdca00 verified
"""
Batch utilities.
- run_batch: generic batch concurrency runner
- BatchTask: SSE task manager for admin batch operations
"""
import asyncio
import time
import uuid
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
from app.core.logger import logger
T = TypeVar("T")
async def run_batch(
items: List[str],
worker: Callable[[str], Awaitable[T]],
*,
batch_size: int = 50,
task: Optional["BatchTask"] = None,
on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
should_cancel: Optional[Callable[[], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
"""
分批并发执行,单项失败不影响整体
Args:
items: 待处理项列表
worker: 异步处理函数
batch_size: 每批大小
Returns:
{item: {"ok": bool, "data": ..., "error": ...}}
"""
try:
batch_size = int(batch_size)
except Exception:
batch_size = 50
batch_size = max(1, batch_size)
async def _one(item: str) -> tuple[str, dict]:
if (should_cancel and should_cancel()) or (task and task.cancelled):
return item, {"ok": False, "error": "cancelled", "cancelled": True}
try:
data = await worker(item)
result = {"ok": True, "data": data}
if task:
task.record(True)
if on_item:
try:
await on_item(item, result)
except Exception:
pass
return item, result
except Exception as e:
logger.warning(f"Batch item failed: {item[:16]}... - {e}")
result = {"ok": False, "error": str(e)}
if task:
task.record(False, error=str(e))
if on_item:
try:
await on_item(item, result)
except Exception:
pass
return item, result
results: Dict[str, dict] = {}
# 分批执行,避免一次性创建所有 task
for i in range(0, len(items), batch_size):
if (should_cancel and should_cancel()) or (task and task.cancelled):
break
chunk = items[i : i + batch_size]
pairs = await asyncio.gather(*(_one(x) for x in chunk))
results.update(dict(pairs))
return results
class BatchTask:
def __init__(self, total: int):
self.id = uuid.uuid4().hex
self.total = int(total)
self.processed = 0
self.ok = 0
self.fail = 0
self.status = "running"
self.warning: Optional[str] = None
self.result: Optional[Dict[str, Any]] = None
self.error: Optional[str] = None
self.created_at = time.time()
self._queues: List[asyncio.Queue] = []
self._final_event: Optional[Dict[str, Any]] = None
self.cancelled = False
def snapshot(self) -> Dict[str, Any]:
return {
"task_id": self.id,
"status": self.status,
"total": self.total,
"processed": self.processed,
"ok": self.ok,
"fail": self.fail,
"warning": self.warning,
}
def attach(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue(maxsize=200)
self._queues.append(q)
return q
def detach(self, q: asyncio.Queue) -> None:
if q in self._queues:
self._queues.remove(q)
def _publish(self, event: Dict[str, Any]) -> None:
for q in list(self._queues):
try:
q.put_nowait(event)
except Exception:
# Drop if queue is full or closed
pass
def record(
self, ok: bool, *, item: Any = None, detail: Any = None, error: str = ""
) -> None:
self.processed += 1
if ok:
self.ok += 1
else:
self.fail += 1
event: Dict[str, Any] = {
"type": "progress",
"task_id": self.id,
"total": self.total,
"processed": self.processed,
"ok": self.ok,
"fail": self.fail,
}
if item is not None:
event["item"] = item
if detail is not None:
event["detail"] = detail
if error:
event["error"] = error
self._publish(event)
def finish(self, result: Dict[str, Any], *, warning: Optional[str] = None) -> None:
self.status = "done"
self.result = result
self.warning = warning
event = {
"type": "done",
"task_id": self.id,
"total": self.total,
"processed": self.processed,
"ok": self.ok,
"fail": self.fail,
"warning": self.warning,
"result": result,
}
self._final_event = event
self._publish(event)
def fail_task(self, error: str) -> None:
self.status = "error"
self.error = error
event = {
"type": "error",
"task_id": self.id,
"total": self.total,
"processed": self.processed,
"ok": self.ok,
"fail": self.fail,
"error": error,
}
self._final_event = event
self._publish(event)
def cancel(self) -> None:
self.cancelled = True
def finish_cancelled(self) -> None:
self.status = "cancelled"
event = {
"type": "cancelled",
"task_id": self.id,
"total": self.total,
"processed": self.processed,
"ok": self.ok,
"fail": self.fail,
}
self._final_event = event
self._publish(event)
def final_event(self) -> Optional[Dict[str, Any]]:
return self._final_event
_TASKS: Dict[str, BatchTask] = {}
def create_task(total: int) -> BatchTask:
task = BatchTask(total)
_TASKS[task.id] = task
return task
def get_task(task_id: str) -> Optional[BatchTask]:
return _TASKS.get(task_id)
def delete_task(task_id: str) -> None:
_TASKS.pop(task_id, None)
async def expire_task(task_id: str, delay: int = 300) -> None:
await asyncio.sleep(delay)
delete_task(task_id)
__all__ = [
"run_batch",
"BatchTask",
"create_task",
"get_task",
"delete_task",
"expire_task",
]