Spaces:
Sleeping
Sleeping
File size: 4,762 Bytes
43a2563 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | """Dynamic micro-batching for higher throughput under concurrency.
Many small concurrent requests are the worst case for a transformer: each pays
the fixed per-call overhead alone. The batcher fixes that. Callers ``await
submit(text)``; the batcher collects everything that arrives within a short
window (or until ``max_batch_size`` is reached), runs ONE forward pass for the
whole group, then fans the results back out to each waiter.
Design notes:
* A single background task owns the queue, so there is no lock contention and
exactly one inference runs at a time (a CPU model has no intra-op parallelism
to gain from concurrent forward passes anyway).
* The model is synchronous and CPU-bound, so it runs in a thread via
``run_in_executor`` to avoid blocking the event loop.
* A failed batch propagates the exception to every waiter in that batch and the
loop keeps going — one bad batch never wedges the service.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Dict, List
from app.classifier import Classifier, Distribution
from app.metrics import BATCH_SIZE, INFERENCE_LATENCY
logger = logging.getLogger(__name__)
@dataclass
class _Job:
text: str
future: "asyncio.Future[Distribution]" = field(default_factory=lambda: asyncio.get_event_loop().create_future())
class MicroBatcher:
"""Coalesces concurrent single-text predictions into batched forward passes."""
def __init__(
self,
classifier: Classifier,
max_batch_size: int = 64,
max_delay_ms: float = 5.0,
) -> None:
self._classifier = classifier
self._max_batch_size = max(1, max_batch_size)
self._max_delay = max(0.0, max_delay_ms) / 1000.0
self._queue: "asyncio.Queue[_Job]" = asyncio.Queue()
self._worker: asyncio.Task | None = None
self._closing = False
async def start(self) -> None:
if self._worker is None:
self._closing = False
self._worker = asyncio.create_task(self._run(), name="micro-batcher")
async def stop(self) -> None:
self._closing = True
if self._worker is not None:
self._worker.cancel()
try:
await self._worker
except asyncio.CancelledError:
pass
self._worker = None
async def submit(self, text: str) -> Distribution:
"""Enqueue one text and await its probability distribution."""
if self._closing:
raise RuntimeError("batcher is shutting down")
job = _Job(text=text)
await self._queue.put(job)
return await job.future
async def submit_many(self, texts: List[str]) -> List[Distribution]:
"""Submit several texts concurrently; they share the batcher's batches."""
return await asyncio.gather(*(self.submit(t) for t in texts))
async def _collect_batch(self) -> List[_Job]:
"""Block for the first job, then drain up to the window / size limit."""
first = await self._queue.get()
batch = [first]
if self._max_delay > 0:
deadline = asyncio.get_event_loop().time() + self._max_delay
while len(batch) < self._max_batch_size:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
break
try:
job = await asyncio.wait_for(self._queue.get(), timeout=remaining)
batch.append(job)
except asyncio.TimeoutError:
break
# Opportunistically grab anything already queued without further waiting.
while len(batch) < self._max_batch_size and not self._queue.empty():
batch.append(self._queue.get_nowait())
return batch
async def _run(self) -> None:
loop = asyncio.get_event_loop()
while True:
batch = await self._collect_batch()
texts = [job.text for job in batch]
BATCH_SIZE.observe(len(texts))
try:
with INFERENCE_LATENCY.time():
results: List[Distribution] = await loop.run_in_executor(
None, self._classifier.predict, texts
)
for job, dist in zip(batch, results):
if not job.future.done():
job.future.set_result(dist)
except Exception as exc: # one bad batch must not kill the worker
logger.exception("batch inference failed for %d items", len(texts))
for job in batch:
if not job.future.done():
job.future.set_exception(exc)
|