File size: 8,452 Bytes
ed37502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Async ComfyUI API client using aiohttp for HTTP and WebSocket communication.

Based on the pattern from ComfyUI's own websockets_api_example.py.
Communicates with ComfyUI at http://127.0.0.1:8188.
"""

from __future__ import annotations

import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any

import aiohttp

logger = logging.getLogger(__name__)


@dataclass
class ComfyUIResult:
    """Result from a completed ComfyUI generation."""

    prompt_id: str
    outputs: dict[str, Any] = field(default_factory=dict)
    images: list[ImageOutput] = field(default_factory=list)


@dataclass
class ImageOutput:
    """A single output image from ComfyUI."""

    filename: str
    subfolder: str
    type: str  # "output" or "temp"


class ComfyUIError(Exception):
    """Raised when ComfyUI returns an error."""


class ComfyUIClient:
    """Async client for the ComfyUI API.

    Usage:
        client = ComfyUIClient("http://127.0.0.1:8188")
        result = await client.generate(workflow_dict)
        image_bytes = await client.download_image(result.images[0])
    """

    def __init__(self, base_url: str = "http://127.0.0.1:8188"):
        self.base_url = base_url.rstrip("/")
        self.client_id = str(uuid.uuid4())
        self._session: aiohttp.ClientSession | None = None

    async def _get_session(self) -> aiohttp.ClientSession:
        if self._session is None or self._session.closed:
            self._session = aiohttp.ClientSession()
        return self._session

    async def close(self) -> None:
        if self._session and not self._session.closed:
            await self._session.close()

    # --- Core generation ---

    async def queue_prompt(self, workflow: dict) -> str:
        """Submit a workflow to ComfyUI. Returns the prompt_id."""
        prompt_id = str(uuid.uuid4())
        payload = {
            "prompt": workflow,
            "client_id": self.client_id,
            "prompt_id": prompt_id,
        }
        session = await self._get_session()
        async with session.post(f"{self.base_url}/prompt", json=payload) as resp:
            if resp.status != 200:
                body = await resp.text()
                raise ComfyUIError(f"Prompt rejected (HTTP {resp.status}): {body}")
            data = await resp.json()
            return data.get("prompt_id", prompt_id)

    async def wait_for_completion(
        self, prompt_id: str, timeout: float = 600
    ) -> ComfyUIResult:
        """Wait for a prompt to finish via WebSocket, then fetch results."""
        ws_host = self.base_url.replace("http://", "").replace("https://", "")
        ws_url = f"ws://{ws_host}/ws?clientId={self.client_id}"

        session = await self._get_session()
        try:
            async with asyncio.timeout(timeout):
                async with session.ws_connect(ws_url) as ws:
                    async for msg in ws:
                        if msg.type == aiohttp.WSMsgType.TEXT:
                            data = json.loads(msg.data)
                            if data.get("type") == "executing":
                                exec_data = data.get("data", {})
                                if (
                                    exec_data.get("node") is None
                                    and exec_data.get("prompt_id") == prompt_id
                                ):
                                    break
                        # Binary messages are latent previews — skip
        except TimeoutError:
            raise ComfyUIError(
                f"Timeout waiting for prompt {prompt_id} after {timeout}s"
            )

        return await self._fetch_result(prompt_id)

    async def generate(self, workflow: dict, timeout: float = 600) -> ComfyUIResult:
        """Submit workflow and wait for completion. Returns the result."""
        prompt_id = await self.queue_prompt(workflow)
        logger.info("Queued prompt %s", prompt_id)
        return await self.wait_for_completion(prompt_id, timeout)

    # --- Result fetching ---

    async def _fetch_result(self, prompt_id: str) -> ComfyUIResult:
        """Fetch history for a completed prompt and extract image outputs."""
        history = await self.get_history(prompt_id)
        prompt_history = history.get(prompt_id, {})
        outputs = prompt_history.get("outputs", {})

        images: list[ImageOutput] = []
        for _node_id, node_output in outputs.items():
            for img_info in node_output.get("images", []):
                images.append(
                    ImageOutput(
                        filename=img_info["filename"],
                        subfolder=img_info.get("subfolder", ""),
                        type=img_info.get("type", "output"),
                    )
                )

        return ComfyUIResult(
            prompt_id=prompt_id,
            outputs=outputs,
            images=images,
        )

    async def download_image(self, image: ImageOutput) -> bytes:
        """Download an output image from ComfyUI."""
        params = {
            "filename": image.filename,
            "subfolder": image.subfolder,
            "type": image.type,
        }
        session = await self._get_session()
        async with session.get(f"{self.base_url}/view", params=params) as resp:
            if resp.status != 200:
                raise ComfyUIError(f"Failed to download image: HTTP {resp.status}")
            return await resp.read()

    # --- Monitoring ---

    async def get_history(self, prompt_id: str) -> dict:
        """Get execution history for a prompt."""
        session = await self._get_session()
        async with session.get(f"{self.base_url}/history/{prompt_id}") as resp:
            return await resp.json()

    async def get_system_stats(self) -> dict:
        """Get system stats including GPU VRAM info."""
        session = await self._get_session()
        async with session.get(f"{self.base_url}/system_stats") as resp:
            return await resp.json()

    async def get_queue_info(self) -> dict:
        """Get current queue state (running + pending)."""
        session = await self._get_session()
        async with session.get(f"{self.base_url}/prompt") as resp:
            return await resp.json()

    async def get_queue_depth(self) -> int:
        """Get number of pending items in the queue."""
        info = await self.get_queue_info()
        return len(info.get("queue_pending", []))

    async def get_vram_free_gb(self) -> float | None:
        """Get free VRAM in GB, or None if unavailable."""
        try:
            stats = await self.get_system_stats()
            devices = stats.get("devices", [])
            if devices:
                return devices[0].get("vram_free", 0) / (1024**3)
        except Exception:
            logger.warning("Failed to get VRAM stats", exc_info=True)
        return None

    async def is_available(self) -> bool:
        """Check if ComfyUI is reachable."""
        try:
            session = await self._get_session()
            async with session.get(
                f"{self.base_url}/system_stats", timeout=aiohttp.ClientTimeout(total=5)
            ) as resp:
                return resp.status == 200
        except Exception:
            return False

    async def upload_image(
        self, image_bytes: bytes, filename: str, overwrite: bool = True
    ) -> str:
        """Upload an image to ComfyUI's input directory. Returns the stored filename."""
        session = await self._get_session()
        data = aiohttp.FormData()
        data.add_field(
            "image", image_bytes, filename=filename, content_type="image/png"
        )
        data.add_field("overwrite", str(overwrite).lower())
        async with session.post(f"{self.base_url}/upload/image", data=data) as resp:
            if resp.status != 200:
                body = await resp.text()
                raise ComfyUIError(f"Image upload failed (HTTP {resp.status}): {body}")
            result = await resp.json()
            return result.get("name", filename)

    async def get_models(self, folder: str = "loras") -> list[str]:
        """List available models in a folder (loras, checkpoints, etc.)."""
        session = await self._get_session()
        async with session.get(f"{self.base_url}/models/{folder}") as resp:
            if resp.status == 200:
                return await resp.json()
            return []