Spaces:
Sleeping
Sleeping
| import asyncio | |
| import contextlib | |
| import json | |
| import logging | |
| import threading | |
| from collections import deque | |
| from collections.abc import AsyncGenerator | |
| from typing import Any | |
| def _format_sse(event: str, payload: dict[str, Any]) -> str: | |
| return f"event: {event}\ndata: {json.dumps(payload, ensure_ascii=False, default=str)}\n\n" | |
| _PIPELINE_LOGGER_NAMES = [ | |
| "deeptutor.LlamaIndexPipeline", | |
| "deeptutor.CustomEmbedding", | |
| "deeptutor.EmbeddingClient", | |
| "deeptutor.KnowledgeInit", | |
| ] | |
| class _TaskStreamHandler(logging.Handler): | |
| """Forwards log records from pipeline loggers into a task's SSE stream.""" | |
| def __init__(self, task_id: str, manager: "KnowledgeTaskStreamManager"): | |
| super().__init__(level=logging.INFO) | |
| self._task_id = task_id | |
| self._manager = manager | |
| def emit(self, record: logging.LogRecord) -> None: | |
| try: | |
| self._manager.emit_log(self._task_id, record.getMessage()) | |
| except Exception: | |
| pass | |
| class KnowledgeTaskStreamManager: | |
| _instance: "KnowledgeTaskStreamManager | None" = None | |
| _instance_lock = threading.Lock() | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self._buffers: dict[str, deque[dict[str, Any]]] = {} | |
| self._subscribers: dict[str, list[tuple[asyncio.Queue, asyncio.AbstractEventLoop]]] = {} | |
| def get_instance(cls) -> "KnowledgeTaskStreamManager": | |
| if cls._instance is None: | |
| with cls._instance_lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def ensure_task(self, task_id: str): | |
| with self._lock: | |
| self._buffers.setdefault(task_id, deque(maxlen=500)) | |
| self._subscribers.setdefault(task_id, []) | |
| def emit(self, task_id: str, event: str, payload: dict[str, Any]): | |
| event_payload = {"event": event, "payload": payload} | |
| with self._lock: | |
| self._buffers.setdefault(task_id, deque(maxlen=500)).append(event_payload) | |
| subscribers = list(self._subscribers.get(task_id, [])) | |
| for queue, loop in subscribers: | |
| try: | |
| loop.call_soon_threadsafe(self._queue_event, queue, event_payload) | |
| except RuntimeError: | |
| continue | |
| def emit_log(self, task_id: str, line: str): | |
| self.emit(task_id, "log", {"line": line, "task_id": task_id}) | |
| def emit_complete(self, task_id: str, detail: str = "Task completed"): | |
| self.emit(task_id, "complete", {"detail": detail, "task_id": task_id}) | |
| def emit_failed(self, task_id: str, detail: str): | |
| self.emit(task_id, "failed", {"detail": detail, "task_id": task_id}) | |
| def subscribe( | |
| self, task_id: str | |
| ) -> tuple[asyncio.Queue[dict[str, Any]], list[dict[str, Any]], asyncio.AbstractEventLoop]: | |
| queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=200) | |
| loop = asyncio.get_running_loop() | |
| with self._lock: | |
| self._buffers.setdefault(task_id, deque(maxlen=500)) | |
| self._subscribers.setdefault(task_id, []).append((queue, loop)) | |
| backlog = list(self._buffers[task_id]) | |
| return queue, backlog, loop | |
| def unsubscribe(self, task_id: str, queue: asyncio.Queue[dict[str, Any]], loop: asyncio.AbstractEventLoop): | |
| with self._lock: | |
| subscribers = self._subscribers.get(task_id, []) | |
| self._subscribers[task_id] = [ | |
| (subscriber_queue, subscriber_loop) | |
| for subscriber_queue, subscriber_loop in subscribers | |
| if subscriber_queue is not queue or subscriber_loop is not loop | |
| ] | |
| async def stream(self, task_id: str) -> AsyncGenerator[str, None]: | |
| queue, backlog, loop = self.subscribe(task_id) | |
| try: | |
| for item in backlog: | |
| yield _format_sse(item["event"], item["payload"]) | |
| if backlog and backlog[-1]["event"] in {"complete", "failed"}: | |
| return | |
| while True: | |
| item = await queue.get() | |
| yield _format_sse(item["event"], item["payload"]) | |
| if item["event"] in {"complete", "failed"}: | |
| break | |
| finally: | |
| self.unsubscribe(task_id, queue, loop) | |
| def _queue_event(queue: asyncio.Queue[dict[str, Any]], payload: dict[str, Any]): | |
| try: | |
| queue.put_nowait(payload) | |
| except asyncio.QueueFull: | |
| pass | |
| def capture_task_logs(task_id: str): | |
| """Capture logs from pipeline loggers and forward them to the task's SSE stream. | |
| Only loggers in ``_PIPELINE_LOGGER_NAMES`` are tapped so that unrelated | |
| concurrent request logs do not leak into the stream. The handler is also | |
| safe to call from ``run_in_executor`` threads because Python logging | |
| handlers are global and ``emit_log`` uses ``call_soon_threadsafe``. | |
| """ | |
| manager = KnowledgeTaskStreamManager.get_instance() | |
| manager.ensure_task(task_id) | |
| handler = _TaskStreamHandler(task_id, manager) | |
| attached: list[logging.Logger] = [] | |
| for name in _PIPELINE_LOGGER_NAMES: | |
| lg = logging.getLogger(name) | |
| lg.addHandler(handler) | |
| attached.append(lg) | |
| try: | |
| yield | |
| finally: | |
| for lg in attached: | |
| lg.removeHandler(handler) | |
| def get_task_stream_manager() -> KnowledgeTaskStreamManager: | |
| return KnowledgeTaskStreamManager.get_instance() | |