"""Batching helpers for future GPU workers.""" from __future__ import annotations from collections import defaultdict from typing import Iterable def bucket_pages(page_records: Iterable[dict]) -> dict[str, list[dict]]: buckets: dict[str, list[dict]] = defaultdict(list) for page in page_records: task = str(page.get("task", "parse")) resolution = str(page.get("resolution_bucket", "default")) buckets[f"{task}:{resolution}"].append(page) return dict(buckets) def batch_gpu_tasks(tasks: Iterable[dict], max_batch_size: int = 4) -> list[dict]: buckets: dict[tuple[str, str, str, str, str, str], list[dict]] = defaultdict(list) for task in tasks: key = ( str(task.get("provider", "huggingface_spaces")), str(task.get("space_name", "zeroshotGPU")), str(task.get("backend", "transformers")), str(task.get("task_type", "unknown")), str(task.get("model_role", "vlm")), str(task.get("model_id", "")), ) buckets[key].append(task) batches: list[dict] = [] safe_batch_size = max(int(max_batch_size), 1) for (provider, space_name, backend, task_type, model_role, model_id), bucket in sorted(buckets.items()): ordered = sorted(bucket, key=lambda item: (-int(item.get("priority", 0)), str(item.get("task_id", "")))) for offset in range(0, len(ordered), safe_batch_size): batch_tasks = ordered[offset : offset + safe_batch_size] batches.append( { "batch_id": f"gb{len(batches) + 1}", "provider": provider, "space_name": space_name, "backend": backend, "task_type": task_type, "model_role": model_role, "model_id": model_id or None, "task_count": len(batch_tasks), "tasks": batch_tasks, } ) return batches