File size: 1,598 Bytes
c63bddb
1bc2486
c63bddb
 
 
1bc2486
 
c63bddb
 
 
 
 
 
 
 
 
 
 
 
980b52e
 
 
c63bddb
 
 
 
 
 
1bc2486
c63bddb
1bc2486
 
 
 
 
 
 
 
 
c63bddb
 
 
1bc2486
c63bddb
 
 
1bc2486
c63bddb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import asyncio
import logging

from haystack.dataclasses import StreamingChunk

logger = logging.getLogger(__name__)


class StreamingQueue:
    """Bridges Haystack's streaming callback to an async iterator for FastAPI JSON lines streaming."""

    _DONE = object()

    def __init__(self) -> None:
        self._queue: asyncio.Queue = asyncio.Queue()

    async def callback(self, chunk: StreamingChunk) -> None:
        await self._queue.put(chunk)

    def sync_callback(self, chunk: StreamingChunk) -> None:
        self._queue.put_nowait(chunk)

    def __aiter__(self) -> "StreamingQueue":
        return self

    async def __anext__(self) -> StreamingChunk:
        item = await self._queue.get()
        if item is self._DONE:
            logger.info("Stream exhausted")
            raise StopAsyncIteration
        if item.tool_calls:
            for tc in item.tool_calls:
                logger.info("Tool call [%s] %s args=%s", tc.index, tc.tool_name or "(streaming)", tc.arguments or "")
        elif item.tool_call_result:
            logger.info("Tool result: %s", str(item.tool_call_result.result)[:200])
        elif item.finish_reason:
            logger.info("Finish reason: %s", item.finish_reason)
        elif item.content:
            logger.info("Text chunk: %r", item.content)
        return item

    async def __aenter__(self) -> "StreamingQueue":
        logger.info("StreamingQueue opened")
        return self

    async def __aexit__(self, *args) -> None:
        logger.info("StreamingQueue closed, sending done sentinel")
        await self._queue.put(self._DONE)