File size: 4,762 Bytes
43a2563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Dynamic micro-batching for higher throughput under concurrency.

Many small concurrent requests are the worst case for a transformer: each pays
the fixed per-call overhead alone. The batcher fixes that. Callers ``await
submit(text)``; the batcher collects everything that arrives within a short
window (or until ``max_batch_size`` is reached), runs ONE forward pass for the
whole group, then fans the results back out to each waiter.

Design notes:
* A single background task owns the queue, so there is no lock contention and
  exactly one inference runs at a time (a CPU model has no intra-op parallelism
  to gain from concurrent forward passes anyway).
* The model is synchronous and CPU-bound, so it runs in a thread via
  ``run_in_executor`` to avoid blocking the event loop.
* A failed batch propagates the exception to every waiter in that batch and the
  loop keeps going — one bad batch never wedges the service.
"""
from __future__ import annotations

import asyncio
import logging
from dataclasses import dataclass, field
from typing import Dict, List

from app.classifier import Classifier, Distribution
from app.metrics import BATCH_SIZE, INFERENCE_LATENCY

logger = logging.getLogger(__name__)


@dataclass
class _Job:
    text: str
    future: "asyncio.Future[Distribution]" = field(default_factory=lambda: asyncio.get_event_loop().create_future())


class MicroBatcher:
    """Coalesces concurrent single-text predictions into batched forward passes."""

    def __init__(
        self,
        classifier: Classifier,
        max_batch_size: int = 64,
        max_delay_ms: float = 5.0,
    ) -> None:
        self._classifier = classifier
        self._max_batch_size = max(1, max_batch_size)
        self._max_delay = max(0.0, max_delay_ms) / 1000.0
        self._queue: "asyncio.Queue[_Job]" = asyncio.Queue()
        self._worker: asyncio.Task | None = None
        self._closing = False

    async def start(self) -> None:
        if self._worker is None:
            self._closing = False
            self._worker = asyncio.create_task(self._run(), name="micro-batcher")

    async def stop(self) -> None:
        self._closing = True
        if self._worker is not None:
            self._worker.cancel()
            try:
                await self._worker
            except asyncio.CancelledError:
                pass
            self._worker = None

    async def submit(self, text: str) -> Distribution:
        """Enqueue one text and await its probability distribution."""
        if self._closing:
            raise RuntimeError("batcher is shutting down")
        job = _Job(text=text)
        await self._queue.put(job)
        return await job.future

    async def submit_many(self, texts: List[str]) -> List[Distribution]:
        """Submit several texts concurrently; they share the batcher's batches."""
        return await asyncio.gather(*(self.submit(t) for t in texts))

    async def _collect_batch(self) -> List[_Job]:
        """Block for the first job, then drain up to the window / size limit."""
        first = await self._queue.get()
        batch = [first]
        if self._max_delay > 0:
            deadline = asyncio.get_event_loop().time() + self._max_delay
            while len(batch) < self._max_batch_size:
                remaining = deadline - asyncio.get_event_loop().time()
                if remaining <= 0:
                    break
                try:
                    job = await asyncio.wait_for(self._queue.get(), timeout=remaining)
                    batch.append(job)
                except asyncio.TimeoutError:
                    break
        # Opportunistically grab anything already queued without further waiting.
        while len(batch) < self._max_batch_size and not self._queue.empty():
            batch.append(self._queue.get_nowait())
        return batch

    async def _run(self) -> None:
        loop = asyncio.get_event_loop()
        while True:
            batch = await self._collect_batch()
            texts = [job.text for job in batch]
            BATCH_SIZE.observe(len(texts))
            try:
                with INFERENCE_LATENCY.time():
                    results: List[Distribution] = await loop.run_in_executor(
                        None, self._classifier.predict, texts
                    )
                for job, dist in zip(batch, results):
                    if not job.future.done():
                        job.future.set_result(dist)
            except Exception as exc:  # one bad batch must not kill the worker
                logger.exception("batch inference failed for %d items", len(texts))
                for job in batch:
                    if not job.future.done():
                        job.future.set_exception(exc)