"""Async ComfyUI API client using aiohttp for HTTP and WebSocket communication. Based on the pattern from ComfyUI's own websockets_api_example.py. Communicates with ComfyUI at http://127.0.0.1:8188. """ from __future__ import annotations import asyncio import json import logging import uuid from dataclasses import dataclass, field from typing import Any import aiohttp logger = logging.getLogger(__name__) @dataclass class ComfyUIResult: """Result from a completed ComfyUI generation.""" prompt_id: str outputs: dict[str, Any] = field(default_factory=dict) images: list[ImageOutput] = field(default_factory=list) @dataclass class ImageOutput: """A single output image from ComfyUI.""" filename: str subfolder: str type: str # "output" or "temp" class ComfyUIError(Exception): """Raised when ComfyUI returns an error.""" class ComfyUIClient: """Async client for the ComfyUI API. Usage: client = ComfyUIClient("http://127.0.0.1:8188") result = await client.generate(workflow_dict) image_bytes = await client.download_image(result.images[0]) """ def __init__(self, base_url: str = "http://127.0.0.1:8188"): self.base_url = base_url.rstrip("/") self.client_id = str(uuid.uuid4()) self._session: aiohttp.ClientSession | None = None async def _get_session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: self._session = aiohttp.ClientSession() return self._session async def close(self) -> None: if self._session and not self._session.closed: await self._session.close() # --- Core generation --- async def queue_prompt(self, workflow: dict) -> str: """Submit a workflow to ComfyUI. Returns the prompt_id.""" prompt_id = str(uuid.uuid4()) payload = { "prompt": workflow, "client_id": self.client_id, "prompt_id": prompt_id, } session = await self._get_session() async with session.post(f"{self.base_url}/prompt", json=payload) as resp: if resp.status != 200: body = await resp.text() raise ComfyUIError(f"Prompt rejected (HTTP {resp.status}): {body}") data = await resp.json() return data.get("prompt_id", prompt_id) async def wait_for_completion( self, prompt_id: str, timeout: float = 600 ) -> ComfyUIResult: """Wait for a prompt to finish via WebSocket, then fetch results.""" ws_host = self.base_url.replace("http://", "").replace("https://", "") ws_url = f"ws://{ws_host}/ws?clientId={self.client_id}" session = await self._get_session() try: async with asyncio.timeout(timeout): async with session.ws_connect(ws_url) as ws: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: data = json.loads(msg.data) if data.get("type") == "executing": exec_data = data.get("data", {}) if ( exec_data.get("node") is None and exec_data.get("prompt_id") == prompt_id ): break # Binary messages are latent previews — skip except TimeoutError: raise ComfyUIError( f"Timeout waiting for prompt {prompt_id} after {timeout}s" ) return await self._fetch_result(prompt_id) async def generate(self, workflow: dict, timeout: float = 600) -> ComfyUIResult: """Submit workflow and wait for completion. Returns the result.""" prompt_id = await self.queue_prompt(workflow) logger.info("Queued prompt %s", prompt_id) return await self.wait_for_completion(prompt_id, timeout) # --- Result fetching --- async def _fetch_result(self, prompt_id: str) -> ComfyUIResult: """Fetch history for a completed prompt and extract image outputs.""" history = await self.get_history(prompt_id) prompt_history = history.get(prompt_id, {}) outputs = prompt_history.get("outputs", {}) images: list[ImageOutput] = [] for _node_id, node_output in outputs.items(): for img_info in node_output.get("images", []): images.append( ImageOutput( filename=img_info["filename"], subfolder=img_info.get("subfolder", ""), type=img_info.get("type", "output"), ) ) return ComfyUIResult( prompt_id=prompt_id, outputs=outputs, images=images, ) async def download_image(self, image: ImageOutput) -> bytes: """Download an output image from ComfyUI.""" params = { "filename": image.filename, "subfolder": image.subfolder, "type": image.type, } session = await self._get_session() async with session.get(f"{self.base_url}/view", params=params) as resp: if resp.status != 200: raise ComfyUIError(f"Failed to download image: HTTP {resp.status}") return await resp.read() # --- Monitoring --- async def get_history(self, prompt_id: str) -> dict: """Get execution history for a prompt.""" session = await self._get_session() async with session.get(f"{self.base_url}/history/{prompt_id}") as resp: return await resp.json() async def get_system_stats(self) -> dict: """Get system stats including GPU VRAM info.""" session = await self._get_session() async with session.get(f"{self.base_url}/system_stats") as resp: return await resp.json() async def get_queue_info(self) -> dict: """Get current queue state (running + pending).""" session = await self._get_session() async with session.get(f"{self.base_url}/prompt") as resp: return await resp.json() async def get_queue_depth(self) -> int: """Get number of pending items in the queue.""" info = await self.get_queue_info() return len(info.get("queue_pending", [])) async def get_vram_free_gb(self) -> float | None: """Get free VRAM in GB, or None if unavailable.""" try: stats = await self.get_system_stats() devices = stats.get("devices", []) if devices: return devices[0].get("vram_free", 0) / (1024**3) except Exception: logger.warning("Failed to get VRAM stats", exc_info=True) return None async def is_available(self) -> bool: """Check if ComfyUI is reachable.""" try: session = await self._get_session() async with session.get( f"{self.base_url}/system_stats", timeout=aiohttp.ClientTimeout(total=5) ) as resp: return resp.status == 200 except Exception: return False async def upload_image( self, image_bytes: bytes, filename: str, overwrite: bool = True ) -> str: """Upload an image to ComfyUI's input directory. Returns the stored filename.""" session = await self._get_session() data = aiohttp.FormData() data.add_field( "image", image_bytes, filename=filename, content_type="image/png" ) data.add_field("overwrite", str(overwrite).lower()) async with session.post(f"{self.base_url}/upload/image", data=data) as resp: if resp.status != 200: body = await resp.text() raise ComfyUIError(f"Image upload failed (HTTP {resp.status}): {body}") result = await resp.json() return result.get("name", filename) async def get_models(self, folder: str = "loras") -> list[str]: """List available models in a folder (loras, checkpoints, etc.).""" session = await self._get_session() async with session.get(f"{self.base_url}/models/{folder}") as resp: if resp.status == 200: return await resp.json() return []