| import os |
| import sys |
| import asyncio |
| import traceback |
|
|
| import nodes |
| import folder_paths |
| import execution |
| import uuid |
| import urllib |
| import json |
| import glob |
| import struct |
| import ssl |
| import socket |
| import ipaddress |
| from PIL import Image, ImageOps |
| from PIL.PngImagePlugin import PngInfo |
| from io import BytesIO |
|
|
| import aiohttp |
| from aiohttp import web |
| import logging |
|
|
| import mimetypes |
| from comfy.cli_args import args |
| import comfy.utils |
| import comfy.model_management |
| from comfy_api import feature_flags |
| import node_helpers |
| from comfyui_version import __version__ |
| from app.frontend_management import FrontendManager |
| from comfy_api.internal import _ComfyNodeInternal |
|
|
| from app.user_manager import UserManager |
| from app.model_manager import ModelFileManager |
| from app.custom_node_manager import CustomNodeManager |
| from typing import Optional, Union |
| from api_server.routes.internal.internal_routes import InternalRoutes |
| from protocol import BinaryEventTypes |
|
|
| |
| from middleware.cache_middleware import cache_control |
|
|
| async def send_socket_catch_exception(function, message): |
| try: |
| await function(message) |
| except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: |
| logging.warning("send error: {}".format(err)) |
|
|
| @web.middleware |
| async def compress_body(request: web.Request, handler): |
| accept_encoding = request.headers.get("Accept-Encoding", "") |
| response: web.Response = await handler(request) |
| if not isinstance(response, web.Response): |
| return response |
| if response.content_type not in ["application/json", "text/plain"]: |
| return response |
| if response.body and "gzip" in accept_encoding: |
| response.enable_compression() |
| return response |
|
|
|
|
| def create_cors_middleware(allowed_origin: str): |
| @web.middleware |
| async def cors_middleware(request: web.Request, handler): |
| if request.method == "OPTIONS": |
| |
| response = web.Response() |
| else: |
| response = await handler(request) |
|
|
| response.headers['Access-Control-Allow-Origin'] = allowed_origin |
| response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' |
| response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' |
| response.headers['Access-Control-Allow-Credentials'] = 'true' |
| return response |
|
|
| return cors_middleware |
|
|
| def is_loopback(host): |
| if host is None: |
| return False |
| try: |
| if ipaddress.ip_address(host).is_loopback: |
| return True |
| else: |
| return False |
| except: |
| pass |
|
|
| loopback = False |
| for family in (socket.AF_INET, socket.AF_INET6): |
| try: |
| r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) |
| for family, _, _, _, sockaddr in r: |
| if not ipaddress.ip_address(sockaddr[0]).is_loopback: |
| return loopback |
| else: |
| loopback = True |
| except socket.gaierror: |
| pass |
|
|
| return loopback |
|
|
|
|
| def create_origin_only_middleware(): |
| @web.middleware |
| async def origin_only_middleware(request: web.Request, handler): |
| |
| |
| |
| if 'Host' in request.headers and 'Origin' in request.headers: |
| host = request.headers['Host'] |
| origin = request.headers['Origin'] |
| host_domain = host.lower() |
| parsed = urllib.parse.urlparse(origin) |
| origin_domain = parsed.netloc.lower() |
| host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) |
|
|
| |
| loopback = is_loopback(host_domain_parsed.hostname) |
|
|
| if parsed.port is None: |
| host_domain = host_domain_parsed.hostname |
| if host_domain_parsed.port is None: |
| origin_domain = parsed.hostname |
|
|
| if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: |
| if host_domain != origin_domain: |
| logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) |
| return web.Response(status=403) |
|
|
| if request.method == "OPTIONS": |
| response = web.Response() |
| else: |
| response = await handler(request) |
|
|
| return response |
|
|
| return origin_only_middleware |
|
|
| class PromptServer(): |
| def __init__(self, loop): |
| PromptServer.instance = self |
|
|
| mimetypes.init() |
| mimetypes.add_type('application/javascript; charset=utf-8', '.js') |
| mimetypes.add_type('image/webp', '.webp') |
|
|
| self.user_manager = UserManager() |
| self.model_file_manager = ModelFileManager() |
| self.custom_node_manager = CustomNodeManager() |
| self.internal_routes = InternalRoutes(self) |
| self.supports = ["custom_nodes_from_web"] |
| self.prompt_queue = execution.PromptQueue(self) |
| self.loop = loop |
| self.messages = asyncio.Queue() |
| self.client_session:Optional[aiohttp.ClientSession] = None |
| self.number = 0 |
|
|
| middlewares = [cache_control] |
| if args.enable_compress_response_body: |
| middlewares.append(compress_body) |
|
|
| if args.enable_cors_header: |
| middlewares.append(create_cors_middleware(args.enable_cors_header)) |
| else: |
| middlewares.append(create_origin_only_middleware()) |
|
|
| max_upload_size = round(args.max_upload_size * 1024 * 1024) |
| self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) |
| self.sockets = dict() |
| self.sockets_metadata = dict() |
| self.web_root = ( |
| FrontendManager.init_frontend(args.front_end_version) |
| if args.front_end_root is None |
| else args.front_end_root |
| ) |
| logging.info(f"[Prompt Server] web root: {self.web_root}") |
| routes = web.RouteTableDef() |
| self.routes = routes |
| self.last_node_id = None |
| self.client_id = None |
|
|
| self.on_prompt_handlers = [] |
|
|
| @routes.get('/ws') |
| async def websocket_handler(request): |
| ws = web.WebSocketResponse() |
| await ws.prepare(request) |
| sid = request.rel_url.query.get('clientId', '') |
| if sid: |
| |
| self.sockets.pop(sid, None) |
| else: |
| sid = uuid.uuid4().hex |
|
|
| |
| self.sockets[sid] = ws |
| |
| self.sockets_metadata[sid] = {"feature_flags": {}} |
|
|
| try: |
| |
| await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) |
| |
| if self.client_id == sid and self.last_node_id is not None: |
| await self.send("executing", { "node": self.last_node_id }, sid) |
|
|
| |
| first_message = True |
|
|
| async for msg in ws: |
| if msg.type == aiohttp.WSMsgType.ERROR: |
| logging.warning('ws connection closed with exception %s' % ws.exception()) |
| elif msg.type == aiohttp.WSMsgType.TEXT: |
| try: |
| data = json.loads(msg.data) |
| |
| if first_message and data.get("type") == "feature_flags": |
| |
| client_flags = data.get("data", {}) |
| self.sockets_metadata[sid]["feature_flags"] = client_flags |
|
|
| |
| await self.send( |
| "feature_flags", |
| feature_flags.get_server_features(), |
| sid, |
| ) |
|
|
| logging.debug( |
| f"Feature flags negotiated for client {sid}: {client_flags}" |
| ) |
| first_message = False |
| except json.JSONDecodeError: |
| logging.warning( |
| f"Invalid JSON received from client {sid}: {msg.data}" |
| ) |
| except Exception as e: |
| logging.error(f"Error processing WebSocket message: {e}") |
| finally: |
| self.sockets.pop(sid, None) |
| self.sockets_metadata.pop(sid, None) |
| return ws |
|
|
| @routes.get("/") |
| async def get_root(request): |
| response = web.FileResponse(os.path.join(self.web_root, "index.html")) |
| response.headers['Cache-Control'] = 'no-cache' |
| response.headers["Pragma"] = "no-cache" |
| response.headers["Expires"] = "0" |
| return response |
|
|
| @routes.get("/embeddings") |
| def get_embeddings(request): |
| embeddings = folder_paths.get_filename_list("embeddings") |
| return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) |
|
|
| @routes.get("/models") |
| def list_model_types(request): |
| model_types = list(folder_paths.folder_names_and_paths.keys()) |
|
|
| return web.json_response(model_types) |
|
|
| @routes.get("/models/{folder}") |
| async def get_models(request): |
| folder = request.match_info.get("folder", None) |
| if not folder in folder_paths.folder_names_and_paths: |
| return web.Response(status=404) |
| files = folder_paths.get_filename_list(folder) |
| return web.json_response(files) |
|
|
| @routes.get("/extensions") |
| async def get_extensions(request): |
| files = glob.glob(os.path.join( |
| glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) |
|
|
| extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) |
|
|
| for name, dir in nodes.EXTENSION_WEB_DIRS.items(): |
| files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) |
| extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( |
| name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) |
|
|
| return web.json_response(extensions) |
|
|
| def get_dir_by_type(dir_type): |
| if dir_type is None: |
| dir_type = "input" |
|
|
| if dir_type == "input": |
| type_dir = folder_paths.get_input_directory() |
| elif dir_type == "temp": |
| type_dir = folder_paths.get_temp_directory() |
| elif dir_type == "output": |
| type_dir = folder_paths.get_output_directory() |
|
|
| return type_dir, dir_type |
|
|
| def compare_image_hash(filepath, image): |
| hasher = node_helpers.hasher() |
|
|
| |
| if os.path.exists(filepath): |
| a = hasher() |
| b = hasher() |
| with open(filepath, "rb") as f: |
| a.update(f.read()) |
| b.update(image.file.read()) |
| image.file.seek(0) |
| return a.hexdigest() == b.hexdigest() |
| return False |
|
|
| def image_upload(post, image_save_function=None): |
| image = post.get("image") |
| overwrite = post.get("overwrite") |
| image_is_duplicate = False |
|
|
| image_upload_type = post.get("type") |
| upload_dir, image_upload_type = get_dir_by_type(image_upload_type) |
|
|
| if image and image.file: |
| filename = image.filename |
| if not filename: |
| return web.Response(status=400) |
|
|
| subfolder = post.get("subfolder", "") |
| full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) |
| filepath = os.path.abspath(os.path.join(full_output_folder, filename)) |
|
|
| if os.path.commonpath((upload_dir, filepath)) != upload_dir: |
| return web.Response(status=400) |
|
|
| if not os.path.exists(full_output_folder): |
| os.makedirs(full_output_folder) |
|
|
| split = os.path.splitext(filename) |
|
|
| if overwrite is not None and (overwrite == "true" or overwrite == "1"): |
| pass |
| else: |
| i = 1 |
| while os.path.exists(filepath): |
| if compare_image_hash(filepath, image): |
| image_is_duplicate = True |
| break |
| filename = f"{split[0]} ({i}){split[1]}" |
| filepath = os.path.join(full_output_folder, filename) |
| i += 1 |
|
|
| if not image_is_duplicate: |
| if image_save_function is not None: |
| image_save_function(image, post, filepath) |
| else: |
| with open(filepath, "wb") as f: |
| f.write(image.file.read()) |
|
|
| return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) |
| else: |
| return web.Response(status=400) |
|
|
| @routes.post("/upload/image") |
| async def upload_image(request): |
| post = await request.post() |
| return image_upload(post) |
|
|
|
|
| @routes.post("/upload/mask") |
| async def upload_mask(request): |
| post = await request.post() |
|
|
| def image_save_function(image, post, filepath): |
| original_ref = json.loads(post.get("original_ref")) |
| filename, output_dir = folder_paths.annotated_filepath(original_ref['filename']) |
|
|
| if not filename: |
| return web.Response(status=400) |
|
|
| |
| if filename[0] == '/' or '..' in filename: |
| return web.Response(status=400) |
|
|
| if output_dir is None: |
| type = original_ref.get("type", "output") |
| output_dir = folder_paths.get_directory_by_type(type) |
|
|
| if output_dir is None: |
| return web.Response(status=400) |
|
|
| if original_ref.get("subfolder", "") != "": |
| full_output_dir = os.path.join(output_dir, original_ref["subfolder"]) |
| if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: |
| return web.Response(status=403) |
| output_dir = full_output_dir |
|
|
| file = os.path.join(output_dir, filename) |
|
|
| if os.path.isfile(file): |
| with Image.open(file) as original_pil: |
| metadata = PngInfo() |
| if hasattr(original_pil,'text'): |
| for key in original_pil.text: |
| metadata.add_text(key, original_pil.text[key]) |
| original_pil = original_pil.convert('RGBA') |
| mask_pil = Image.open(image.file).convert('RGBA') |
|
|
| |
| new_alpha = mask_pil.getchannel('A') |
| original_pil.putalpha(new_alpha) |
| original_pil.save(filepath, compress_level=4, pnginfo=metadata) |
|
|
| return image_upload(post, image_save_function) |
|
|
| @routes.get("/view") |
| async def view_image(request): |
| if "filename" in request.rel_url.query: |
| filename = request.rel_url.query["filename"] |
| filename, output_dir = folder_paths.annotated_filepath(filename) |
|
|
| if not filename: |
| return web.Response(status=400) |
|
|
| |
| if filename[0] == '/' or '..' in filename: |
| return web.Response(status=400) |
|
|
| if output_dir is None: |
| type = request.rel_url.query.get("type", "output") |
| output_dir = folder_paths.get_directory_by_type(type) |
|
|
| if output_dir is None: |
| return web.Response(status=400) |
|
|
| if "subfolder" in request.rel_url.query: |
| full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) |
| if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: |
| return web.Response(status=403) |
| output_dir = full_output_dir |
|
|
| filename = os.path.basename(filename) |
| file = os.path.join(output_dir, filename) |
|
|
| if os.path.isfile(file): |
| if 'preview' in request.rel_url.query: |
| with Image.open(file) as img: |
| preview_info = request.rel_url.query['preview'].split(';') |
| image_format = preview_info[0] |
| if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''): |
| image_format = 'webp' |
|
|
| quality = 90 |
| if preview_info[-1].isdigit(): |
| quality = int(preview_info[-1]) |
|
|
| buffer = BytesIO() |
| if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb': |
| img = img.convert("RGB") |
| img.save(buffer, format=image_format, quality=quality) |
| buffer.seek(0) |
|
|
| return web.Response(body=buffer.read(), content_type=f'image/{image_format}', |
| headers={"Content-Disposition": f"filename=\"{filename}\""}) |
|
|
| if 'channel' not in request.rel_url.query: |
| channel = 'rgba' |
| else: |
| channel = request.rel_url.query["channel"] |
|
|
| if channel == 'rgb': |
| with Image.open(file) as img: |
| if img.mode == "RGBA": |
| r, g, b, a = img.split() |
| new_img = Image.merge('RGB', (r, g, b)) |
| else: |
| new_img = img.convert("RGB") |
|
|
| buffer = BytesIO() |
| new_img.save(buffer, format='PNG') |
| buffer.seek(0) |
|
|
| return web.Response(body=buffer.read(), content_type='image/png', |
| headers={"Content-Disposition": f"filename=\"{filename}\""}) |
|
|
| elif channel == 'a': |
| with Image.open(file) as img: |
| if img.mode == "RGBA": |
| _, _, _, a = img.split() |
| else: |
| a = Image.new('L', img.size, 255) |
|
|
| |
| alpha_img = Image.new('RGBA', img.size) |
| alpha_img.putalpha(a) |
| alpha_buffer = BytesIO() |
| alpha_img.save(alpha_buffer, format='PNG') |
| alpha_buffer.seek(0) |
|
|
| return web.Response(body=alpha_buffer.read(), content_type='image/png', |
| headers={"Content-Disposition": f"filename=\"{filename}\""}) |
| else: |
| |
| content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' |
|
|
| |
| if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: |
| content_type = 'application/octet-stream' |
|
|
| return web.FileResponse( |
| file, |
| headers={ |
| "Content-Disposition": f"filename=\"{filename}\"", |
| "Content-Type": content_type |
| } |
| ) |
|
|
| return web.Response(status=404) |
|
|
| @routes.get("/view_metadata/{folder_name}") |
| async def view_metadata(request): |
| folder_name = request.match_info.get("folder_name", None) |
| if folder_name is None: |
| return web.Response(status=404) |
| if not "filename" in request.rel_url.query: |
| return web.Response(status=404) |
|
|
| filename = request.rel_url.query["filename"] |
| if not filename.endswith(".safetensors"): |
| return web.Response(status=404) |
|
|
| safetensors_path = folder_paths.get_full_path(folder_name, filename) |
| if safetensors_path is None: |
| return web.Response(status=404) |
| out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) |
| if out is None: |
| return web.Response(status=404) |
| dt = json.loads(out) |
| if not "__metadata__" in dt: |
| return web.Response(status=404) |
| return web.json_response(dt["__metadata__"]) |
|
|
| @routes.get("/system_stats") |
| async def system_stats(request): |
| device = comfy.model_management.get_torch_device() |
| device_name = comfy.model_management.get_torch_device_name(device) |
| cpu_device = comfy.model_management.torch.device("cpu") |
| ram_total = comfy.model_management.get_total_memory(cpu_device) |
| ram_free = comfy.model_management.get_free_memory(cpu_device) |
| vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) |
| vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) |
| required_frontend_version = FrontendManager.get_required_frontend_version() |
|
|
| system_stats = { |
| "system": { |
| "os": os.name, |
| "ram_total": ram_total, |
| "ram_free": ram_free, |
| "comfyui_version": __version__, |
| "required_frontend_version": required_frontend_version, |
| "python_version": sys.version, |
| "pytorch_version": comfy.model_management.torch_version, |
| "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", |
| "argv": sys.argv |
| }, |
| "devices": [ |
| { |
| "name": device_name, |
| "type": device.type, |
| "index": device.index, |
| "vram_total": vram_total, |
| "vram_free": vram_free, |
| "torch_vram_total": torch_vram_total, |
| "torch_vram_free": torch_vram_free, |
| } |
| ] |
| } |
| return web.json_response(system_stats) |
|
|
| @routes.get("/features") |
| async def get_features(request): |
| return web.json_response(feature_flags.get_server_features()) |
|
|
| @routes.get("/prompt") |
| async def get_prompt(request): |
| return web.json_response(self.get_queue_info()) |
|
|
| def node_info(node_class): |
| obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] |
| if issubclass(obj_class, _ComfyNodeInternal): |
| return obj_class.GET_NODE_INFO_V1() |
| info = {} |
| info['input'] = obj_class.INPUT_TYPES() |
| info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} |
| info['output'] = obj_class.RETURN_TYPES |
| info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) |
| info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] |
| info['name'] = node_class |
| info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class |
| info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' |
| info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") |
| info['category'] = 'sd' |
| if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: |
| info['output_node'] = True |
| else: |
| info['output_node'] = False |
|
|
| if hasattr(obj_class, 'CATEGORY'): |
| info['category'] = obj_class.CATEGORY |
|
|
| if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): |
| info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS |
|
|
| if getattr(obj_class, "DEPRECATED", False): |
| info['deprecated'] = True |
| if getattr(obj_class, "EXPERIMENTAL", False): |
| info['experimental'] = True |
|
|
| if hasattr(obj_class, 'API_NODE'): |
| info['api_node'] = obj_class.API_NODE |
| return info |
|
|
| @routes.get("/object_info") |
| async def get_object_info(request): |
| with folder_paths.cache_helper: |
| out = {} |
| for x in nodes.NODE_CLASS_MAPPINGS: |
| try: |
| out[x] = node_info(x) |
| except Exception: |
| logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") |
| logging.error(traceback.format_exc()) |
| return web.json_response(out) |
|
|
| @routes.get("/object_info/{node_class}") |
| async def get_object_info_node(request): |
| node_class = request.match_info.get("node_class", None) |
| out = {} |
| if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): |
| out[node_class] = node_info(node_class) |
| return web.json_response(out) |
|
|
| @routes.get("/history") |
| async def get_history(request): |
| max_items = request.rel_url.query.get("max_items", None) |
| if max_items is not None: |
| max_items = int(max_items) |
| return web.json_response(self.prompt_queue.get_history(max_items=max_items)) |
|
|
| @routes.get("/history/{prompt_id}") |
| async def get_history_prompt_id(request): |
| prompt_id = request.match_info.get("prompt_id", None) |
| return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) |
|
|
| @routes.get("/queue") |
| async def get_queue(request): |
| queue_info = {} |
| current_queue = self.prompt_queue.get_current_queue_volatile() |
| queue_info['queue_running'] = current_queue[0] |
| queue_info['queue_pending'] = current_queue[1] |
| return web.json_response(queue_info) |
|
|
| @routes.post("/prompt") |
| async def post_prompt(request): |
| logging.info("got prompt") |
| json_data = await request.json() |
| json_data = self.trigger_on_prompt(json_data) |
|
|
| if "number" in json_data: |
| number = float(json_data['number']) |
| else: |
| number = self.number |
| if "front" in json_data: |
| if json_data['front']: |
| number = -number |
|
|
| self.number += 1 |
|
|
| if "prompt" in json_data: |
| prompt = json_data["prompt"] |
| prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) |
|
|
| partial_execution_targets = None |
| if "partial_execution_targets" in json_data: |
| partial_execution_targets = json_data["partial_execution_targets"] |
|
|
| valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) |
| extra_data = {} |
| if "extra_data" in json_data: |
| extra_data = json_data["extra_data"] |
|
|
| if "client_id" in json_data: |
| extra_data["client_id"] = json_data["client_id"] |
| if valid[0]: |
| outputs_to_execute = valid[2] |
| self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) |
| response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} |
| return web.json_response(response) |
| else: |
| logging.warning("invalid prompt: {}".format(valid[1])) |
| return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) |
| else: |
| error = { |
| "type": "no_prompt", |
| "message": "No prompt provided", |
| "details": "No prompt provided", |
| "extra_info": {} |
| } |
| return web.json_response({"error": error, "node_errors": {}}, status=400) |
|
|
| @routes.post("/queue") |
| async def post_queue(request): |
| json_data = await request.json() |
| if "clear" in json_data: |
| if json_data["clear"]: |
| self.prompt_queue.wipe_queue() |
| if "delete" in json_data: |
| to_delete = json_data['delete'] |
| for id_to_delete in to_delete: |
| delete_func = lambda a: a[1] == id_to_delete |
| self.prompt_queue.delete_queue_item(delete_func) |
|
|
| return web.Response(status=200) |
|
|
| @routes.post("/interrupt") |
| async def post_interrupt(request): |
| try: |
| json_data = await request.json() |
| except json.JSONDecodeError: |
| json_data = {} |
|
|
| |
| prompt_id = json_data.get('prompt_id') |
| if prompt_id: |
| currently_running, _ = self.prompt_queue.get_current_queue() |
|
|
| |
| should_interrupt = False |
| for item in currently_running: |
| |
| if item[1] == prompt_id: |
| logging.info(f"Interrupting prompt {prompt_id}") |
| should_interrupt = True |
| break |
|
|
| if should_interrupt: |
| nodes.interrupt_processing() |
| else: |
| logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") |
| else: |
| |
| logging.info("Global interrupt (no prompt_id specified)") |
| nodes.interrupt_processing() |
|
|
| return web.Response(status=200) |
|
|
| @routes.post("/free") |
| async def post_free(request): |
| json_data = await request.json() |
| unload_models = json_data.get("unload_models", False) |
| free_memory = json_data.get("free_memory", False) |
| if unload_models: |
| self.prompt_queue.set_flag("unload_models", unload_models) |
| if free_memory: |
| self.prompt_queue.set_flag("free_memory", free_memory) |
| return web.Response(status=200) |
|
|
| @routes.post("/history") |
| async def post_history(request): |
| json_data = await request.json() |
| if "clear" in json_data: |
| if json_data["clear"]: |
| self.prompt_queue.wipe_history() |
| if "delete" in json_data: |
| to_delete = json_data['delete'] |
| for id_to_delete in to_delete: |
| self.prompt_queue.delete_history_item(id_to_delete) |
|
|
| return web.Response(status=200) |
|
|
| async def setup(self): |
| timeout = aiohttp.ClientTimeout(total=None) |
| self.client_session = aiohttp.ClientSession(timeout=timeout) |
|
|
| def add_routes(self): |
| self.user_manager.add_routes(self.routes) |
| self.model_file_manager.add_routes(self.routes) |
| self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) |
| self.app.add_subapp('/internal', self.internal_routes.get_app()) |
|
|
| |
| |
| |
| |
| |
| api_routes = web.RouteTableDef() |
| for route in self.routes: |
| |
| |
| if isinstance(route, web.RouteDef): |
| api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs) |
| self.app.add_routes(api_routes) |
| self.app.add_routes(self.routes) |
|
|
| |
| for name, dir in nodes.EXTENSION_WEB_DIRS.items(): |
| self.app.add_routes([web.static('/extensions/' + name, dir)]) |
|
|
| workflow_templates_path = FrontendManager.templates_path() |
| if workflow_templates_path: |
| self.app.add_routes([ |
| web.static('/templates', workflow_templates_path) |
| ]) |
|
|
| |
| embedded_docs_path = FrontendManager.embedded_docs_path() |
| if embedded_docs_path: |
| self.app.add_routes([ |
| web.static('/docs', embedded_docs_path) |
| ]) |
|
|
| self.app.add_routes([ |
| web.static('/', self.web_root), |
| ]) |
|
|
| def get_queue_info(self): |
| prompt_info = {} |
| exec_info = {} |
| exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() |
| prompt_info['exec_info'] = exec_info |
| return prompt_info |
|
|
| async def send(self, event, data, sid=None): |
| if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: |
| await self.send_image(data, sid=sid) |
| elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: |
| |
| preview_image, metadata = data |
| await self.send_image_with_metadata(preview_image, metadata, sid=sid) |
| elif isinstance(data, (bytes, bytearray)): |
| await self.send_bytes(event, data, sid) |
| else: |
| await self.send_json(event, data, sid) |
|
|
| def encode_bytes(self, event, data): |
| if not isinstance(event, int): |
| raise RuntimeError(f"Binary event types must be integers, got {event}") |
|
|
| packed = struct.pack(">I", event) |
| message = bytearray(packed) |
| message.extend(data) |
| return message |
|
|
| async def send_image(self, image_data, sid=None): |
| image_type = image_data[0] |
| image = image_data[1] |
| max_size = image_data[2] |
| if max_size is not None: |
| if hasattr(Image, 'Resampling'): |
| resampling = Image.Resampling.BILINEAR |
| else: |
| resampling = Image.Resampling.LANCZOS |
|
|
| image = ImageOps.contain(image, (max_size, max_size), resampling) |
| type_num = 1 |
| if image_type == "JPEG": |
| type_num = 1 |
| elif image_type == "PNG": |
| type_num = 2 |
|
|
| bytesIO = BytesIO() |
| header = struct.pack(">I", type_num) |
| bytesIO.write(header) |
| image.save(bytesIO, format=image_type, quality=95, compress_level=1) |
| preview_bytes = bytesIO.getvalue() |
| await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) |
|
|
| async def send_image_with_metadata(self, image_data, metadata=None, sid=None): |
| image_type = image_data[0] |
| image = image_data[1] |
| max_size = image_data[2] |
| if max_size is not None: |
| if hasattr(Image, 'Resampling'): |
| resampling = Image.Resampling.BILINEAR |
| else: |
| resampling = Image.Resampling.LANCZOS |
|
|
| image = ImageOps.contain(image, (max_size, max_size), resampling) |
|
|
| mimetype = "image/png" if image_type == "PNG" else "image/jpeg" |
|
|
| |
| if metadata is None: |
| metadata = {} |
| metadata["image_type"] = mimetype |
|
|
| |
| import json |
| metadata_json = json.dumps(metadata).encode('utf-8') |
| metadata_length = len(metadata_json) |
|
|
| |
| bytesIO = BytesIO() |
| image.save(bytesIO, format=image_type, quality=95, compress_level=1) |
| image_bytes = bytesIO.getvalue() |
|
|
| |
| combined_data = bytearray() |
| combined_data.extend(struct.pack(">I", metadata_length)) |
| combined_data.extend(metadata_json) |
| combined_data.extend(image_bytes) |
|
|
| await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid) |
|
|
| async def send_bytes(self, event, data, sid=None): |
| message = self.encode_bytes(event, data) |
|
|
| if sid is None: |
| sockets = list(self.sockets.values()) |
| for ws in sockets: |
| await send_socket_catch_exception(ws.send_bytes, message) |
| elif sid in self.sockets: |
| await send_socket_catch_exception(self.sockets[sid].send_bytes, message) |
|
|
| async def send_json(self, event, data, sid=None): |
| message = {"type": event, "data": data} |
|
|
| if sid is None: |
| sockets = list(self.sockets.values()) |
| for ws in sockets: |
| await send_socket_catch_exception(ws.send_json, message) |
| elif sid in self.sockets: |
| await send_socket_catch_exception(self.sockets[sid].send_json, message) |
|
|
| def send_sync(self, event, data, sid=None): |
| self.loop.call_soon_threadsafe( |
| self.messages.put_nowait, (event, data, sid)) |
|
|
| def queue_updated(self): |
| self.send_sync("status", { "status": self.get_queue_info() }) |
|
|
| async def publish_loop(self): |
| while True: |
| msg = await self.messages.get() |
| await self.send(*msg) |
|
|
| async def start(self, address, port, verbose=True, call_on_start=None): |
| await self.start_multi_address([(address, port)], call_on_start=call_on_start) |
|
|
| async def start_multi_address(self, addresses, call_on_start=None, verbose=True): |
| runner = web.AppRunner(self.app, access_log=None) |
| await runner.setup() |
| ssl_ctx = None |
| scheme = "http" |
| if args.tls_keyfile and args.tls_certfile: |
| ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) |
| ssl_ctx.load_cert_chain(certfile=args.tls_certfile, |
| keyfile=args.tls_keyfile) |
| scheme = "https" |
|
|
| if verbose: |
| logging.info("Starting server\n") |
| for addr in addresses: |
| address = addr[0] |
| port = addr[1] |
| site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) |
| await site.start() |
|
|
| if not hasattr(self, 'address'): |
| self.address = address |
| self.port = port |
|
|
| if ':' in address: |
| address_print = "[{}]".format(address) |
| else: |
| address_print = address |
|
|
| if verbose: |
| logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port)) |
|
|
| if call_on_start is not None: |
| call_on_start(scheme, self.address, self.port) |
|
|
| def add_on_prompt_handler(self, handler): |
| self.on_prompt_handlers.append(handler) |
|
|
| def trigger_on_prompt(self, json_data): |
| for handler in self.on_prompt_handlers: |
| try: |
| json_data = handler(json_data) |
| except Exception: |
| logging.warning("[ERROR] An error occurred during the on_prompt_handler processing") |
| logging.warning(traceback.format_exc()) |
|
|
| return json_data |
|
|
| def send_progress_text( |
| self, text: Union[bytes, bytearray, str], node_id: str, sid=None |
| ): |
| if isinstance(text, str): |
| text = text.encode("utf-8") |
| node_id_bytes = str(node_id).encode("utf-8") |
|
|
| |
| message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text |
|
|
| self.send_sync(BinaryEventTypes.TEXT, message, sid) |
|
|