Spaces:
Running
Running
File size: 8,452 Bytes
ed37502 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """Async ComfyUI API client using aiohttp for HTTP and WebSocket communication.
Based on the pattern from ComfyUI's own websockets_api_example.py.
Communicates with ComfyUI at http://127.0.0.1:8188.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
import aiohttp
logger = logging.getLogger(__name__)
@dataclass
class ComfyUIResult:
"""Result from a completed ComfyUI generation."""
prompt_id: str
outputs: dict[str, Any] = field(default_factory=dict)
images: list[ImageOutput] = field(default_factory=list)
@dataclass
class ImageOutput:
"""A single output image from ComfyUI."""
filename: str
subfolder: str
type: str # "output" or "temp"
class ComfyUIError(Exception):
"""Raised when ComfyUI returns an error."""
class ComfyUIClient:
"""Async client for the ComfyUI API.
Usage:
client = ComfyUIClient("http://127.0.0.1:8188")
result = await client.generate(workflow_dict)
image_bytes = await client.download_image(result.images[0])
"""
def __init__(self, base_url: str = "http://127.0.0.1:8188"):
self.base_url = base_url.rstrip("/")
self.client_id = str(uuid.uuid4())
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
if self._session and not self._session.closed:
await self._session.close()
# --- Core generation ---
async def queue_prompt(self, workflow: dict) -> str:
"""Submit a workflow to ComfyUI. Returns the prompt_id."""
prompt_id = str(uuid.uuid4())
payload = {
"prompt": workflow,
"client_id": self.client_id,
"prompt_id": prompt_id,
}
session = await self._get_session()
async with session.post(f"{self.base_url}/prompt", json=payload) as resp:
if resp.status != 200:
body = await resp.text()
raise ComfyUIError(f"Prompt rejected (HTTP {resp.status}): {body}")
data = await resp.json()
return data.get("prompt_id", prompt_id)
async def wait_for_completion(
self, prompt_id: str, timeout: float = 600
) -> ComfyUIResult:
"""Wait for a prompt to finish via WebSocket, then fetch results."""
ws_host = self.base_url.replace("http://", "").replace("https://", "")
ws_url = f"ws://{ws_host}/ws?clientId={self.client_id}"
session = await self._get_session()
try:
async with asyncio.timeout(timeout):
async with session.ws_connect(ws_url) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
if data.get("type") == "executing":
exec_data = data.get("data", {})
if (
exec_data.get("node") is None
and exec_data.get("prompt_id") == prompt_id
):
break
# Binary messages are latent previews — skip
except TimeoutError:
raise ComfyUIError(
f"Timeout waiting for prompt {prompt_id} after {timeout}s"
)
return await self._fetch_result(prompt_id)
async def generate(self, workflow: dict, timeout: float = 600) -> ComfyUIResult:
"""Submit workflow and wait for completion. Returns the result."""
prompt_id = await self.queue_prompt(workflow)
logger.info("Queued prompt %s", prompt_id)
return await self.wait_for_completion(prompt_id, timeout)
# --- Result fetching ---
async def _fetch_result(self, prompt_id: str) -> ComfyUIResult:
"""Fetch history for a completed prompt and extract image outputs."""
history = await self.get_history(prompt_id)
prompt_history = history.get(prompt_id, {})
outputs = prompt_history.get("outputs", {})
images: list[ImageOutput] = []
for _node_id, node_output in outputs.items():
for img_info in node_output.get("images", []):
images.append(
ImageOutput(
filename=img_info["filename"],
subfolder=img_info.get("subfolder", ""),
type=img_info.get("type", "output"),
)
)
return ComfyUIResult(
prompt_id=prompt_id,
outputs=outputs,
images=images,
)
async def download_image(self, image: ImageOutput) -> bytes:
"""Download an output image from ComfyUI."""
params = {
"filename": image.filename,
"subfolder": image.subfolder,
"type": image.type,
}
session = await self._get_session()
async with session.get(f"{self.base_url}/view", params=params) as resp:
if resp.status != 200:
raise ComfyUIError(f"Failed to download image: HTTP {resp.status}")
return await resp.read()
# --- Monitoring ---
async def get_history(self, prompt_id: str) -> dict:
"""Get execution history for a prompt."""
session = await self._get_session()
async with session.get(f"{self.base_url}/history/{prompt_id}") as resp:
return await resp.json()
async def get_system_stats(self) -> dict:
"""Get system stats including GPU VRAM info."""
session = await self._get_session()
async with session.get(f"{self.base_url}/system_stats") as resp:
return await resp.json()
async def get_queue_info(self) -> dict:
"""Get current queue state (running + pending)."""
session = await self._get_session()
async with session.get(f"{self.base_url}/prompt") as resp:
return await resp.json()
async def get_queue_depth(self) -> int:
"""Get number of pending items in the queue."""
info = await self.get_queue_info()
return len(info.get("queue_pending", []))
async def get_vram_free_gb(self) -> float | None:
"""Get free VRAM in GB, or None if unavailable."""
try:
stats = await self.get_system_stats()
devices = stats.get("devices", [])
if devices:
return devices[0].get("vram_free", 0) / (1024**3)
except Exception:
logger.warning("Failed to get VRAM stats", exc_info=True)
return None
async def is_available(self) -> bool:
"""Check if ComfyUI is reachable."""
try:
session = await self._get_session()
async with session.get(
f"{self.base_url}/system_stats", timeout=aiohttp.ClientTimeout(total=5)
) as resp:
return resp.status == 200
except Exception:
return False
async def upload_image(
self, image_bytes: bytes, filename: str, overwrite: bool = True
) -> str:
"""Upload an image to ComfyUI's input directory. Returns the stored filename."""
session = await self._get_session()
data = aiohttp.FormData()
data.add_field(
"image", image_bytes, filename=filename, content_type="image/png"
)
data.add_field("overwrite", str(overwrite).lower())
async with session.post(f"{self.base_url}/upload/image", data=data) as resp:
if resp.status != 200:
body = await resp.text()
raise ComfyUIError(f"Image upload failed (HTTP {resp.status}): {body}")
result = await resp.json()
return result.get("name", filename)
async def get_models(self, folder: str = "loras") -> list[str]:
"""List available models in a folder (loras, checkpoints, etc.)."""
session = await self._get_session()
async with session.get(f"{self.base_url}/models/{folder}") as resp:
if resp.status == 200:
return await resp.json()
return []
|