| import asyncio |
| import json |
| import logging |
| import random |
| import urllib.parse |
| import urllib.request |
| from typing import Optional |
|
|
| import websocket |
| from open_webui.env import SRC_LOG_LEVELS |
| from pydantic import BaseModel |
|
|
| log = logging.getLogger(__name__) |
| log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) |
|
|
| default_headers = {"User-Agent": "Mozilla/5.0"} |
|
|
|
|
| def queue_prompt(prompt, client_id, base_url, api_key): |
| log.info("queue_prompt") |
| p = {"prompt": prompt, "client_id": client_id} |
| data = json.dumps(p).encode("utf-8") |
| log.debug(f"queue_prompt data: {data}") |
| try: |
| req = urllib.request.Request( |
| f"{base_url}/prompt", |
| data=data, |
| headers={**default_headers, "Authorization": f"Bearer {api_key}"}, |
| ) |
| response = urllib.request.urlopen(req).read() |
| return json.loads(response) |
| except Exception as e: |
| log.exception(f"Error while queuing prompt: {e}") |
| raise e |
|
|
|
|
| def get_image(filename, subfolder, folder_type, base_url, api_key): |
| log.info("get_image") |
| data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| url_values = urllib.parse.urlencode(data) |
| req = urllib.request.Request( |
| f"{base_url}/view?{url_values}", |
| headers={**default_headers, "Authorization": f"Bearer {api_key}"}, |
| ) |
| with urllib.request.urlopen(req) as response: |
| return response.read() |
|
|
|
|
| def get_image_url(filename, subfolder, folder_type, base_url): |
| log.info("get_image") |
| data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| url_values = urllib.parse.urlencode(data) |
| return f"{base_url}/view?{url_values}" |
|
|
|
|
| def get_history(prompt_id, base_url, api_key): |
| log.info("get_history") |
|
|
| req = urllib.request.Request( |
| f"{base_url}/history/{prompt_id}", |
| headers={**default_headers, "Authorization": f"Bearer {api_key}"}, |
| ) |
| with urllib.request.urlopen(req) as response: |
| return json.loads(response.read()) |
|
|
|
|
| def get_images(ws, prompt, client_id, base_url, api_key): |
| prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"] |
| output_images = [] |
| while True: |
| out = ws.recv() |
| if isinstance(out, str): |
| message = json.loads(out) |
| if message["type"] == "executing": |
| data = message["data"] |
| if data["node"] is None and data["prompt_id"] == prompt_id: |
| break |
| else: |
| continue |
|
|
| history = get_history(prompt_id, base_url, api_key)[prompt_id] |
| for o in history["outputs"]: |
| for node_id in history["outputs"]: |
| node_output = history["outputs"][node_id] |
| if "images" in node_output: |
| for image in node_output["images"]: |
| url = get_image_url( |
| image["filename"], image["subfolder"], image["type"], base_url |
| ) |
| output_images.append({"url": url}) |
| return {"data": output_images} |
|
|
|
|
| class ComfyUINodeInput(BaseModel): |
| type: Optional[str] = None |
| node_ids: list[str] = [] |
| key: Optional[str] = "text" |
| value: Optional[str] = None |
|
|
|
|
| class ComfyUIWorkflow(BaseModel): |
| workflow: str |
| nodes: list[ComfyUINodeInput] |
|
|
|
|
| class ComfyUIGenerateImageForm(BaseModel): |
| workflow: ComfyUIWorkflow |
|
|
| prompt: str |
| negative_prompt: Optional[str] = None |
| width: int |
| height: int |
| n: int = 1 |
|
|
| steps: Optional[int] = None |
| seed: Optional[int] = None |
|
|
|
|
| async def comfyui_generate_image( |
| model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key |
| ): |
| ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") |
| workflow = json.loads(payload.workflow.workflow) |
|
|
| for node in payload.workflow.nodes: |
| if node.type: |
| if node.type == "model": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][node.key] = model |
| elif node.type == "prompt": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "text" |
| ] = payload.prompt |
| elif node.type == "negative_prompt": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "text" |
| ] = payload.negative_prompt |
| elif node.type == "width": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "width" |
| ] = payload.width |
| elif node.type == "height": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "height" |
| ] = payload.height |
| elif node.type == "n": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "batch_size" |
| ] = payload.n |
| elif node.type == "steps": |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][ |
| node.key if node.key else "steps" |
| ] = payload.steps |
| elif node.type == "seed": |
| seed = ( |
| payload.seed |
| if payload.seed |
| else random.randint(0, 18446744073709551614) |
| ) |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][node.key] = seed |
| else: |
| for node_id in node.node_ids: |
| workflow[node_id]["inputs"][node.key] = node.value |
|
|
| try: |
| ws = websocket.WebSocket() |
| headers = {"Authorization": f"Bearer {api_key}"} |
| ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers) |
| log.info("WebSocket connection established.") |
| except Exception as e: |
| log.exception(f"Failed to connect to WebSocket server: {e}") |
| return None |
|
|
| try: |
| log.info("Sending workflow to WebSocket server.") |
| log.info(f"Workflow: {workflow}") |
| images = await asyncio.to_thread( |
| get_images, ws, workflow, client_id, base_url, api_key |
| ) |
| except Exception as e: |
| log.exception(f"Error while receiving images: {e}") |
| images = None |
|
|
| ws.close() |
|
|
| return images |
|
|