dippoo's picture
Initial deployment - Content Engine
ed37502
raw
history blame
8.45 kB
"""Async ComfyUI API client using aiohttp for HTTP and WebSocket communication.
Based on the pattern from ComfyUI's own websockets_api_example.py.
Communicates with ComfyUI at http://127.0.0.1:8188.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
import aiohttp
logger = logging.getLogger(__name__)
@dataclass
class ComfyUIResult:
"""Result from a completed ComfyUI generation."""
prompt_id: str
outputs: dict[str, Any] = field(default_factory=dict)
images: list[ImageOutput] = field(default_factory=list)
@dataclass
class ImageOutput:
"""A single output image from ComfyUI."""
filename: str
subfolder: str
type: str # "output" or "temp"
class ComfyUIError(Exception):
"""Raised when ComfyUI returns an error."""
class ComfyUIClient:
"""Async client for the ComfyUI API.
Usage:
client = ComfyUIClient("http://127.0.0.1:8188")
result = await client.generate(workflow_dict)
image_bytes = await client.download_image(result.images[0])
"""
def __init__(self, base_url: str = "http://127.0.0.1:8188"):
self.base_url = base_url.rstrip("/")
self.client_id = str(uuid.uuid4())
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
if self._session and not self._session.closed:
await self._session.close()
# --- Core generation ---
async def queue_prompt(self, workflow: dict) -> str:
"""Submit a workflow to ComfyUI. Returns the prompt_id."""
prompt_id = str(uuid.uuid4())
payload = {
"prompt": workflow,
"client_id": self.client_id,
"prompt_id": prompt_id,
}
session = await self._get_session()
async with session.post(f"{self.base_url}/prompt", json=payload) as resp:
if resp.status != 200:
body = await resp.text()
raise ComfyUIError(f"Prompt rejected (HTTP {resp.status}): {body}")
data = await resp.json()
return data.get("prompt_id", prompt_id)
async def wait_for_completion(
self, prompt_id: str, timeout: float = 600
) -> ComfyUIResult:
"""Wait for a prompt to finish via WebSocket, then fetch results."""
ws_host = self.base_url.replace("http://", "").replace("https://", "")
ws_url = f"ws://{ws_host}/ws?clientId={self.client_id}"
session = await self._get_session()
try:
async with asyncio.timeout(timeout):
async with session.ws_connect(ws_url) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
if data.get("type") == "executing":
exec_data = data.get("data", {})
if (
exec_data.get("node") is None
and exec_data.get("prompt_id") == prompt_id
):
break
# Binary messages are latent previews — skip
except TimeoutError:
raise ComfyUIError(
f"Timeout waiting for prompt {prompt_id} after {timeout}s"
)
return await self._fetch_result(prompt_id)
async def generate(self, workflow: dict, timeout: float = 600) -> ComfyUIResult:
"""Submit workflow and wait for completion. Returns the result."""
prompt_id = await self.queue_prompt(workflow)
logger.info("Queued prompt %s", prompt_id)
return await self.wait_for_completion(prompt_id, timeout)
# --- Result fetching ---
async def _fetch_result(self, prompt_id: str) -> ComfyUIResult:
"""Fetch history for a completed prompt and extract image outputs."""
history = await self.get_history(prompt_id)
prompt_history = history.get(prompt_id, {})
outputs = prompt_history.get("outputs", {})
images: list[ImageOutput] = []
for _node_id, node_output in outputs.items():
for img_info in node_output.get("images", []):
images.append(
ImageOutput(
filename=img_info["filename"],
subfolder=img_info.get("subfolder", ""),
type=img_info.get("type", "output"),
)
)
return ComfyUIResult(
prompt_id=prompt_id,
outputs=outputs,
images=images,
)
async def download_image(self, image: ImageOutput) -> bytes:
"""Download an output image from ComfyUI."""
params = {
"filename": image.filename,
"subfolder": image.subfolder,
"type": image.type,
}
session = await self._get_session()
async with session.get(f"{self.base_url}/view", params=params) as resp:
if resp.status != 200:
raise ComfyUIError(f"Failed to download image: HTTP {resp.status}")
return await resp.read()
# --- Monitoring ---
async def get_history(self, prompt_id: str) -> dict:
"""Get execution history for a prompt."""
session = await self._get_session()
async with session.get(f"{self.base_url}/history/{prompt_id}") as resp:
return await resp.json()
async def get_system_stats(self) -> dict:
"""Get system stats including GPU VRAM info."""
session = await self._get_session()
async with session.get(f"{self.base_url}/system_stats") as resp:
return await resp.json()
async def get_queue_info(self) -> dict:
"""Get current queue state (running + pending)."""
session = await self._get_session()
async with session.get(f"{self.base_url}/prompt") as resp:
return await resp.json()
async def get_queue_depth(self) -> int:
"""Get number of pending items in the queue."""
info = await self.get_queue_info()
return len(info.get("queue_pending", []))
async def get_vram_free_gb(self) -> float | None:
"""Get free VRAM in GB, or None if unavailable."""
try:
stats = await self.get_system_stats()
devices = stats.get("devices", [])
if devices:
return devices[0].get("vram_free", 0) / (1024**3)
except Exception:
logger.warning("Failed to get VRAM stats", exc_info=True)
return None
async def is_available(self) -> bool:
"""Check if ComfyUI is reachable."""
try:
session = await self._get_session()
async with session.get(
f"{self.base_url}/system_stats", timeout=aiohttp.ClientTimeout(total=5)
) as resp:
return resp.status == 200
except Exception:
return False
async def upload_image(
self, image_bytes: bytes, filename: str, overwrite: bool = True
) -> str:
"""Upload an image to ComfyUI's input directory. Returns the stored filename."""
session = await self._get_session()
data = aiohttp.FormData()
data.add_field(
"image", image_bytes, filename=filename, content_type="image/png"
)
data.add_field("overwrite", str(overwrite).lower())
async with session.post(f"{self.base_url}/upload/image", data=data) as resp:
if resp.status != 200:
body = await resp.text()
raise ComfyUIError(f"Image upload failed (HTTP {resp.status}): {body}")
result = await resp.json()
return result.get("name", filename)
async def get_models(self, folder: str = "loras") -> list[str]:
"""List available models in a folder (loras, checkpoints, etc.)."""
session = await self._get_session()
async with session.get(f"{self.base_url}/models/{folder}") as resp:
if resp.status == 200:
return await resp.json()
return []