zeroshotGPU / zsgdp /gpu /batching.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Batching helpers for future GPU workers."""
from __future__ import annotations
from collections import defaultdict
from typing import Iterable
def bucket_pages(page_records: Iterable[dict]) -> dict[str, list[dict]]:
buckets: dict[str, list[dict]] = defaultdict(list)
for page in page_records:
task = str(page.get("task", "parse"))
resolution = str(page.get("resolution_bucket", "default"))
buckets[f"{task}:{resolution}"].append(page)
return dict(buckets)
def batch_gpu_tasks(tasks: Iterable[dict], max_batch_size: int = 4) -> list[dict]:
buckets: dict[tuple[str, str, str, str, str, str], list[dict]] = defaultdict(list)
for task in tasks:
key = (
str(task.get("provider", "huggingface_spaces")),
str(task.get("space_name", "zeroshotGPU")),
str(task.get("backend", "transformers")),
str(task.get("task_type", "unknown")),
str(task.get("model_role", "vlm")),
str(task.get("model_id", "")),
)
buckets[key].append(task)
batches: list[dict] = []
safe_batch_size = max(int(max_batch_size), 1)
for (provider, space_name, backend, task_type, model_role, model_id), bucket in sorted(buckets.items()):
ordered = sorted(bucket, key=lambda item: (-int(item.get("priority", 0)), str(item.get("task_id", ""))))
for offset in range(0, len(ordered), safe_batch_size):
batch_tasks = ordered[offset : offset + safe_batch_size]
batches.append(
{
"batch_id": f"gb{len(batches) + 1}",
"provider": provider,
"space_name": space_name,
"backend": backend,
"task_type": task_type,
"model_role": model_role,
"model_id": model_id or None,
"task_count": len(batch_tasks),
"tasks": batch_tasks,
}
)
return batches