Spaces:
Build error
Build error
| import json | |
| import logging | |
| import random | |
| import urllib.parse | |
| from typing import Optional | |
| import aiohttp | |
| from pydantic import BaseModel | |
| from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL | |
| from open_webui.utils.session_pool import get_session | |
| log = logging.getLogger(__name__) | |
| default_headers = {'User-Agent': 'Mozilla/5.0'} | |
| async def queue_prompt(prompt, client_id, base_url, api_key): | |
| log.info('queue_prompt') | |
| p = {'prompt': prompt, 'client_id': client_id} | |
| log.debug(f'queue_prompt data: {p}') | |
| try: | |
| session = await get_session() | |
| async with session.post( | |
| f'{base_url}/prompt', | |
| json=p, | |
| headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as r: | |
| r.raise_for_status() | |
| return await r.json() | |
| except Exception as e: | |
| log.exception(f'Error while queuing prompt: {e}') | |
| raise | |
| async def get_image(filename, subfolder, folder_type, base_url, api_key): | |
| log.info('get_image') | |
| data = {'filename': filename, 'subfolder': subfolder, 'type': folder_type} | |
| url_values = urllib.parse.urlencode(data) | |
| session = await get_session() | |
| async with session.get( | |
| f'{base_url}/view?{url_values}', | |
| headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as r: | |
| r.raise_for_status() | |
| return await r.read() | |
| def get_image_url(filename, subfolder, folder_type, base_url): | |
| log.info('get_image') | |
| data = {'filename': filename, 'subfolder': subfolder, 'type': folder_type} | |
| url_values = urllib.parse.urlencode(data) | |
| return f'{base_url}/view?{url_values}' | |
| async def get_history(prompt_id, base_url, api_key): | |
| log.info('get_history') | |
| session = await get_session() | |
| async with session.get( | |
| f'{base_url}/history/{prompt_id}', | |
| headers={**default_headers, 'Authorization': f'Bearer {api_key}'}, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as r: | |
| r.raise_for_status() | |
| return await r.json() | |
| async def _ws_get_images(ws, workflow, client_id, base_url, api_key): | |
| """Queue a prompt and wait on *ws* for ComfyUI to finish executing it. | |
| Returns a dict of ``{'data': [{'url': ...}, ...]}``. | |
| """ | |
| prompt_id = (await queue_prompt(workflow, client_id, base_url, api_key))['prompt_id'] | |
| output_images = [] | |
| async for msg in ws: | |
| if msg.type == aiohttp.WSMsgType.TEXT: | |
| message = json.loads(msg.data) | |
| if message['type'] == 'executing': | |
| data = message['data'] | |
| if data['node'] is None and data['prompt_id'] == prompt_id: | |
| break # Execution is done | |
| elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): | |
| log.error(f'WebSocket closed unexpectedly: {msg.type}') | |
| break | |
| # binary messages (previews) are silently skipped | |
| history = (await get_history(prompt_id, base_url, api_key))[prompt_id] | |
| for node_id in history['outputs']: | |
| node_output = history['outputs'][node_id] | |
| if node_id in workflow and workflow[node_id].get('class_type') in [ | |
| 'SaveImage', | |
| 'PreviewImage', | |
| ]: | |
| if 'images' in node_output: | |
| for image in node_output['images']: | |
| url = get_image_url(image['filename'], image['subfolder'], image['type'], base_url) | |
| output_images.append({'url': url}) | |
| return {'data': output_images} | |
| async def comfyui_upload_image(image_file_item, base_url, api_key): | |
| url = f'{base_url}/api/upload/image' | |
| headers = {} | |
| if api_key: | |
| headers['Authorization'] = f'Bearer {api_key}' | |
| _, (filename, file_bytes, mime_type) = image_file_item | |
| form = aiohttp.FormData() | |
| form.add_field('image', file_bytes, filename=filename, content_type=mime_type) | |
| form.add_field('type', 'input') # required by ComfyUI | |
| session = await get_session() | |
| async with session.post(url, data=form, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL) as resp: | |
| resp.raise_for_status() | |
| return await resp.json() | |
| class ComfyUINodeInput(BaseModel): | |
| type: Optional[str] = None | |
| node_ids: list[str] = [] | |
| key: Optional[str] = 'text' | |
| value: Optional[str] = None | |
| class ComfyUIWorkflow(BaseModel): | |
| workflow: str | |
| nodes: list[ComfyUINodeInput] | |
| class ComfyUICreateImageForm(BaseModel): | |
| workflow: ComfyUIWorkflow | |
| prompt: str | |
| negative_prompt: Optional[str] = None | |
| width: int | |
| height: int | |
| n: int = 1 | |
| steps: Optional[int] = None | |
| seed: Optional[int] = None | |
| def _apply_workflow_nodes(workflow, nodes, model, payload): | |
| """Mutate *workflow* dict in-place based on typed node definitions.""" | |
| for node in nodes: | |
| if node.type: | |
| if node.type == 'model': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key] = model | |
| elif node.type == 'prompt': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.prompt | |
| elif node.type == 'negative_prompt': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'text'] = payload.negative_prompt | |
| elif node.type == 'image': | |
| if isinstance(payload.image, list): | |
| for idx, node_id in enumerate(node.node_ids): | |
| if idx < len(payload.image): | |
| workflow[node_id]['inputs'][node.key] = payload.image[idx] | |
| else: | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key] = payload.image | |
| elif node.type == 'width': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'width'] = payload.width | |
| elif node.type == 'height': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'height'] = payload.height | |
| elif node.type == 'n': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'batch_size'] = payload.n | |
| elif node.type == 'steps': | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key if node.key else 'steps'] = payload.steps | |
| elif node.type == 'seed': | |
| seed = payload.seed if payload.seed else random.randint(0, 1125899906842624) | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key] = seed | |
| else: | |
| for node_id in node.node_ids: | |
| workflow[node_id]['inputs'][node.key] = node.value | |
| async def comfyui_create_image(model: str, payload: ComfyUICreateImageForm, client_id, base_url, api_key): | |
| ws_url = base_url.replace('http://', 'ws://').replace('https://', 'wss://') | |
| workflow = json.loads(payload.workflow.workflow) | |
| _apply_workflow_nodes(workflow, payload.workflow.nodes, model, payload) | |
| headers = {'Authorization': f'Bearer {api_key}'} | |
| session = await get_session() | |
| try: | |
| async with session.ws_connect( | |
| f'{ws_url}/ws?clientId={client_id}', | |
| headers=headers, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as ws: | |
| log.info('WebSocket connection established.') | |
| log.info('Sending workflow to WebSocket server.') | |
| log.info(f'Workflow: {workflow}') | |
| images = await _ws_get_images(ws, workflow, client_id, base_url, api_key) | |
| except aiohttp.WSServerHandshakeError as e: | |
| log.exception(f'Failed to connect to WebSocket server: {e}') | |
| return None | |
| except Exception as e: | |
| log.exception(f'Error during image generation: {e}') | |
| return None | |
| return images | |
| class ComfyUIEditImageForm(BaseModel): | |
| workflow: ComfyUIWorkflow | |
| image: str | list[str] | |
| prompt: str | |
| width: Optional[int] = None | |
| height: Optional[int] = None | |
| n: Optional[int] = None | |
| steps: Optional[int] = None | |
| seed: Optional[int] = None | |
| async def comfyui_edit_image(model: str, payload: ComfyUIEditImageForm, client_id, base_url, api_key): | |
| ws_url = base_url.replace('http://', 'ws://').replace('https://', 'wss://') | |
| workflow = json.loads(payload.workflow.workflow) | |
| _apply_workflow_nodes(workflow, payload.workflow.nodes, model, payload) | |
| headers = {'Authorization': f'Bearer {api_key}'} | |
| session = await get_session() | |
| try: | |
| async with session.ws_connect( | |
| f'{ws_url}/ws?clientId={client_id}', | |
| headers=headers, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as ws: | |
| log.info('WebSocket connection established.') | |
| log.info('Sending workflow to WebSocket server.') | |
| log.info(f'Workflow: {workflow}') | |
| images = await _ws_get_images(ws, workflow, client_id, base_url, api_key) | |
| except aiohttp.WSServerHandshakeError as e: | |
| log.exception(f'Failed to connect to WebSocket server: {e}') | |
| return None | |
| except Exception as e: | |
| log.exception(f'Error during image editing: {e}') | |
| return None | |
| return images | |