|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from socketio import ASGIApp, AsyncServer |
|
|
|
|
|
from invokeai.app.services.events.events_common import ( |
|
|
BatchEnqueuedEvent, |
|
|
BulkDownloadCompleteEvent, |
|
|
BulkDownloadErrorEvent, |
|
|
BulkDownloadEventBase, |
|
|
BulkDownloadStartedEvent, |
|
|
DownloadCancelledEvent, |
|
|
DownloadCompleteEvent, |
|
|
DownloadErrorEvent, |
|
|
DownloadEventBase, |
|
|
DownloadProgressEvent, |
|
|
DownloadStartedEvent, |
|
|
FastAPIEvent, |
|
|
InvocationCompleteEvent, |
|
|
InvocationErrorEvent, |
|
|
InvocationProgressEvent, |
|
|
InvocationStartedEvent, |
|
|
ModelEventBase, |
|
|
ModelInstallCancelledEvent, |
|
|
ModelInstallCompleteEvent, |
|
|
ModelInstallDownloadProgressEvent, |
|
|
ModelInstallDownloadsCompleteEvent, |
|
|
ModelInstallErrorEvent, |
|
|
ModelInstallStartedEvent, |
|
|
ModelLoadCompleteEvent, |
|
|
ModelLoadStartedEvent, |
|
|
QueueClearedEvent, |
|
|
QueueEventBase, |
|
|
QueueItemStatusChangedEvent, |
|
|
register_events, |
|
|
) |
|
|
|
|
|
|
|
|
class QueueSubscriptionEvent(BaseModel): |
|
|
"""Event data for subscribing to the socket.io queue room. |
|
|
This is a pydantic model to ensure the data is in the correct format.""" |
|
|
|
|
|
queue_id: str |
|
|
|
|
|
|
|
|
class BulkDownloadSubscriptionEvent(BaseModel): |
|
|
"""Event data for subscribing to the socket.io bulk downloads room. |
|
|
This is a pydantic model to ensure the data is in the correct format.""" |
|
|
|
|
|
bulk_download_id: str |
|
|
|
|
|
|
|
|
QUEUE_EVENTS = { |
|
|
InvocationStartedEvent, |
|
|
InvocationProgressEvent, |
|
|
InvocationCompleteEvent, |
|
|
InvocationErrorEvent, |
|
|
QueueItemStatusChangedEvent, |
|
|
BatchEnqueuedEvent, |
|
|
QueueClearedEvent, |
|
|
} |
|
|
|
|
|
MODEL_EVENTS = { |
|
|
DownloadCancelledEvent, |
|
|
DownloadCompleteEvent, |
|
|
DownloadErrorEvent, |
|
|
DownloadProgressEvent, |
|
|
DownloadStartedEvent, |
|
|
ModelLoadStartedEvent, |
|
|
ModelLoadCompleteEvent, |
|
|
ModelInstallDownloadProgressEvent, |
|
|
ModelInstallDownloadsCompleteEvent, |
|
|
ModelInstallStartedEvent, |
|
|
ModelInstallCompleteEvent, |
|
|
ModelInstallCancelledEvent, |
|
|
ModelInstallErrorEvent, |
|
|
} |
|
|
|
|
|
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} |
|
|
|
|
|
|
|
|
class SocketIO: |
|
|
_sub_queue = "subscribe_queue" |
|
|
_unsub_queue = "unsubscribe_queue" |
|
|
|
|
|
_sub_bulk_download = "subscribe_bulk_download" |
|
|
_unsub_bulk_download = "unsubscribe_bulk_download" |
|
|
|
|
|
def __init__(self, app: FastAPI): |
|
|
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") |
|
|
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") |
|
|
app.mount("/ws", self._app) |
|
|
|
|
|
self._sio.on(self._sub_queue, handler=self._handle_sub_queue) |
|
|
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) |
|
|
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) |
|
|
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) |
|
|
|
|
|
register_events(QUEUE_EVENTS, self._handle_queue_event) |
|
|
register_events(MODEL_EVENTS, self._handle_model_event) |
|
|
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) |
|
|
|
|
|
async def _handle_sub_queue(self, sid: str, data: Any) -> None: |
|
|
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) |
|
|
|
|
|
async def _handle_unsub_queue(self, sid: str, data: Any) -> None: |
|
|
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) |
|
|
|
|
|
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: |
|
|
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) |
|
|
|
|
|
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: |
|
|
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) |
|
|
|
|
|
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): |
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id) |
|
|
|
|
|
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None: |
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) |
|
|
|
|
|
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: |
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) |
|
|
|