Spaces:
Running
Running
| """Async ComfyUI API client using aiohttp for HTTP and WebSocket communication. | |
| Based on the pattern from ComfyUI's own websockets_api_example.py. | |
| Communicates with ComfyUI at http://127.0.0.1:8188. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| import aiohttp | |
| logger = logging.getLogger(__name__) | |
| class ComfyUIResult: | |
| """Result from a completed ComfyUI generation.""" | |
| prompt_id: str | |
| outputs: dict[str, Any] = field(default_factory=dict) | |
| images: list[ImageOutput] = field(default_factory=list) | |
| class ImageOutput: | |
| """A single output image from ComfyUI.""" | |
| filename: str | |
| subfolder: str | |
| type: str # "output" or "temp" | |
| class ComfyUIError(Exception): | |
| """Raised when ComfyUI returns an error.""" | |
| class ComfyUIClient: | |
| """Async client for the ComfyUI API. | |
| Usage: | |
| client = ComfyUIClient("http://127.0.0.1:8188") | |
| result = await client.generate(workflow_dict) | |
| image_bytes = await client.download_image(result.images[0]) | |
| """ | |
| def __init__(self, base_url: str = "http://127.0.0.1:8188"): | |
| self.base_url = base_url.rstrip("/") | |
| self.client_id = str(uuid.uuid4()) | |
| self._session: aiohttp.ClientSession | None = None | |
| async def _get_session(self) -> aiohttp.ClientSession: | |
| if self._session is None or self._session.closed: | |
| self._session = aiohttp.ClientSession() | |
| return self._session | |
| async def close(self) -> None: | |
| if self._session and not self._session.closed: | |
| await self._session.close() | |
| # --- Core generation --- | |
| async def queue_prompt(self, workflow: dict) -> str: | |
| """Submit a workflow to ComfyUI. Returns the prompt_id.""" | |
| prompt_id = str(uuid.uuid4()) | |
| payload = { | |
| "prompt": workflow, | |
| "client_id": self.client_id, | |
| "prompt_id": prompt_id, | |
| } | |
| session = await self._get_session() | |
| async with session.post(f"{self.base_url}/prompt", json=payload) as resp: | |
| if resp.status != 200: | |
| body = await resp.text() | |
| raise ComfyUIError(f"Prompt rejected (HTTP {resp.status}): {body}") | |
| data = await resp.json() | |
| return data.get("prompt_id", prompt_id) | |
| async def wait_for_completion( | |
| self, prompt_id: str, timeout: float = 600 | |
| ) -> ComfyUIResult: | |
| """Wait for a prompt to finish via WebSocket, then fetch results.""" | |
| ws_host = self.base_url.replace("http://", "").replace("https://", "") | |
| ws_url = f"ws://{ws_host}/ws?clientId={self.client_id}" | |
| session = await self._get_session() | |
| try: | |
| async with asyncio.timeout(timeout): | |
| async with session.ws_connect(ws_url) as ws: | |
| async for msg in ws: | |
| if msg.type == aiohttp.WSMsgType.TEXT: | |
| data = json.loads(msg.data) | |
| if data.get("type") == "executing": | |
| exec_data = data.get("data", {}) | |
| if ( | |
| exec_data.get("node") is None | |
| and exec_data.get("prompt_id") == prompt_id | |
| ): | |
| break | |
| # Binary messages are latent previews — skip | |
| except TimeoutError: | |
| raise ComfyUIError( | |
| f"Timeout waiting for prompt {prompt_id} after {timeout}s" | |
| ) | |
| return await self._fetch_result(prompt_id) | |
| async def generate(self, workflow: dict, timeout: float = 600) -> ComfyUIResult: | |
| """Submit workflow and wait for completion. Returns the result.""" | |
| prompt_id = await self.queue_prompt(workflow) | |
| logger.info("Queued prompt %s", prompt_id) | |
| return await self.wait_for_completion(prompt_id, timeout) | |
| # --- Result fetching --- | |
| async def _fetch_result(self, prompt_id: str) -> ComfyUIResult: | |
| """Fetch history for a completed prompt and extract image outputs.""" | |
| history = await self.get_history(prompt_id) | |
| prompt_history = history.get(prompt_id, {}) | |
| outputs = prompt_history.get("outputs", {}) | |
| images: list[ImageOutput] = [] | |
| for _node_id, node_output in outputs.items(): | |
| for img_info in node_output.get("images", []): | |
| images.append( | |
| ImageOutput( | |
| filename=img_info["filename"], | |
| subfolder=img_info.get("subfolder", ""), | |
| type=img_info.get("type", "output"), | |
| ) | |
| ) | |
| return ComfyUIResult( | |
| prompt_id=prompt_id, | |
| outputs=outputs, | |
| images=images, | |
| ) | |
| async def download_image(self, image: ImageOutput) -> bytes: | |
| """Download an output image from ComfyUI.""" | |
| params = { | |
| "filename": image.filename, | |
| "subfolder": image.subfolder, | |
| "type": image.type, | |
| } | |
| session = await self._get_session() | |
| async with session.get(f"{self.base_url}/view", params=params) as resp: | |
| if resp.status != 200: | |
| raise ComfyUIError(f"Failed to download image: HTTP {resp.status}") | |
| return await resp.read() | |
| # --- Monitoring --- | |
| async def get_history(self, prompt_id: str) -> dict: | |
| """Get execution history for a prompt.""" | |
| session = await self._get_session() | |
| async with session.get(f"{self.base_url}/history/{prompt_id}") as resp: | |
| return await resp.json() | |
| async def get_system_stats(self) -> dict: | |
| """Get system stats including GPU VRAM info.""" | |
| session = await self._get_session() | |
| async with session.get(f"{self.base_url}/system_stats") as resp: | |
| return await resp.json() | |
| async def get_queue_info(self) -> dict: | |
| """Get current queue state (running + pending).""" | |
| session = await self._get_session() | |
| async with session.get(f"{self.base_url}/prompt") as resp: | |
| return await resp.json() | |
| async def get_queue_depth(self) -> int: | |
| """Get number of pending items in the queue.""" | |
| info = await self.get_queue_info() | |
| return len(info.get("queue_pending", [])) | |
| async def get_vram_free_gb(self) -> float | None: | |
| """Get free VRAM in GB, or None if unavailable.""" | |
| try: | |
| stats = await self.get_system_stats() | |
| devices = stats.get("devices", []) | |
| if devices: | |
| return devices[0].get("vram_free", 0) / (1024**3) | |
| except Exception: | |
| logger.warning("Failed to get VRAM stats", exc_info=True) | |
| return None | |
| async def is_available(self) -> bool: | |
| """Check if ComfyUI is reachable.""" | |
| try: | |
| session = await self._get_session() | |
| async with session.get( | |
| f"{self.base_url}/system_stats", timeout=aiohttp.ClientTimeout(total=5) | |
| ) as resp: | |
| return resp.status == 200 | |
| except Exception: | |
| return False | |
| async def upload_image( | |
| self, image_bytes: bytes, filename: str, overwrite: bool = True | |
| ) -> str: | |
| """Upload an image to ComfyUI's input directory. Returns the stored filename.""" | |
| session = await self._get_session() | |
| data = aiohttp.FormData() | |
| data.add_field( | |
| "image", image_bytes, filename=filename, content_type="image/png" | |
| ) | |
| data.add_field("overwrite", str(overwrite).lower()) | |
| async with session.post(f"{self.base_url}/upload/image", data=data) as resp: | |
| if resp.status != 200: | |
| body = await resp.text() | |
| raise ComfyUIError(f"Image upload failed (HTTP {resp.status}): {body}") | |
| result = await resp.json() | |
| return result.get("name", filename) | |
| async def get_models(self, folder: str = "loras") -> list[str]: | |
| """List available models in a folder (loras, checkpoints, etc.).""" | |
| session = await self._get_session() | |
| async with session.get(f"{self.base_url}/models/{folder}") as resp: | |
| if resp.status == 200: | |
| return await resp.json() | |
| return [] | |