router / deeptutor /api /utils /task_log_stream.py
Piyush1225's picture
push
5df8a73
import asyncio
import contextlib
import json
import logging
import threading
from collections import deque
from collections.abc import AsyncGenerator
from typing import Any
def _format_sse(event: str, payload: dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(payload, ensure_ascii=False, default=str)}\n\n"
_PIPELINE_LOGGER_NAMES = [
"deeptutor.LlamaIndexPipeline",
"deeptutor.CustomEmbedding",
"deeptutor.EmbeddingClient",
"deeptutor.KnowledgeInit",
]
class _TaskStreamHandler(logging.Handler):
"""Forwards log records from pipeline loggers into a task's SSE stream."""
def __init__(self, task_id: str, manager: "KnowledgeTaskStreamManager"):
super().__init__(level=logging.INFO)
self._task_id = task_id
self._manager = manager
def emit(self, record: logging.LogRecord) -> None:
try:
self._manager.emit_log(self._task_id, record.getMessage())
except Exception:
pass
class KnowledgeTaskStreamManager:
_instance: "KnowledgeTaskStreamManager | None" = None
_instance_lock = threading.Lock()
def __init__(self):
self._lock = threading.Lock()
self._buffers: dict[str, deque[dict[str, Any]]] = {}
self._subscribers: dict[str, list[tuple[asyncio.Queue, asyncio.AbstractEventLoop]]] = {}
@classmethod
def get_instance(cls) -> "KnowledgeTaskStreamManager":
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def ensure_task(self, task_id: str):
with self._lock:
self._buffers.setdefault(task_id, deque(maxlen=500))
self._subscribers.setdefault(task_id, [])
def emit(self, task_id: str, event: str, payload: dict[str, Any]):
event_payload = {"event": event, "payload": payload}
with self._lock:
self._buffers.setdefault(task_id, deque(maxlen=500)).append(event_payload)
subscribers = list(self._subscribers.get(task_id, []))
for queue, loop in subscribers:
try:
loop.call_soon_threadsafe(self._queue_event, queue, event_payload)
except RuntimeError:
continue
def emit_log(self, task_id: str, line: str):
self.emit(task_id, "log", {"line": line, "task_id": task_id})
def emit_complete(self, task_id: str, detail: str = "Task completed"):
self.emit(task_id, "complete", {"detail": detail, "task_id": task_id})
def emit_failed(self, task_id: str, detail: str):
self.emit(task_id, "failed", {"detail": detail, "task_id": task_id})
def subscribe(
self, task_id: str
) -> tuple[asyncio.Queue[dict[str, Any]], list[dict[str, Any]], asyncio.AbstractEventLoop]:
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=200)
loop = asyncio.get_running_loop()
with self._lock:
self._buffers.setdefault(task_id, deque(maxlen=500))
self._subscribers.setdefault(task_id, []).append((queue, loop))
backlog = list(self._buffers[task_id])
return queue, backlog, loop
def unsubscribe(self, task_id: str, queue: asyncio.Queue[dict[str, Any]], loop: asyncio.AbstractEventLoop):
with self._lock:
subscribers = self._subscribers.get(task_id, [])
self._subscribers[task_id] = [
(subscriber_queue, subscriber_loop)
for subscriber_queue, subscriber_loop in subscribers
if subscriber_queue is not queue or subscriber_loop is not loop
]
async def stream(self, task_id: str) -> AsyncGenerator[str, None]:
queue, backlog, loop = self.subscribe(task_id)
try:
for item in backlog:
yield _format_sse(item["event"], item["payload"])
if backlog and backlog[-1]["event"] in {"complete", "failed"}:
return
while True:
item = await queue.get()
yield _format_sse(item["event"], item["payload"])
if item["event"] in {"complete", "failed"}:
break
finally:
self.unsubscribe(task_id, queue, loop)
@staticmethod
def _queue_event(queue: asyncio.Queue[dict[str, Any]], payload: dict[str, Any]):
try:
queue.put_nowait(payload)
except asyncio.QueueFull:
pass
@contextlib.contextmanager
def capture_task_logs(task_id: str):
"""Capture logs from pipeline loggers and forward them to the task's SSE stream.
Only loggers in ``_PIPELINE_LOGGER_NAMES`` are tapped so that unrelated
concurrent request logs do not leak into the stream. The handler is also
safe to call from ``run_in_executor`` threads because Python logging
handlers are global and ``emit_log`` uses ``call_soon_threadsafe``.
"""
manager = KnowledgeTaskStreamManager.get_instance()
manager.ensure_task(task_id)
handler = _TaskStreamHandler(task_id, manager)
attached: list[logging.Logger] = []
for name in _PIPELINE_LOGGER_NAMES:
lg = logging.getLogger(name)
lg.addHandler(handler)
attached.append(lg)
try:
yield
finally:
for lg in attached:
lg.removeHandler(handler)
def get_task_stream_manager() -> KnowledgeTaskStreamManager:
return KnowledgeTaskStreamManager.get_instance()