diff --git "a/webui.py" "b/webui.py" --- "a/webui.py" +++ "b/webui.py" @@ -1,1227 +1,1329 @@ -import gradio as gr -import websocket -import uuid -import json -import urllib.request -import urllib.parse -from PIL import Image -import io -import os -import random -import time -import threading -import base64 - -try: - from fastapi import FastAPI, Request - from fastapi.responses import JSONResponse, FileResponse - import uvicorn -except Exception: - FastAPI = None - uvicorn = None - -# --- Constants and Setup --- -BASE_DIR = os.path.dirname(__file__) -# Allow overriding data directory via environment variable WEBUI_DATA_DIR -DATA_DIR = os.path.abspath(os.getenv('WEBUI_DATA_DIR', os.path.join(BASE_DIR, 'data'))) -WORKFLOWS_DIR = os.path.join(DATA_DIR, 'workflows') -OUTPUT_DIR = os.path.join(DATA_DIR, 'outputs') -PRESETS_FILE = os.path.join(DATA_DIR, 'presets.json') -USER_CONFIG_FILE = os.path.join(DATA_DIR, 'user_config.json') -os.makedirs(OUTPUT_DIR, exist_ok=True) -os.makedirs(WORKFLOWS_DIR, exist_ok=True) -SCHEDULER_STOP = False - -# --- Auto-save Config Manager (20s interval, debounced) --- -CONFIG_SAVE_INTERVAL = int(os.getenv('WEBUI_CONFIG_INTERVAL', '20')) # seconds -_pending_config = {} -_config_lock = threading.Lock() -_config_saver_thread = None -_config_changed = False - -def start_config_saver(): - """Start the background config saver thread.""" - global _config_saver_thread - if _config_saver_thread is None or not _config_saver_thread.is_alive(): - _config_saver_thread = threading.Thread(target=_config_saver_loop, daemon=True) - _config_saver_thread.start() - -def _flush_pending_config(): - """Flush any pending config changes immediately.""" - global _config_changed - with _config_lock: - if _config_changed and _pending_config: - try: - config = load_user_config() - config.update(_pending_config) - save_user_config(config) - print(f"[Auto-save] Configuration saved at {time.strftime('%H:%M:%S')}") - except Exception as e: - print(f"[Auto-save] Error saving config: {e}") - finally: - _pending_config.clear() - _config_changed = False - -def _config_saver_loop(): - """Background thread that saves config every CONFIG_SAVE_INTERVAL if changed.""" - while True: - time.sleep(CONFIG_SAVE_INTERVAL) - _flush_pending_config() - -def set_config_save_interval(seconds: int): - """Update autosave interval at runtime (min 5s).""" - global CONFIG_SAVE_INTERVAL - try: - seconds = int(seconds) - if seconds < 5: - seconds = 5 - except Exception: - seconds = 20 - CONFIG_SAVE_INTERVAL = seconds - queue_config_update(config_save_interval=seconds) - return seconds - -def queue_config_update(**kwargs): - """Queue config updates to be saved in the next interval.""" - global _config_changed - with _config_lock: - _pending_config.update(kwargs) - _config_changed = True - -# --- Preset Management Functions --- -def load_presets(): - """Loads presets from the JSON file. If not found, creates it with defaults.""" - default_presets = { - "None": {"positive": "", "negative": ""}, - "✨ 推荐风格": { - "positive": "best quality, very aesthetic, highres, absurdres, sensitive", - "negative": "lowres, (bad), bad feet, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, artistic error, username, scan, [abstract], english text, shiny_skin" - }, - "🎨 动漫风格": { - "positive": "masterpiece, best quality, anime, 1girl, beautiful detailed eyes, detailed face", - "negative": "photorealistic, 3d, extra limbs, bad anatomy, ugly, deformed" - }, - "📸 写实风格": { - "positive": "photorealistic, high quality, detailed, professional photography", - "negative": "anime, cartoon, drawing, painting, sketch" - } - } - - if not os.path.exists(PRESETS_FILE): - with open(PRESETS_FILE, 'w', encoding='utf-8') as f: - json.dump(default_presets, f, indent=4) - return default_presets - else: - try: - with open(PRESETS_FILE, 'r', encoding='utf-8') as f: - presets = json.load(f) - if "None" not in presets: - presets["None"] = {"positive": "", "negative": ""} - return presets - except (json.JSONDecodeError, IOError): - with open(PRESETS_FILE, 'w', encoding='utf-8') as f: - json.dump(default_presets, f, indent=4) - return default_presets - -def save_presets(presets_dict): - """Saves the given dictionary to the presets JSON file.""" - with open(PRESETS_FILE, 'w', encoding='utf-8') as f: - json.dump(presets_dict, f, indent=4) - -def combine_prompts(prefix, main_prompt): - """Combines prefix and main prompt intelligently.""" - if prefix and main_prompt: - return f"{prefix.strip()}, {main_prompt.strip()}" - elif prefix: - return prefix.strip() - elif main_prompt: - return main_prompt.strip() - return "" - -def select_preset(preset_name): - """Selects a preset and returns its values.""" - preset_data = GLOBAL_PRESETS.get(preset_name, {"positive": "", "negative": ""}) - return preset_name, preset_data["positive"], preset_data["negative"] - -def save_or_update_preset(preset_name, positive_prefix, negative_prefix): - """Saves or updates a preset.""" - if not preset_name or not preset_name.strip(): - return gr.update(), "Preset name cannot be empty." - - preset_name = preset_name.strip() - GLOBAL_PRESETS[preset_name] = {"positive": positive_prefix, "negative": negative_prefix} - save_presets(GLOBAL_PRESETS) - return gr.update(choices=list(GLOBAL_PRESETS.keys()), value=preset_name), f"Preset '{preset_name}' saved." - -def delete_preset(preset_name): - """Deletes a preset.""" - if not preset_name or preset_name.strip() in ["", "None"]: - return gr.update(), gr.update(), gr.update(), gr.update(), "Cannot delete this preset." - - preset_name = preset_name.strip() - if preset_name in GLOBAL_PRESETS: - del GLOBAL_PRESETS[preset_name] - save_presets(GLOBAL_PRESETS) - return (gr.update(choices=list(GLOBAL_PRESETS.keys()), value="None"), - "None", "", "", f"Preset '{preset_name}' deleted.") - return gr.update(), gr.update(), gr.update(), gr.update(), f"Preset '{preset_name}' not found." - -# Load presets globally -GLOBAL_PRESETS = load_presets() - -# --- User Config Management Functions --- -def load_user_config(): - """Loads user configuration from JSON file.""" - default_config = { - "server_address": "127.0.0.1:8188", - "model": "", - "sampler": "euler", - "scheduler": "normal", - "steps": 30, - "cfg": 6.0, - "width": 768, - "height": 1280, - "batch_size": 1, - "batch_count": 1, - "seed": 757831338432565, - "after_generate": "randomize", - "positive_prefix": "", - "negative_prefix": "", - "positive_prompt": "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master,", - "negative_prompt": "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin,", - "preset_name": "None", - "current_workflow": "workflow_template", - "language": "en", - "config_save_interval": 20, - # OpenAI API defaults (can override generator settings) - "api_server_address": "", - "api_model": "", - "api_sampler": "", - "api_scheduler": "", - "api_steps": 30, - "api_cfg": 6.0, - "api_width": 768, - "api_height": 1280, - "api_seed": 757831338432565, - "api_after_generate": "randomize", - "api_positive_prefix": "", - "api_negative_prefix": "", - "api_workflow": "workflow_template", - "api_return": "url", - "api_n": 1 - } - - if not os.path.exists(USER_CONFIG_FILE): - save_user_config(default_config) - return default_config - - try: - with open(USER_CONFIG_FILE, 'r', encoding='utf-8') as f: - config = json.load(f) - # 合并默认配置,确保新字段存在 - for key, value in default_config.items(): - if key not in config: - config[key] = value - return config - except (json.JSONDecodeError, IOError): - save_user_config(default_config) - return default_config - -def save_user_config(config_dict): - """Saves user configuration to JSON file.""" - with open(USER_CONFIG_FILE, 'w', encoding='utf-8') as f: - json.dump(config_dict, f, indent=4) - -def update_user_config(**kwargs): - """Updates specific configuration values.""" - config = load_user_config() - for key, value in kwargs.items(): - config[key] = value - save_user_config(config) - -# Load user config globally -USER_CONFIG = load_user_config() - -# --- Workflow Management Functions --- -def load_workflows(): - """Loads all workflow files from the workflows directory.""" - workflows = {} - if not os.path.exists(WORKFLOWS_DIR): - return workflows - - for filename in os.listdir(WORKFLOWS_DIR): - if filename.endswith('.json'): - workflow_name = filename[:-5] # Remove .json extension - workflow_path = os.path.join(WORKFLOWS_DIR, filename) - try: - with open(workflow_path, 'r', encoding='utf-8') as f: - workflows[workflow_name] = json.load(f) - except (json.JSONDecodeError, IOError) as e: - print(f"Error loading workflow {workflow_name}: {e}") - - return workflows - -def save_workflow(workflow_name, workflow_content): - """Saves a workflow to the workflows directory.""" - if not workflow_name or not workflow_name.strip(): - return False, "Workflow name cannot be empty." - - workflow_name = workflow_name.strip() - workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") - - try: - # Validate JSON content - json.loads(workflow_content) - with open(workflow_path, 'w', encoding='utf-8') as f: - f.write(workflow_content) - return True, f"Workflow '{workflow_name}' saved successfully." - except json.JSONDecodeError as e: - return False, f"Invalid JSON format: {e}" - except IOError as e: - return False, f"Error saving workflow: {e}" - -def delete_workflow(workflow_name): - """Deletes a workflow file.""" - if not workflow_name or workflow_name.strip() in ["", "workflow_template"]: - return False, "Cannot delete this workflow." - - workflow_name = workflow_name.strip() - workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") - - if os.path.exists(workflow_path): - try: - os.remove(workflow_path) - return True, f"Workflow '{workflow_name}' deleted successfully." - except IOError as e: - return False, f"Error deleting workflow: {e}" - else: - return False, f"Workflow '{workflow_name}' not found." - -def load_workflow_content(workflow_name): - """Loads the content of a specific workflow.""" - if not workflow_name or workflow_name == "workflow_template": - # Load default template - workflow_path = os.path.join(WORKFLOWS_DIR, "workflow_template.json") - else: - workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") - - if os.path.exists(workflow_path): - try: - with open(workflow_path, 'r', encoding='utf-8') as f: - return json.load(f) - except (json.JSONDecodeError, IOError) as e: - print(f"Error loading workflow {workflow_name}: {e}") - return None - return None - -# Load workflows globally -GLOBAL_WORKFLOWS = load_workflows() - -# --- OpenAI-compatible API Server (FastAPI) --- -OPENAI_SERVER_THREAD = None -OPENAI_UVICORN_SERVER = None -OPENAI_API_APP = None - -def _ensure_fastapi_available(): - if FastAPI is None or uvicorn is None: - raise RuntimeError("FastAPI/uvicorn not installed. Please install with: pip install fastapi uvicorn") - -def _encode_image_b64(path: str) -> str: - with open(path, 'rb') as f: - return base64.b64encode(f.read()).decode('utf-8') - -def generate_image_sync(server_address, positive_prefix, negative_prefix, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, current_workflow): - """Synchronous image generation that returns list of saved file paths.""" - # Normalize server address - if not server_address.startswith("http://") and not server_address.startswith("https://"): - server_address = "http://" + server_address - server_address = server_address.rstrip('/') - - ws_address = "ws://" + server_address[len("http://"):] - if server_address.startswith("https://"): - ws_address = "wss://" + server_address[len("https://"):] - - client_id = str(uuid.uuid4()) - all_generated_images = [] - initial_seed = seed - - for i in range(batch_count): - if after_generate == "randomize": - current_seed = random.randint(0, 2**32 - 1) - elif after_generate == "increment": - current_seed = initial_seed + i - elif after_generate == "decrement": - current_seed = initial_seed - i - else: # "fixed" - current_seed = initial_seed - - ws = websocket.WebSocket() - try: - ws.connect(f"{ws_address}/ws?clientId={client_id}") - - workflow_content = load_workflow_content(current_workflow) - if workflow_content is None: - break - workflow_content = json.dumps(workflow_content) - - final_positive_prompt = combine_prompts(positive_prefix, positive_prompt) - final_negative_prompt = combine_prompts(negative_prefix, negative_prompt) - - workflow_content = workflow_content.replace('%prompt%', final_positive_prompt) - workflow_content = workflow_content.replace('%negative_prompt%', final_negative_prompt) - workflow_content = workflow_content.replace('%model%', model) - workflow_content = workflow_content.replace('%width%', str(width)) - workflow_content = workflow_content.replace('%height%', str(height)) - workflow_content = workflow_content.replace('%batch_size%', str(batch_size)) - workflow_content = workflow_content.replace('%seed%', str(current_seed)) - workflow_content = workflow_content.replace('%steps%', str(steps)) - workflow_content = workflow_content.replace('%cfg%', str(cfg)) - workflow_content = workflow_content.replace('%sampler%', sampler) - workflow_content = workflow_content.replace('%scheduler%', scheduler) - - prompt_workflow = json.loads(workflow_content) - prompt_data = queue_prompt(prompt_workflow, client_id, server_address) - prompt_id = prompt_data['prompt_id'] - - while True: - out = ws.recv() - if not isinstance(out, str): - continue - message = json.loads(out) - if message['type'] == 'executing': - data = message['data'] - if data['node'] is None and data['prompt_id'] == prompt_id: - break - - history = get_history(prompt_id, server_address)[prompt_id] - images_output = [] - for node_id in history['outputs']: - if 'images' in history['outputs'][node_id]: - for image in history['outputs'][node_id]['images']: - image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address) - images_output.append(image_data) - - if not images_output: - continue - - pil_images = [Image.open(io.BytesIO(data)) for data in images_output] - for img_idx, img in enumerate(pil_images): - filename = f"{int(time.time())}_{current_seed}_{img_idx}.png" - filepath = os.path.join(OUTPUT_DIR, filename) - img.save(filepath) - all_generated_images.append(filepath) - - finally: - if ws.connected: - ws.close() - - return all_generated_images - -def _create_openai_app(get_config): - """Create FastAPI app with OpenAI-compatible routes.""" - app = FastAPI() - - @app.get("/v1/files/{filename}") - def get_file(filename: str): - path = os.path.join(OUTPUT_DIR, filename) - if os.path.exists(path): - return FileResponse(path) - return JSONResponse(status_code=404, content={"error": {"message": "File not found"}}) - - @app.post("/v1/chat/completions") - async def chat_completions(req: Request): - body = await req.json() - model = body.get("model", "gpt-image-proxy") - messages = body.get("messages", []) - n = int(body.get("n", get_config()["api_n"])) - return_type = body.get("response_format", {}).get("type", get_config()["api_return"]) # "b64_json" or "url" - - # latest user message - user_text = "" - for m in reversed(messages): - if m.get("role") == "user": - content = m.get("content", "") - if isinstance(content, list): - user_text = " ".join([item.get("text", "") for item in content if item.get("type") == "text"]).strip() - else: - user_text = str(content) - break - - cfg = get_config() - # Use a random seed for each request and randomize across images in n - req_seed = random.randint(0, 2**32 - 1) - filepaths = generate_image_sync( - cfg["server_address"], cfg["positive_prefix"], cfg["negative_prefix"], user_text, cfg["negative_prompt"], - cfg["model"], cfg["sampler"], cfg["scheduler"], cfg["steps"], cfg["cfg"], cfg["width"], cfg["height"], - int(req_seed), "randomize", 1, n, cfg["current_workflow"] - ) - - choices = [] - for idx, fp in enumerate(filepaths): - if return_type == "url": - content = f"/v1/files/{os.path.basename(fp)}" - else: - content = _encode_image_b64(fp) - choices.append({ - "index": idx, - "finish_reason": "stop", - "message": {"role": "assistant", "content": content} - }) - - resp = { - "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": choices, - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - "x_comfy": {"seed": int(req_seed), "after_generate": "randomize"} - } - return JSONResponse(content=resp) - - return app - -def start_openai_server(host: str, port: int, get_config): - global OPENAI_SERVER_THREAD, OPENAI_UVICORN_SERVER, OPENAI_API_APP - _ensure_fastapi_available() - if OPENAI_SERVER_THREAD and OPENAI_SERVER_THREAD.is_alive(): - return True, f"Already running on {host}:{port}" - OPENAI_API_APP = _create_openai_app(get_config) - config = uvicorn.Config(OPENAI_API_APP, host=host, port=port, log_level="info") - server = uvicorn.Server(config) - OPENAI_UVICORN_SERVER = server - def _run(): - server.run() - t = threading.Thread(target=_run, daemon=True) - OPENAI_SERVER_THREAD = t - t.start() - return True, f"OpenAI API running on http://{host}:{port}" - -def stop_openai_server(): - global OPENAI_UVICORN_SERVER, OPENAI_SERVER_THREAD - if OPENAI_UVICORN_SERVER: - OPENAI_UVICORN_SERVER.should_exit = True - return True, "OpenAI API stopping..." - -# --- Debounced Save Functions (queued, saved every 20s) --- -def save_server_address(server_address): - queue_config_update(server_address=server_address) - return server_address - -def save_model(model): - queue_config_update(model=model) - return model - -def save_sampler(sampler): - queue_config_update(sampler=sampler) - return sampler - -def save_scheduler(scheduler): - queue_config_update(scheduler=scheduler) - return scheduler - -def save_steps(steps): - queue_config_update(steps=steps) - return steps - -def save_cfg(cfg): - queue_config_update(cfg=cfg) - return cfg - -def save_width(width): - queue_config_update(width=width) - return width - -def save_height(height): - queue_config_update(height=height) - return height - -def save_batch_size(batch_size): - queue_config_update(batch_size=batch_size) - return batch_size - -def save_batch_count(batch_count): - queue_config_update(batch_count=batch_count) - return batch_count - -def save_seed(seed): - queue_config_update(seed=seed) - return seed - -def save_after_generate(after_generate): - queue_config_update(after_generate=after_generate) - return after_generate - -def save_positive_prefix(positive_prefix): - queue_config_update(positive_prefix=positive_prefix) - return positive_prefix - -def save_negative_prefix(negative_prefix): - queue_config_update(negative_prefix=negative_prefix) - return negative_prefix - -def save_positive_prompt(positive_prompt): - queue_config_update(positive_prompt=positive_prompt) - return positive_prompt - -def save_negative_prompt(negative_prompt): - queue_config_update(negative_prompt=negative_prompt) - return negative_prompt - -def save_preset_name(preset_name): - queue_config_update(preset_name=preset_name) - return preset_name - -def save_current_workflow(current_workflow): - queue_config_update(current_workflow=current_workflow) - return current_workflow - -def load_ui_config(): - """Loads user configuration and returns it for UI initialization.""" - config = load_user_config() - - # Get server address and fetch available options - server_address = config.get("server_address", "127.0.0.1:8188") - if not server_address.startswith("http://") and not server_address.startswith("https://"): - server_address = "http://" + server_address - - object_info = get_object_info(server_address) - available_models = get_models(object_info) - available_samplers = get_samplers(object_info) - available_schedulers = get_schedulers(object_info) - - return ( - config.get("server_address", "127.0.0.1:8188"), - config.get("model", ""), - config.get("sampler", "euler"), - config.get("scheduler", "normal"), - config.get("steps", 30), - config.get("cfg", 6.0), - config.get("width", 768), - config.get("height", 1280), - config.get("batch_size", 1), - config.get("batch_count", 1), - config.get("seed", 757831338432565), - config.get("after_generate", "randomize"), - config.get("positive_prefix", ""), - config.get("negative_prefix", ""), - config.get("positive_prompt", "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master,"), - config.get("negative_prompt", "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin,"), - config.get("preset_name", "None"), - config.get("current_workflow", "workflow_template"), - gr.update(choices=available_models), - gr.update(choices=available_samplers), - gr.update(choices=available_schedulers), - gr.update(choices=list(GLOBAL_WORKFLOWS.keys())) - ) - -# --- ComfyUI API Functions --- -def get_image(filename, subfolder, folder_type, server_address): - """Fetches an image from the ComfyUI server.""" - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} - url_values = urllib.parse.urlencode(data) - with urllib.request.urlopen(f"{server_address}/view?{url_values}") as response: - return response.read() - -def queue_prompt(prompt, client_id, server_address): - """Queues a prompt on the ComfyUI server.""" - p = {"prompt": prompt, "client_id": client_id} - data = json.dumps(p).encode('utf-8') - req = urllib.request.Request(f"{server_address}/prompt", data=data) - response = urllib.request.urlopen(req) - return json.loads(response.read()) - -def get_history(prompt_id, server_address): - """Gets the history for a given prompt ID.""" - with urllib.request.urlopen(f"{server_address}/history/{prompt_id}") as response: - return json.loads(response.read()) - -def get_object_info(server_address): - """Gets object info from the ComfyUI server.""" - try: - with urllib.request.urlopen(f"{server_address}/object_info") as response: - return json.loads(response.read()) - except Exception as e: - print(f"Failed to fetch object info: {e}") - return None - -def get_models(object_info): - """Extracts a comprehensive list of models from object_info.""" - models = [] - if not object_info: - return ["model.safetensors"] - - if "CheckpointLoaderSimple" in object_info and "ckpt_name" in object_info["CheckpointLoaderSimple"]["input"]["required"]: - models.extend(object_info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0]) - if "UNETLoader" in object_info and "unet_name" in object_info["UNETLoader"]["input"]["required"]: - models.extend(object_info["UNETLoader"]["input"]["required"]["unet_name"][0]) - if "UnetLoaderGGUF" in object_info and "unet_name" in object_info["UnetLoaderGGUF"]["input"]["required"]: - models.extend(object_info["UnetLoaderGGUF"]["input"]["required"]["unet_name"][0]) - - if not models: - return ["model.safetensors"] - - return list(dict.fromkeys(models)) - -def get_samplers(object_info): - if object_info and "KSampler" in object_info: - return object_info["KSampler"]["input"]["required"]["sampler_name"][0] - return ["euler"] - -def get_schedulers(object_info): - if object_info and "KSampler" in object_info: - return object_info["KSampler"]["input"]["required"]["scheduler"][0] - return ["normal"] - -# --- UI Callback Functions --- -def update_choices(server_address): - """Callback function to update dropdown choices.""" - if not server_address: - return (gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[])) - - if not server_address.startswith("http://") and not server_address.startswith("https://"): - http_server_address = "http://" + server_address - else: - http_server_address = server_address - http_server_address = http_server_address.rstrip('/') - - object_info = get_object_info(http_server_address) - available_models = get_models(object_info) - available_samplers = get_samplers(object_info) - available_schedulers = get_schedulers(object_info) - - return ( - gr.update(choices=available_models, value=available_models[0] if available_models else None), - gr.update(choices=available_samplers, value=available_samplers[0] if available_samplers else None), - gr.update(choices=available_schedulers, value=available_schedulers[0] if available_schedulers else None) - ) - -# --- Scheduler Functions --- -def stop_scheduler(): - """Sets the global flag to stop the scheduler.""" - global SCHEDULER_STOP - SCHEDULER_STOP = True - print("Scheduler stop requested.") - return "Scheduler stopping..." - -def run_scheduled_generation(interval, server_address, *gen_args): - """Runs the generation task on a schedule.""" - global SCHEDULER_STOP - SCHEDULER_STOP = False - print("Scheduler started.") - - # Prepend server_address back to gen_args for generate_images call - full_gen_args = [server_address] + list(gen_args) - - while not SCHEDULER_STOP: - yield "Running scheduled generation...", None - - # Call the main generation function, but for a single image/batch - gen_with_single_batch = list(full_gen_args) - # The index for batch_count is the last one - gen_with_single_batch[-1] = 1 - - gen = generate_images(*gen_with_single_batch) - final_gallery = None - for _, gallery in gen: - if SCHEDULER_STOP: break - final_gallery = gallery - - if SCHEDULER_STOP: break - - yield "Generation complete. Waiting for next run...", final_gallery - - wait_seconds = int(interval * 60) - for i in range(wait_seconds): - if SCHEDULER_STOP: break - if (wait_seconds - i) % 60 == 0: - remaining_minutes = (wait_seconds - i) // 60 - yield f"Next run in {remaining_minutes} minute(s)...", final_gallery - time.sleep(1) - - if SCHEDULER_STOP: break - - print("Scheduler stopped.") - yield "Scheduler stopped.", None - - -# --- History Management Functions --- -def get_history_images(): - """Returns a sorted list of images from the output directory.""" - if not os.path.exists(OUTPUT_DIR): - return [] - images = [os.path.join(OUTPUT_DIR, f) for f in os.listdir(OUTPUT_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp'))] - images.sort(key=os.path.getmtime, reverse=True) - return images - -def delete_image(filepaths): - """Deletes selected images and returns the updated list of images.""" - if isinstance(filepaths, list) and filepaths: - for item in filepaths: - # Handle both string paths and tuple (path, metadata) formats - if isinstance(item, tuple): - filepath = item[0] # Extract the file path from tuple - else: - filepath = item - - if filepath and os.path.exists(filepath): - try: - os.remove(filepath) - except Exception as e: - print(f"Error deleting file {filepath}: {e}") - return get_history_images() - -# --- Core Generation Logic --- -def generate_images(server_address, positive_prefix, negative_prefix, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, current_workflow): - """Main function to generate images based on UI inputs.""" - # Normalize server address - if not server_address.startswith("http://") and not server_address.startswith("https://"): - server_address = "http://" + server_address - server_address = server_address.rstrip('/') - - ws_address = "ws://" + server_address[len("http://"):] - if server_address.startswith("https://"): - ws_address = "wss://" + server_address[len("https://"):] - - client_id = str(uuid.uuid4()) - all_generated_images = [] - initial_seed = seed - - for i in range(batch_count): - yield f"Running batch {i+1}/{batch_count}...", all_generated_images - - if after_generate == "randomize": - current_seed = random.randint(0, 2**32 - 1) - elif after_generate == "increment": - current_seed = initial_seed + i - elif after_generate == "decrement": - current_seed = initial_seed - i - else: # "fixed" - current_seed = initial_seed - - ws = websocket.WebSocket() - try: - yield f"Batch {i+1}: Connecting...", all_generated_images - ws.connect(f"{ws_address}/ws?clientId={client_id}") - - # Load workflow content - workflow_content = load_workflow_content(current_workflow) - if workflow_content is None: - yield f"Error: Could not load workflow '{current_workflow}'", all_generated_images - break - workflow_content = json.dumps(workflow_content) - - # Combine prefix and main prompts - final_positive_prompt = combine_prompts(positive_prefix, positive_prompt) - final_negative_prompt = combine_prompts(negative_prefix, negative_prompt) - - # Replace placeholders with actual values - workflow_content = workflow_content.replace('%prompt%', final_positive_prompt) - workflow_content = workflow_content.replace('%negative_prompt%', final_negative_prompt) - workflow_content = workflow_content.replace('%model%', model) - workflow_content = workflow_content.replace('%width%', str(width)) - workflow_content = workflow_content.replace('%height%', str(height)) - workflow_content = workflow_content.replace('%batch_size%', str(batch_size)) - workflow_content = workflow_content.replace('%seed%', str(current_seed)) - workflow_content = workflow_content.replace('%steps%', str(steps)) - workflow_content = workflow_content.replace('%cfg%', str(cfg)) - workflow_content = workflow_content.replace('%sampler%', sampler) - workflow_content = workflow_content.replace('%scheduler%', scheduler) - - # Parse the modified workflow - prompt_workflow = json.loads(workflow_content) - - prompt_data = queue_prompt(prompt_workflow, client_id, server_address) - prompt_id = prompt_data['prompt_id'] - - while True: - out = ws.recv() - if not isinstance(out, str): continue - message = json.loads(out) - if message['type'] == 'executing': - data = message['data'] - if data['node'] is None and data['prompt_id'] == prompt_id: - break - else: - node_id = data['node'] - node_title = prompt_workflow.get(node_id, {}).get('_meta', {}).get('title', f"Node {node_id}") - yield f"Batch {i+1}: Executing {node_title}...", all_generated_images - - history = get_history(prompt_id, server_address)[prompt_id] - images_output = [] - for node_id in history['outputs']: - if 'images' in history['outputs'][node_id]: - for image in history['outputs'][node_id]['images']: - image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address) - images_output.append(image_data) - - if not images_output: - continue - - pil_images = [Image.open(io.BytesIO(data)) for data in images_output] - for img_idx, img in enumerate(pil_images): - filename = f"{int(time.time())}_{current_seed}_{img_idx}.png" - filepath = os.path.join(OUTPUT_DIR, filename) - img.save(filepath) - all_generated_images.insert(0, filepath) # Insert at beginning to show newest first - - except Exception as e: - yield f"Error in batch {i+1}: {e}", all_generated_images - break # Stop on error - finally: - if ws.connected: - ws.close() - - yield "Done!", all_generated_images - -# --- Gradio UI --- -def create_ui(): - # Load initial configuration - config = load_user_config() - # Start auto-save background thread (debounced every 20s) - start_config_saver() - # Set initial default values (will be overridden by load_ui_config on page load) - # Don't fetch from server during initialization to avoid validation errors - available_models = [] - available_samplers = [] - available_schedulers = [] - - # Set initial default values (will be overridden by load_ui_config on page load) - default_server_address = "127.0.0.1:8188" - default_model = "" - default_sampler = "euler" - default_scheduler = "normal" - default_steps = 30 - default_cfg = 6.0 - default_width = 768 - default_height = 1280 - default_batch_size = 1 - default_batch_count = 1 - default_seed = 757831338432565 - default_after_generate = "randomize" - default_positive_prefix = "" - default_negative_prefix = "" - default_positive = "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master," - default_negative = "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin," - default_preset_name = "None" - default_workflow = "workflow_template" - - css = """ - :root { font-family: sans-serif; } - #output_gallery img, #history_gallery img { border: 2px solid #e0e0e0; border-radius: 8px; } - """ - - with gr.Blocks(css=css, theme=gr.themes.Soft()) as app: - gr.Markdown("

ComfyUI Web Interface

") - - with gr.Tabs(): - with gr.TabItem("Generator"): - with gr.Row(): - with gr.Column(scale=1): - gr.Markdown("

⚙️ Settings

") - with gr.Row(): - server_addr = gr.Textbox(label="Server Address", value=default_server_address, scale=3) - refresh_btn = gr.Button("🔄 Refresh", scale=1) - model = gr.Dropdown(label="Model (Checkpoint Name)", choices=[], value="") - - with gr.Accordion("Workflow", open=True): - workflow_selector = gr.Dropdown(label="Workflow Template", choices=list(GLOBAL_WORKFLOWS.keys()), value=default_workflow) - - with gr.Accordion("Sampling Parameters", open=True): - sampler = gr.Dropdown(label="Sampler", choices=[], value="euler") - scheduler = gr.Dropdown(label="Scheduler", choices=[], value="normal") - steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=default_steps) - cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=default_cfg) - - with gr.Accordion("Image Dimensions", open=True): - width = gr.Slider(label="Width", minimum=64, maximum=2048, step=64, value=default_width) - height = gr.Slider(label="Height", minimum=64, maximum=2048, step=64, value=default_height) - batch_size = gr.Slider(label="Batch Size (Images per generation)", minimum=1, maximum=16, step=1, value=default_batch_size) - batch_count = gr.Slider(label="Batch Count (Executions)", minimum=1, maximum=20, step=1, value=default_batch_count) - - # Place seed and after_generate within the left settings column to keep two-column layout - with gr.Row(): - seed = gr.Number(label="Seed", value=default_seed, precision=0) - after_generate = gr.Dropdown( - label="After Generate", - choices=["randomize", "increment", "decrement", "fixed"], - value=default_after_generate - ) - - - with gr.Column(scale=2): - gr.Markdown("

🎨 Prompts & Generation

") - - with gr.Accordion("Style Presets", open=True): - preset_selector = gr.Dropdown(label="Select Style", choices=list(GLOBAL_PRESETS.keys()), value=default_preset_name) - preset_name_input = gr.Textbox(label="Style Name (for saving)", lines=1) - positive_prefix_input = gr.Textbox(label="Positive Prefix", lines=3, interactive=True, value=default_positive_prefix) - negative_prefix_input = gr.Textbox(label="Negative Prefix", lines=3, interactive=True, value=default_negative_prefix) - with gr.Row(): - save_preset_btn = gr.Button("💾 Save / Update Style") - delete_preset_btn = gr.Button("🗑️ Delete Style", variant="stop") - preset_status_label = gr.Label(value="Select a style to apply, or edit the fields and save a new one.") - - positive_prompt = gr.Textbox(label="Positive Prompt (Your content)", lines=6, value=default_positive) - negative_prompt = gr.Textbox(label="Negative Prompt (Your content)", lines=3, value=default_negative) - generate_btn = gr.Button("Generate Image", variant="primary") - status_label = gr.Label(value="Idle", label="Status") - output_gallery = gr.Gallery(label="Generated Images", elem_id="output_gallery", columns=4) - - with gr.TabItem("Scheduler / Keep-Alive"): - gr.Markdown("## Scheduled Generation") - gr.Markdown("This feature will periodically run a generation task with the settings from the 'Generator' tab to keep a remote server active. It will always run with a 'Batch Count' of 1.") - scheduler_interval = gr.Number(label="Interval (minutes)", value=10, minimum=1, step=1) - with gr.Row(): - start_scheduler_btn = gr.Button("Start Scheduler") - stop_scheduler_btn = gr.Button("Stop Scheduler") - scheduler_status = gr.Label("Scheduler is stopped.") - scheduler_output = gr.Gallery(label="Last Scheduled Image", columns=1, height="auto") - - with gr.TabItem("History"): - with gr.Row(): - refresh_history_btn = gr.Button("🔄 Refresh History") - delete_btn = gr.Button("🗑️ Delete Selected Images") - history_gallery = gr.Gallery(label="Image History", elem_id="history_gallery", columns=8, allow_preview=True, preview=True) - - with gr.TabItem("OpenAI API"): - gr.Markdown("## OpenAI-compatible API") - with gr.Row(): - api_host = gr.Textbox(label="Host", value="127.0.0.1") - api_port = gr.Number(label="Port", value=9000, precision=0) - gr.Markdown("### Request Mapping and Generation Parameters") - with gr.Row(): - api_return = gr.Dropdown(label="Response Type", choices=["url", "b64_json"], value=USER_CONFIG.get("api_return", "url")) - api_n = gr.Slider(label="Images per request (n)", minimum=1, maximum=8, step=1, value=USER_CONFIG.get("api_n", 1)) - with gr.Accordion("Override Generation Parameters (optional)", open=False): - with gr.Row(): - api_server_addr = gr.Textbox(label="Server Address (override)", value=USER_CONFIG.get("api_server_address", "")) - api_model = gr.Textbox(label="Model (ckpt)", value=USER_CONFIG.get("api_model", "")) - with gr.Row(): - api_sampler = gr.Textbox(label="Sampler", value=USER_CONFIG.get("api_sampler", "")) - api_scheduler = gr.Textbox(label="Scheduler", value=USER_CONFIG.get("api_scheduler", "")) - with gr.Row(): - api_steps = gr.Number(label="Steps", value=USER_CONFIG.get("api_steps", 30), precision=0) - api_cfg = gr.Number(label="CFG", value=USER_CONFIG.get("api_cfg", 6.0)) - with gr.Row(): - api_width = gr.Number(label="Width", value=USER_CONFIG.get("api_width", 768), precision=0) - api_height = gr.Number(label="Height", value=USER_CONFIG.get("api_height", 1280), precision=0) - with gr.Row(): - api_seed = gr.Number(label="Seed", value=USER_CONFIG.get("api_seed", 757831338432565), precision=0) - api_after = gr.Dropdown(label="After Generate", choices=["randomize", "increment", "decrement", "fixed"], value=USER_CONFIG.get("api_after_generate", "randomize")) - with gr.Row(): - api_pos_prefix = gr.Textbox(label="Positive Prefix", lines=2, value=USER_CONFIG.get("api_positive_prefix", "")) - api_neg_prefix = gr.Textbox(label="Negative Prefix", lines=2, value=USER_CONFIG.get("api_negative_prefix", "")) - with gr.Row(): - api_workflow = gr.Dropdown(label="Workflow Template", choices=list(GLOBAL_WORKFLOWS.keys()), value=USER_CONFIG.get("api_workflow", "workflow_template")) - api_status = gr.Label("Server is stopped.") - with gr.Row(): - api_save_cfg_btn = gr.Button("Save API Config") - api_start_btn = gr.Button("Start API Server") - api_stop_btn = gr.Button("Stop API Server") - - with gr.TabItem("Settings"): - gr.Markdown("## Workflow Management") - with gr.Row(): - with gr.Column(scale=1): - gr.Markdown("### Workflow List") - workflow_list = gr.Dropdown(label="Select Workflow", choices=list(GLOBAL_WORKFLOWS.keys()), value=default_workflow) - with gr.Row(): - load_workflow_btn = gr.Button("📂 Load Workflow") - delete_workflow_btn = gr.Button("🗑️ Delete Workflow", variant="stop") - workflow_status = gr.Label(value="Select a workflow to edit or create a new one.") - - with gr.Column(scale=2): - gr.Markdown("### Workflow Editor") - workflow_name_input = gr.Textbox(label="Workflow Name", lines=1, value="workflow_template") - workflow_content_input = gr.Textbox(label="Workflow JSON Content", lines=20, value="", max_lines=30) - with gr.Row(): - save_workflow_btn = gr.Button("💾 Save Workflow", variant="primary") - new_workflow_btn = gr.Button("➕ New Workflow") - workflow_editor_status = gr.Label(value="Edit the workflow JSON content above.") - - gr.Markdown("## Preferences") - with gr.Row(): - language_dropdown = gr.Dropdown(label="Language", choices=["en", "zh"], value=config.get("language", "en")) - autosave_interval = gr.Number(label="Autosave Interval (seconds)", value=config.get("config_save_interval", 20), minimum=5, step=1) - with gr.Row(): - save_prefs_btn = gr.Button("Save Preferences") - prefs_status = gr.Label("") - - # Define Inputs/Outputs for main generation - gen_inputs = [server_addr, positive_prefix_input, negative_prefix_input, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, workflow_selector] - gen_outputs = [status_label, output_gallery] - - # Wire up events - refresh_btn.click(fn=update_choices, inputs=server_addr, outputs=[model, sampler, scheduler]) - - # Real-time save events - server_addr.change(fn=save_server_address, inputs=server_addr, outputs=server_addr) - model.change(fn=save_model, inputs=model, outputs=model) - sampler.change(fn=save_sampler, inputs=sampler, outputs=sampler) - scheduler.change(fn=save_scheduler, inputs=scheduler, outputs=scheduler) - steps.change(fn=save_steps, inputs=steps, outputs=steps) - cfg.change(fn=save_cfg, inputs=cfg, outputs=cfg) - width.change(fn=save_width, inputs=width, outputs=width) - height.change(fn=save_height, inputs=height, outputs=height) - batch_size.change(fn=save_batch_size, inputs=batch_size, outputs=batch_size) - batch_count.change(fn=save_batch_count, inputs=batch_count, outputs=batch_count) - seed.change(fn=save_seed, inputs=seed, outputs=seed) - after_generate.change(fn=save_after_generate, inputs=after_generate, outputs=after_generate) - # Save text fields on blur instead of every keystroke - positive_prefix_input.blur(fn=save_positive_prefix, inputs=positive_prefix_input, outputs=positive_prefix_input) - negative_prefix_input.blur(fn=save_negative_prefix, inputs=negative_prefix_input, outputs=negative_prefix_input) - positive_prompt.blur(fn=save_positive_prompt, inputs=positive_prompt, outputs=positive_prompt) - negative_prompt.blur(fn=save_negative_prompt, inputs=negative_prompt, outputs=negative_prompt) - preset_selector.change(fn=save_preset_name, inputs=preset_selector, outputs=preset_selector) - workflow_selector.change(fn=save_current_workflow, inputs=workflow_selector, outputs=workflow_selector) - - # Preset events - preset_selector.change(fn=select_preset, inputs=preset_selector, outputs=[preset_name_input, positive_prefix_input, negative_prefix_input]) - save_preset_btn.click(fn=save_or_update_preset, inputs=[preset_name_input, positive_prefix_input, negative_prefix_input], outputs=[preset_selector, preset_status_label]) - delete_preset_btn.click(fn=delete_preset, inputs=[preset_name_input], outputs=[preset_selector, preset_name_input, positive_prefix_input, negative_prefix_input, preset_status_label]) - - gen_event = generate_btn.click(fn=generate_images, inputs=gen_inputs, outputs=gen_outputs) - gen_event.then(fn=get_history_images, outputs=history_gallery) - - # Scheduler Tab Events - scheduler_inputs = [scheduler_interval, server_addr, positive_prefix_input, negative_prefix_input, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, workflow_selector] - scheduler_outputs = [scheduler_status, scheduler_output] - - start_event = start_scheduler_btn.click(fn=run_scheduled_generation, inputs=scheduler_inputs, outputs=scheduler_outputs) - stop_scheduler_btn.click(fn=stop_scheduler, inputs=None, outputs=scheduler_status, cancels=[start_event]) - - # History Tab Events - app.load(fn=get_history_images, outputs=history_gallery) - refresh_history_btn.click(fn=get_history_images, outputs=history_gallery) - delete_btn.click(fn=delete_image, inputs=history_gallery, outputs=history_gallery) - - # OpenAI API tab events - def _get_api_config(): - cfg = load_user_config() - return { - "server_address": cfg.get("api_server_address") or cfg.get("server_address", "127.0.0.1:8188"), - "model": cfg.get("api_model") or cfg.get("model", ""), - "sampler": cfg.get("api_sampler") or cfg.get("sampler", "euler"), - "scheduler": cfg.get("api_scheduler") or cfg.get("scheduler", "normal"), - "steps": cfg.get("api_steps") or cfg.get("steps", 30), - "cfg": cfg.get("api_cfg") or cfg.get("cfg", 6.0), - "width": cfg.get("api_width") or cfg.get("width", 768), - "height": cfg.get("api_height") or cfg.get("height", 1280), - "seed": cfg.get("api_seed") or cfg.get("seed", 757831338432565), - "after_generate": cfg.get("api_after_generate") or cfg.get("after_generate", "randomize"), - "positive_prefix": cfg.get("api_positive_prefix") or cfg.get("positive_prefix", ""), - "negative_prefix": cfg.get("api_negative_prefix") or cfg.get("negative_prefix", ""), - "negative_prompt": cfg.get("negative_prompt", ""), - "current_workflow": cfg.get("api_workflow") or cfg.get("current_workflow", "workflow_template"), - "api_return": cfg.get("api_return", "url"), - "api_n": cfg.get("api_n", 1) - } - - def start_api(host, port, return_type, n): - queue_config_update(api_return=return_type, api_n=int(n)) - try: - ok, msg = start_openai_server(str(host), int(port), _get_api_config) - return msg - except Exception as e: - return f"Failed to start: {e}" - - def stop_api(): - ok, msg = stop_openai_server() - return msg - - def save_api_config(*vals): - queue_config_update( - api_server_address=vals[0], api_model=vals[1], api_sampler=vals[2], api_scheduler=vals[3], - api_steps=int(vals[4]), api_cfg=float(vals[5]), api_width=int(vals[6]), api_height=int(vals[7]), - api_seed=int(vals[8]), api_after_generate=vals[9], api_positive_prefix=vals[10], api_negative_prefix=vals[11], - api_workflow=vals[12] - ) - return "API config saved (debounced)." - - api_save_cfg_btn.click(fn=save_api_config, inputs=[api_server_addr, api_model, api_sampler, api_scheduler, api_steps, api_cfg, api_width, api_height, api_seed, api_after, api_pos_prefix, api_neg_prefix, api_workflow], outputs=api_status) - api_start_btn.click(fn=start_api, inputs=[api_host, api_port, api_return, api_n], outputs=api_status) - api_stop_btn.click(fn=stop_api, outputs=api_status) - - # Workflow management events - def load_workflow_to_editor(workflow_name): - """Loads a workflow into the editor.""" - if not workflow_name or workflow_name == "workflow_template": - workflow_path = os.path.join(WORKFLOWS_DIR, "workflow_template.json") - else: - workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") - - if os.path.exists(workflow_path): - try: - with open(workflow_path, 'r', encoding='utf-8') as f: - content = f.read() - return workflow_name, content, f"Loaded workflow '{workflow_name}'" - except Exception as e: - return workflow_name, "", f"Error loading workflow: {e}" - else: - return workflow_name, "", f"Workflow '{workflow_name}' not found" - - def save_workflow_from_editor(workflow_name, workflow_content): - """Saves workflow from editor.""" - success, message = save_workflow(workflow_name, workflow_content) - if success: - # Reload workflows - global GLOBAL_WORKFLOWS - GLOBAL_WORKFLOWS = load_workflows() - return gr.update(choices=list(GLOBAL_WORKFLOWS.keys()), value=workflow_name), message - else: - return gr.update(), message - - def delete_workflow_from_editor(workflow_name): - """Deletes a workflow.""" - success, message = delete_workflow(workflow_name) - if success: - # Reload workflows - global GLOBAL_WORKFLOWS - GLOBAL_WORKFLOWS = load_workflows() - return (gr.update(choices=list(GLOBAL_WORKFLOWS.keys()), value="workflow_template"), - "workflow_template", "", message) - else: - return gr.update(), workflow_name, "", message - - def create_new_workflow(): - """Creates a new empty workflow.""" - return "new_workflow", "", "New workflow created. Enter a name and JSON content." - - # Wire up workflow management events - load_workflow_btn.click(fn=load_workflow_to_editor, inputs=workflow_list, outputs=[workflow_name_input, workflow_content_input, workflow_status]) - save_workflow_btn.click(fn=save_workflow_from_editor, inputs=[workflow_name_input, workflow_content_input], outputs=[workflow_list, workflow_editor_status]) - delete_workflow_btn.click(fn=delete_workflow_from_editor, inputs=workflow_name_input, outputs=[workflow_list, workflow_name_input, workflow_content_input, workflow_editor_status]) - new_workflow_btn.click(fn=create_new_workflow, outputs=[workflow_name_input, workflow_content_input, workflow_editor_status]) - - # Preferences events - def save_preferences(lang, interval): - queue_config_update(language=lang) - seconds = set_config_save_interval(interval) - return f"Saved. Language: {lang}, autosave: {seconds}s" - - save_prefs_btn.click(fn=save_preferences, inputs=[language_dropdown, autosave_interval], outputs=prefs_status) - - # Load user config on page load - app.load(fn=load_ui_config, outputs=[ - server_addr, model, sampler, scheduler, steps, cfg, width, height, - batch_size, batch_count, seed, after_generate, positive_prefix_input, - negative_prefix_input, positive_prompt, negative_prompt, preset_selector, workflow_selector, - model, sampler, scheduler, workflow_selector - ]) - - return app - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None, help="Host to run the server on. Defaults to 127.0.0.1.") - args = parser.parse_args() - - webui = create_ui() - try: - # Pass server_name to launch() - webui.launch(server_name=args.host) - finally: - # Flush any pending config updates on exit - _flush_pending_config() +import gradio as gr +import websocket +import uuid +import json +import urllib.request +import urllib.parse +from PIL import Image +import io +import os +import random +import time +import threading +import base64 + +try: + from fastapi import FastAPI, Request + from fastapi.responses import JSONResponse, FileResponse + import uvicorn +except Exception: + FastAPI = None + uvicorn = None + +# --- Constants and Setup --- +BASE_DIR = os.path.dirname(__file__) +# Allow overriding data directory via environment variable WEBUI_DATA_DIR +DATA_DIR = os.path.abspath(os.getenv('WEBUI_DATA_DIR', os.path.join(BASE_DIR, 'data'))) +WORKFLOWS_DIR = os.path.join(DATA_DIR, 'workflows') +OUTPUT_DIR = os.path.join(DATA_DIR, 'outputs') +PRESETS_FILE = os.path.join(DATA_DIR, 'presets.json') +USER_CONFIG_FILE = os.path.join(DATA_DIR, 'user_config.json') +os.makedirs(OUTPUT_DIR, exist_ok=True) +os.makedirs(WORKFLOWS_DIR, exist_ok=True) +# --- Scheduler State --- +SCHEDULER_THREAD = None +SCHEDULER_STOP_EVENT = threading.Event() +SCHEDULER_STATUS = { + "running": False, + "interval": 10, + "last_run_time": "N/A", + "last_run_status": "Stopped", + "last_image": None +} +SCHEDULER_LOCK = threading.Lock() + +# --- Auto-save Config Manager (20s interval, debounced) --- +CONFIG_SAVE_INTERVAL = int(os.getenv('WEBUI_CONFIG_INTERVAL', '20')) # seconds +_pending_config = {} +_config_lock = threading.Lock() +_config_saver_thread = None +_config_changed = False + +def start_config_saver(): + """Start the background config saver thread.""" + global _config_saver_thread + if _config_saver_thread is None or not _config_saver_thread.is_alive(): + _config_saver_thread = threading.Thread(target=_config_saver_loop, daemon=True) + _config_saver_thread.start() + +def _flush_pending_config(): + """Flush any pending config changes immediately.""" + global _config_changed + with _config_lock: + if _config_changed and _pending_config: + try: + config = load_user_config() + config.update(_pending_config) + save_user_config(config) + print(f"[Auto-save] Configuration saved at {time.strftime('%H:%M:%S')}") + except Exception as e: + print(f"[Auto-save] Error saving config: {e}") + finally: + _pending_config.clear() + _config_changed = False + +def _config_saver_loop(): + """Background thread that saves config every CONFIG_SAVE_INTERVAL if changed.""" + while True: + time.sleep(CONFIG_SAVE_INTERVAL) + _flush_pending_config() + +def set_config_save_interval(seconds: int): + """Update autosave interval at runtime (min 5s).""" + global CONFIG_SAVE_INTERVAL + try: + seconds = int(seconds) + if seconds < 5: + seconds = 5 + except Exception: + seconds = 20 + CONFIG_SAVE_INTERVAL = seconds + queue_config_update(config_save_interval=seconds) + return seconds + +def queue_config_update(**kwargs): + """Queue config updates to be saved in the next interval.""" + global _config_changed + with _config_lock: + _pending_config.update(kwargs) + _config_changed = True + +# --- Preset Management Functions --- +def load_presets(): + """Loads presets from the JSON file. If not found, creates it with defaults.""" + default_presets = { + "None": {"positive": "", "negative": ""}, + "✨ 推荐风格": { + "positive": "best quality, very aesthetic, highres, absurdres, sensitive", + "negative": "lowres, (bad), bad feet, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, artistic error, username, scan, [abstract], english text, shiny_skin" + }, + "🎨 动漫风格": { + "positive": "masterpiece, best quality, anime, 1girl, beautiful detailed eyes, detailed face", + "negative": "photorealistic, 3d, extra limbs, bad anatomy, ugly, deformed" + }, + "📸 写实风格": { + "positive": "photorealistic, high quality, detailed, professional photography", + "negative": "anime, cartoon, drawing, painting, sketch" + } + } + + if not os.path.exists(PRESETS_FILE): + with open(PRESETS_FILE, 'w', encoding='utf-8') as f: + json.dump(default_presets, f, indent=4) + return default_presets + else: + try: + with open(PRESETS_FILE, 'r', encoding='utf-8') as f: + presets = json.load(f) + if "None" not in presets: + presets["None"] = {"positive": "", "negative": ""} + return presets + except (json.JSONDecodeError, IOError): + with open(PRESETS_FILE, 'w', encoding='utf-8') as f: + json.dump(default_presets, f, indent=4) + return default_presets + +def save_presets(presets_dict): + """Saves the given dictionary to the presets JSON file.""" + with open(PRESETS_FILE, 'w', encoding='utf-8') as f: + json.dump(presets_dict, f, indent=4) + +def combine_prompts(prefix, main_prompt): + """Combines prefix and main prompt intelligently.""" + if prefix and main_prompt: + return f"{prefix.strip()}, {main_prompt.strip()}" + elif prefix: + return prefix.strip() + elif main_prompt: + return main_prompt.strip() + return "" + +def select_preset(preset_name): + """Selects a preset and returns its values.""" + preset_data = GLOBAL_PRESETS.get(preset_name, {"positive": "", "negative": ""}) + return preset_name, preset_data["positive"], preset_data["negative"] + +def save_or_update_preset(preset_name, positive_prefix, negative_prefix): + """Saves or updates a preset.""" + if not preset_name or not preset_name.strip(): + return gr.update(), "Preset name cannot be empty." + + preset_name = preset_name.strip() + GLOBAL_PRESETS[preset_name] = {"positive": positive_prefix, "negative": negative_prefix} + save_presets(GLOBAL_PRESETS) + return gr.update(choices=list(GLOBAL_PRESETS.keys()), value=preset_name), f"Preset '{preset_name}' saved." + +def delete_preset(preset_name): + """Deletes a preset.""" + if not preset_name or preset_name.strip() in ["", "None"]: + return gr.update(), gr.update(), gr.update(), gr.update(), "Cannot delete this preset." + + preset_name = preset_name.strip() + if preset_name in GLOBAL_PRESETS: + del GLOBAL_PRESETS[preset_name] + save_presets(GLOBAL_PRESETS) + return (gr.update(choices=list(GLOBAL_PRESETS.keys()), value="None"), + "None", "", "", f"Preset '{preset_name}' deleted.") + return gr.update(), gr.update(), gr.update(), gr.update(), f"Preset '{preset_name}' not found." + +# Load presets globally +GLOBAL_PRESETS = load_presets() + +# --- User Config Management Functions --- +def load_user_config(): + """Loads user configuration from JSON file.""" + default_config = { + "server_address": "127.0.0.1:8188", + "model": "", + "sampler": "euler", + "scheduler": "normal", + "steps": 30, + "cfg": 6.0, + "width": 768, + "height": 1280, + "batch_size": 1, + "batch_count": 1, + "seed": 757831338432565, + "after_generate": "randomize", + "positive_prefix": "", + "negative_prefix": "", + "positive_prompt": "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master,", + "negative_prompt": "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin,", + "preset_name": "None", + "current_workflow": "workflow_template", + "language": "en", + "config_save_interval": 20, + # OpenAI API defaults (can override generator settings) + "api_server_address": "", + "api_model": "", + "api_sampler": "", + "api_scheduler": "", + "api_steps": 30, + "api_cfg": 6.0, + "api_width": 768, + "api_height": 1280, + "api_seed": 757831338432565, + "api_after_generate": "randomize", + "api_positive_prefix": "", + "api_negative_prefix": "", + "api_workflow": "workflow_template", + "api_return": "url", + "api_n": 1 + } + + if not os.path.exists(USER_CONFIG_FILE): + save_user_config(default_config) + return default_config + + try: + with open(USER_CONFIG_FILE, 'r', encoding='utf-8') as f: + config = json.load(f) + # 合并默认配置,确保新字段存在 + for key, value in default_config.items(): + if key not in config: + config[key] = value + return config + except (json.JSONDecodeError, IOError): + save_user_config(default_config) + return default_config + +def save_user_config(config_dict): + """Saves user configuration to JSON file.""" + with open(USER_CONFIG_FILE, 'w', encoding='utf-8') as f: + json.dump(config_dict, f, indent=4) + +def update_user_config(**kwargs): + """Updates specific configuration values.""" + config = load_user_config() + for key, value in kwargs.items(): + config[key] = value + save_user_config(config) + +# Load user config globally +USER_CONFIG = load_user_config() + +# --- Workflow Management Functions --- +def load_workflows(): + """Loads all workflow files from the workflows directory.""" + workflows = {} + if not os.path.exists(WORKFLOWS_DIR): + return workflows + + for filename in os.listdir(WORKFLOWS_DIR): + if filename.endswith('.json'): + workflow_name = filename[:-5] # Remove .json extension + workflow_path = os.path.join(WORKFLOWS_DIR, filename) + try: + with open(workflow_path, 'r', encoding='utf-8') as f: + workflows[workflow_name] = json.load(f) + except (json.JSONDecodeError, IOError) as e: + print(f"Error loading workflow {workflow_name}: {e}") + + return workflows + +def save_workflow(workflow_name, workflow_content): + """Saves a workflow to the workflows directory.""" + if not workflow_name or not workflow_name.strip(): + return False, "Workflow name cannot be empty." + + workflow_name = workflow_name.strip() + workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") + + try: + # Validate JSON content + json.loads(workflow_content) + with open(workflow_path, 'w', encoding='utf-8') as f: + f.write(workflow_content) + return True, f"Workflow '{workflow_name}' saved successfully." + except json.JSONDecodeError as e: + return False, f"Invalid JSON format: {e}" + except IOError as e: + return False, f"Error saving workflow: {e}" + +def delete_workflow(workflow_name): + """Deletes a workflow file.""" + if not workflow_name or workflow_name.strip() in ["", "workflow_template"]: + return False, "Cannot delete this workflow." + + workflow_name = workflow_name.strip() + workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") + + if os.path.exists(workflow_path): + try: + os.remove(workflow_path) + return True, f"Workflow '{workflow_name}' deleted successfully." + except IOError as e: + return False, f"Error deleting workflow: {e}" + else: + return False, f"Workflow '{workflow_name}' not found." + +def load_workflow_content(workflow_name): + """Loads the content of a specific workflow.""" + if not workflow_name or workflow_name == "workflow_template": + # Load default template + workflow_path = os.path.join(WORKFLOWS_DIR, "workflow_template.json") + else: + workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") + + if os.path.exists(workflow_path): + try: + with open(workflow_path, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + print(f"Error loading workflow {workflow_name}: {e}") + return None + return None + +# Load workflows globally +GLOBAL_WORKFLOWS = load_workflows() + +# --- OpenAI-compatible API Server (FastAPI) --- + + +def _ensure_fastapi_available(): + if FastAPI is None or uvicorn is None: + raise RuntimeError("FastAPI/uvicorn not installed. Please install with: pip install fastapi uvicorn") + +def _encode_image_b64(path: str) -> str: + with open(path, 'rb') as f: + return base64.b64encode(f.read()).decode('utf-8') + +def get_api_config(): + """ + Loads the user config and creates a consolidated config dictionary for the API, + applying API-specific overrides over the main generator settings. + """ + cfg = load_user_config() + return { + "server_address": cfg.get("api_server_address") or cfg.get("server_address", "127.0.0.1:8188"), + "model": cfg.get("api_model") or cfg.get("model", ""), + "sampler": cfg.get("api_sampler") or cfg.get("sampler", "euler"), + "scheduler": cfg.get("api_scheduler") or cfg.get("scheduler", "normal"), + "steps": cfg.get("api_steps", 30), + "cfg": cfg.get("api_cfg", 6.0), + "width": cfg.get("api_width", 768), + "height": cfg.get("api_height", 1280), + "seed": cfg.get("api_seed", 757831338432565), + "after_generate": cfg.get("api_after_generate") or cfg.get("after_generate", "randomize"), + "positive_prefix": cfg.get("api_positive_prefix") or cfg.get("positive_prefix", ""), + "negative_prefix": cfg.get("api_negative_prefix") or cfg.get("negative_prefix", ""), + "negative_prompt": cfg.get("negative_prompt", ""), + "current_workflow": cfg.get("api_workflow") or cfg.get("current_workflow", "workflow_template"), + "api_return": cfg.get("api_return", "url"), + "api_n": cfg.get("api_n", 1) + } + +def generate_image_sync(server_address, positive_prefix, negative_prefix, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, current_workflow): + """Synchronous image generation that returns list of saved file paths.""" + # Normalize server address + if not server_address.startswith("http://") and not server_address.startswith("https://"): + server_address = "http://" + server_address + server_address = server_address.rstrip('/') + + ws_address = "ws://" + server_address[len("http://"):] + if server_address.startswith("https://"): + ws_address = "wss://" + server_address[len("https://"):] + + client_id = str(uuid.uuid4()) + all_generated_images = [] + initial_seed = seed + + for i in range(batch_count): + if after_generate == "randomize": + current_seed = random.randint(0, 2**32 - 1) + elif after_generate == "increment": + current_seed = initial_seed + i + elif after_generate == "decrement": + current_seed = initial_seed - i + else: # "fixed" + current_seed = initial_seed + + ws = websocket.WebSocket() + try: + ws.connect(f"{ws_address}/ws?clientId={client_id}") + + workflow_content = load_workflow_content(current_workflow) + if workflow_content is None: + break + workflow_content = json.dumps(workflow_content) + + final_positive_prompt = combine_prompts(positive_prefix, positive_prompt) + final_negative_prompt = combine_prompts(negative_prefix, negative_prompt) + + workflow_content = workflow_content.replace('%prompt%', final_positive_prompt) + workflow_content = workflow_content.replace('%negative_prompt%', final_negative_prompt) + workflow_content = workflow_content.replace('%model%', model) + workflow_content = workflow_content.replace('%width%', str(width)) + workflow_content = workflow_content.replace('%height%', str(height)) + workflow_content = workflow_content.replace('%batch_size%', str(batch_size)) + workflow_content = workflow_content.replace('%seed%', str(current_seed)) + workflow_content = workflow_content.replace('%steps%', str(steps)) + workflow_content = workflow_content.replace('%cfg%', str(cfg)) + workflow_content = workflow_content.replace('%sampler%', sampler) + workflow_content = workflow_content.replace('%scheduler%', scheduler) + + prompt_workflow = json.loads(workflow_content) + prompt_data = queue_prompt(prompt_workflow, client_id, server_address) + prompt_id = prompt_data['prompt_id'] + + while True: + out = ws.recv() + if not isinstance(out, str): + continue + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break + + history = get_history(prompt_id, server_address)[prompt_id] + images_output = [] + for node_id in history['outputs']: + if 'images' in history['outputs'][node_id]: + for image in history['outputs'][node_id]['images']: + image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address) + images_output.append(image_data) + + if not images_output: + continue + + pil_images = [Image.open(io.BytesIO(data)) for data in images_output] + for img_idx, img in enumerate(pil_images): + filename = f"{int(time.time())}_{current_seed}_{img_idx}.png" + filepath = os.path.join(OUTPUT_DIR, filename) + img.save(filepath) + all_generated_images.append(filepath) + + finally: + if ws.connected: + ws.close() + + return all_generated_images + +def _create_openai_app(): + """Create FastAPI app with OpenAI-compatible routes.""" + app = FastAPI() + + @app.get("/v1/files/{filename}") + def get_file(filename: str): + path = os.path.join(OUTPUT_DIR, filename) + if os.path.exists(path): + return FileResponse(path) + return JSONResponse(status_code=404, content={"error": {"message": "File not found"}}) + + @app.post("/v1/chat/completions") + async def chat_completions(req: Request): + body = await req.json() + cfg = get_api_config() + model = body.get("model", "gpt-image-proxy") + messages = body.get("messages", []) + n = int(body.get("n", cfg["api_n"])) + return_type = body.get("response_format", {}).get("type", cfg["api_return"]) # "b64_json" or "url" + + # latest user message + user_text = "" + for m in reversed(messages): + if m.get("role") == "user": + content = m.get("content", "") + if isinstance(content, list): + user_text = " ".join([item.get("text", "") for item in content if item.get("type") == "text"]).strip() + else: + user_text = str(content) + break + + # Use a random seed for each request and randomize across images in n + req_seed = random.randint(0, 2**32 - 1) + filepaths = generate_image_sync( + cfg["server_address"], cfg["positive_prefix"], cfg["negative_prefix"], user_text, cfg["negative_prompt"], + cfg["model"], cfg["sampler"], cfg["scheduler"], cfg["steps"], cfg["cfg"], cfg["width"], cfg["height"], + int(req_seed), "randomize", 1, n, cfg["current_workflow"] + ) + + choices = [] + for idx, fp in enumerate(filepaths): + if return_type == "url": + content = f"/v1/files/{os.path.basename(fp)}" + else: + content = _encode_image_b64(fp) + choices.append({ + "index": idx, + "finish_reason": "stop", + "message": {"role": "assistant", "content": content} + }) + + resp = { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": choices, + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "x_comfy": {"seed": int(req_seed), "after_generate": "randomize"} + } + return JSONResponse(content=resp) + + return app + + + + + +# --- Debounced Save Functions (queued, saved every 20s) --- +def save_server_address(server_address): + queue_config_update(server_address=server_address) + return server_address + +def save_model(model): + queue_config_update(model=model) + return model + +def save_sampler(sampler): + queue_config_update(sampler=sampler) + return sampler + +def save_scheduler(scheduler): + queue_config_update(scheduler=scheduler) + return scheduler + +def save_steps(steps): + queue_config_update(steps=steps) + return steps + +def save_cfg(cfg): + queue_config_update(cfg=cfg) + return cfg + +def save_width(width): + queue_config_update(width=width) + return width + +def save_height(height): + queue_config_update(height=height) + return height + +def save_batch_size(batch_size): + queue_config_update(batch_size=batch_size) + return batch_size + +def save_batch_count(batch_count): + queue_config_update(batch_count=batch_count) + return batch_count + +def save_seed(seed): + queue_config_update(seed=seed) + return seed + +def save_after_generate(after_generate): + queue_config_update(after_generate=after_generate) + return after_generate + +def save_positive_prefix(positive_prefix): + queue_config_update(positive_prefix=positive_prefix) + return positive_prefix + +def save_negative_prefix(negative_prefix): + queue_config_update(negative_prefix=negative_prefix) + return negative_prefix + +def save_positive_prompt(positive_prompt): + queue_config_update(positive_prompt=positive_prompt) + return positive_prompt + +def save_negative_prompt(negative_prompt): + queue_config_update(negative_prompt=negative_prompt) + return negative_prompt + +def save_preset_name(preset_name): + queue_config_update(preset_name=preset_name) + return preset_name + +def save_current_workflow(current_workflow): + queue_config_update(current_workflow=current_workflow) + return current_workflow + +def load_ui_config(): + """Loads user configuration and returns it for UI initialization.""" + config = load_user_config() + + # Get server address and fetch available options + server_address = config.get("server_address", "127.0.0.1:8188") + if not server_address.startswith("http://") and not server_address.startswith("https://"): + server_address = "http://" + server_address + + object_info = get_object_info(server_address) + available_models = get_models(object_info) + available_samplers = get_samplers(object_info) + available_schedulers = get_schedulers(object_info) + + return ( + config.get("server_address", "127.0.0.1:8188"), + config.get("model", ""), + config.get("sampler", "euler"), + config.get("scheduler", "normal"), + config.get("steps", 30), + config.get("cfg", 6.0), + config.get("width", 768), + config.get("height", 1280), + config.get("batch_size", 1), + config.get("batch_count", 1), + config.get("seed", 757831338432565), + config.get("after_generate", "randomize"), + config.get("positive_prefix", ""), + config.get("negative_prefix", ""), + config.get("positive_prompt", "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master,"), + config.get("negative_prompt", "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin,"), + config.get("preset_name", "None"), + config.get("current_workflow", "workflow_template"), + gr.update(choices=available_models), + gr.update(choices=available_samplers), + gr.update(choices=available_schedulers), + gr.update(choices=list(GLOBAL_WORKFLOWS.keys())), + # API settings + config.get("api_return", "url"), + config.get("api_n", 1), + config.get("api_server_address", ""), + config.get("api_model", ""), + config.get("api_sampler", ""), + config.get("api_scheduler", ""), + config.get("api_steps", 30), + config.get("api_cfg", 6.0), + config.get("api_width", 768), + config.get("api_height", 1280), + config.get("api_seed", 757831338432565), + config.get("api_after_generate", "randomize"), + config.get("api_positive_prefix", ""), + config.get("api_negative_prefix", ""), + config.get("api_workflow", "workflow_template") + ) + +# --- ComfyUI API Functions --- +def get_image(filename, subfolder, folder_type, server_address): + """Fetches an image from the ComfyUI server.""" + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen(f"{server_address}/view?{url_values}") as response: + return response.read() + +def queue_prompt(prompt, client_id, server_address): + """Queues a prompt on the ComfyUI server.""" + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request(f"{server_address}/prompt", data=data) + response = urllib.request.urlopen(req) + return json.loads(response.read()) + +def get_history(prompt_id, server_address): + """Gets the history for a given prompt ID.""" + with urllib.request.urlopen(f"{server_address}/history/{prompt_id}") as response: + return json.loads(response.read()) + +def get_object_info(server_address): + """Gets object info from the ComfyUI server.""" + try: + with urllib.request.urlopen(f"{server_address}/object_info") as response: + return json.loads(response.read()) + except Exception as e: + print(f"Failed to fetch object info: {e}") + return None + +def get_models(object_info): + """Extracts a comprehensive list of models from object_info.""" + models = [] + if not object_info: + return ["model.safetensors"] + + if "CheckpointLoaderSimple" in object_info and "ckpt_name" in object_info["CheckpointLoaderSimple"]["input"]["required"]: + models.extend(object_info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0]) + if "UNETLoader" in object_info and "unet_name" in object_info["UNETLoader"]["input"]["required"]: + models.extend(object_info["UNETLoader"]["input"]["required"]["unet_name"][0]) + if "UnetLoaderGGUF" in object_info and "unet_name" in object_info["UnetLoaderGGUF"]["input"]["required"]: + models.extend(object_info["UnetLoaderGGUF"]["input"]["required"]["unet_name"][0]) + + if not models: + return ["model.safetensors"] + + return list(dict.fromkeys(models)) + +def get_samplers(object_info): + if object_info and "KSampler" in object_info: + return object_info["KSampler"]["input"]["required"]["sampler_name"][0] + return ["euler"] + +def get_schedulers(object_info): + if object_info and "KSampler" in object_info: + return object_info["KSampler"]["input"]["required"]["scheduler"][0] + return ["normal"] + +# --- UI Callback Functions --- +def update_choices(server_address): + """Callback function to update dropdown choices.""" + if not server_address: + return (gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[])) + + if not server_address.startswith("http://") and not server_address.startswith("https://"): + http_server_address = "http://" + server_address + else: + http_server_address = server_address + http_server_address = http_server_address.rstrip('/') + + object_info = get_object_info(http_server_address) + available_models = get_models(object_info) + available_samplers = get_samplers(object_info) + available_schedulers = get_schedulers(object_info) + + return ( + gr.update(choices=available_models, value=available_models[0] if available_models else None), + gr.update(choices=available_samplers, value=available_samplers[0] if available_samplers else None), + gr.update(choices=available_schedulers, value=available_schedulers[0] if available_schedulers else None) + ) + +# --- Scheduler Functions --- +# --- Scheduler Functions --- +def _scheduler_loop(interval_minutes, gen_args_dict): + """The actual background loop for the scheduler.""" + print(f"[Scheduler] Thread started. Interval: {interval_minutes} min.") + # Perform the first run immediately without waiting + first_run = True + + while not SCHEDULER_STOP_EVENT.is_set(): + wait_seconds = int(interval_minutes * 60) + + if not first_run: + print(f"[Scheduler] Waiting for {interval_minutes} minute(s)...") + # wait() returns True if the event was set, False if it timed out. + if SCHEDULER_STOP_EVENT.wait(timeout=wait_seconds): + break # Stop was requested during sleep. + + first_run = False + if SCHEDULER_STOP_EVENT.is_set(): break # Check again in case stop was called during generation + + try: + with SCHEDULER_LOCK: + SCHEDULER_STATUS["last_run_status"] = "Running generation..." + print("[Scheduler] Running scheduled generation...") + + # Use generate_image_sync as it's simpler and doesn't yield UI updates + filepaths = generate_image_sync(**gen_args_dict) + + with SCHEDULER_LOCK: + SCHEDULER_STATUS["last_run_time"] = time.strftime('%Y-%m-%d %H:%M:%S') + if filepaths: + SCHEDULER_STATUS["last_run_status"] = "Success" + SCHEDULER_STATUS["last_image"] = filepaths[0] + print(f"[Scheduler] Successfully generated {len(filepaths)} image(s). Last image: {filepaths[0]}") + else: + SCHEDULER_STATUS["last_run_status"] = "Success (no images)" + print("[Scheduler] Generation ran but produced no images.") + + except Exception as e: + error_message = f"Error: {type(e).__name__}" + with SCHEDULER_LOCK: + SCHEDULER_STATUS["last_run_time"] = time.strftime('%Y-%m-%d %H:%M:%S') + SCHEDULER_STATUS["last_run_status"] = error_message + print(f"[Scheduler] An error occurred during generation: {e}") + + print("[Scheduler] Loop finished.") + with SCHEDULER_LOCK: + SCHEDULER_STATUS["running"] = False + SCHEDULER_STATUS["last_run_status"] = "Stopped" + + +def start_scheduler(interval, server_address, *gen_args): + global SCHEDULER_THREAD + with SCHEDULER_LOCK: + if SCHEDULER_STATUS["running"]: + return "Scheduler is already running.", None + + arg_names = ["positive_prefix", "negative_prefix", "positive_prompt", "negative_prompt", "model", "sampler", "scheduler", "steps", "cfg", "width", "height", "seed", "after_generate", "batch_size", "batch_count", "current_workflow"] + + gen_args_dict = dict(zip(arg_names, gen_args)) + gen_args_dict["server_address"] = server_address + gen_args_dict["batch_count"] = 1 # Always 1 for scheduler + gen_args_dict["batch_size"] = 1 # Also force batch size to 1 to be safe + + SCHEDULER_STOP_EVENT.clear() + SCHEDULER_THREAD = threading.Thread(target=_scheduler_loop, args=(interval, gen_args_dict), daemon=True) + SCHEDULER_THREAD.start() + + SCHEDULER_STATUS["running"] = True + SCHEDULER_STATUS["interval"] = interval + SCHEDULER_STATUS["last_run_status"] = "Started, running first job..." + + print(f"[Scheduler] Started with interval {interval} minutes.") + # We can't return status here as it's not a generator anymore + return "Scheduler started. Status will update automatically.", None + +def stop_scheduler_global(): + global SCHEDULER_THREAD + with SCHEDULER_LOCK: + if not SCHEDULER_STATUS["running"]: + return "Scheduler is not running." + + print("[Scheduler] Stop requested.") + SCHEDULER_STOP_EVENT.set() + + if SCHEDULER_THREAD: + SCHEDULER_THREAD.join(timeout=10) + SCHEDULER_THREAD = None + + # The loop itself will update the status to "Stopped" + print("[Scheduler] Stopped.") + return "Scheduler stopping... Status will update." + +def get_scheduler_status_for_ui(): + with SCHEDULER_LOCK: + status = SCHEDULER_STATUS["last_run_status"] + image = SCHEDULER_STATUS["last_image"] + if SCHEDULER_STATUS["running"]: + status_text = f"Running (Interval: {SCHEDULER_STATUS['interval']} min). Last run: {SCHEDULER_STATUS['last_run_time']}. Status: {status}" + else: + status_text = f"Stopped. Last run: {SCHEDULER_STATUS['last_run_time']}. Status: {status}" + + image_list = [image] if image and os.path.exists(image) else None + return status_text, image_list + + +# --- History Management Functions --- +def get_history_images(): + """Returns a sorted list of images from the output directory.""" + if not os.path.exists(OUTPUT_DIR): + return [] + images = [os.path.join(OUTPUT_DIR, f) for f in os.listdir(OUTPUT_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp'))] + images.sort(key=os.path.getmtime, reverse=True) + return images + +def delete_image(filepaths): + """Deletes selected images and returns the updated list of images.""" + if isinstance(filepaths, list) and filepaths: + for item in filepaths: + # Handle both string paths and tuple (path, metadata) formats + if isinstance(item, tuple): + filepath = item[0] # Extract the file path from tuple + else: + filepath = item + + if filepath and os.path.exists(filepath): + try: + os.remove(filepath) + except Exception as e: + print(f"Error deleting file {filepath}: {e}") + return get_history_images() + +# --- Core Generation Logic --- +def generate_images(server_address, positive_prefix, negative_prefix, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, current_workflow): + """Main function to generate images based on UI inputs.""" + # Normalize server address + if not server_address.startswith("http://") and not server_address.startswith("https://"): + server_address = "http://" + server_address + server_address = server_address.rstrip('/') + + ws_address = "ws://" + server_address[len("http://"):] + if server_address.startswith("https://"): + ws_address = "wss://" + server_address[len("https://"):] + + client_id = str(uuid.uuid4()) + all_generated_images = [] + initial_seed = seed + + for i in range(batch_count): + yield f"Running batch {i+1}/{batch_count}...", all_generated_images + + if after_generate == "randomize": + current_seed = random.randint(0, 2**32 - 1) + elif after_generate == "increment": + current_seed = initial_seed + i + elif after_generate == "decrement": + current_seed = initial_seed - i + else: # "fixed" + current_seed = initial_seed + + ws = websocket.WebSocket() + try: + yield f"Batch {i+1}: Connecting...", all_generated_images + ws.connect(f"{ws_address}/ws?clientId={client_id}") + + # Load workflow content + workflow_content = load_workflow_content(current_workflow) + if workflow_content is None: + yield f"Error: Could not load workflow '{current_workflow}'", all_generated_images + break + workflow_content = json.dumps(workflow_content) + + # Combine prefix and main prompts + final_positive_prompt = combine_prompts(positive_prefix, positive_prompt) + final_negative_prompt = combine_prompts(negative_prefix, negative_prompt) + + # Replace placeholders with actual values + workflow_content = workflow_content.replace('%prompt%', final_positive_prompt) + workflow_content = workflow_content.replace('%negative_prompt%', final_negative_prompt) + workflow_content = workflow_content.replace('%model%', model) + workflow_content = workflow_content.replace('%width%', str(width)) + workflow_content = workflow_content.replace('%height%', str(height)) + workflow_content = workflow_content.replace('%batch_size%', str(batch_size)) + workflow_content = workflow_content.replace('%seed%', str(current_seed)) + workflow_content = workflow_content.replace('%steps%', str(steps)) + workflow_content = workflow_content.replace('%cfg%', str(cfg)) + workflow_content = workflow_content.replace('%sampler%', sampler) + workflow_content = workflow_content.replace('%scheduler%', scheduler) + + # Parse the modified workflow + prompt_workflow = json.loads(workflow_content) + + prompt_data = queue_prompt(prompt_workflow, client_id, server_address) + prompt_id = prompt_data['prompt_id'] + + while True: + out = ws.recv() + if not isinstance(out, str): continue + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break + else: + node_id = data['node'] + node_title = prompt_workflow.get(node_id, {}).get('_meta', {}).get('title', f"Node {node_id}") + yield f"Batch {i+1}: Executing {node_title}...", all_generated_images + + history = get_history(prompt_id, server_address)[prompt_id] + images_output = [] + for node_id in history['outputs']: + if 'images' in history['outputs'][node_id]: + for image in history['outputs'][node_id]['images']: + image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address) + images_output.append(image_data) + + if not images_output: + continue + + pil_images = [Image.open(io.BytesIO(data)) for data in images_output] + for img_idx, img in enumerate(pil_images): + filename = f"{int(time.time())}_{current_seed}_{img_idx}.png" + filepath = os.path.join(OUTPUT_DIR, filename) + img.save(filepath) + all_generated_images.insert(0, filepath) # Insert at beginning to show newest first + + except Exception as e: + yield f"Error in batch {i+1}: {e}", all_generated_images + break # Stop on error + finally: + if ws.connected: + ws.close() + + yield "Done!", all_generated_images + +# --- Gradio UI --- +def create_ui(): + # Load initial configuration + config = load_user_config() + # Start auto-save background thread (debounced every 20s) + start_config_saver() + # Set initial default values (will be overridden by load_ui_config on page load) + # Don't fetch from server during initialization to avoid validation errors + available_models = [] + available_samplers = [] + available_schedulers = [] + + # Set initial default values (will be overridden by load_ui_config on page load) + default_server_address = "127.0.0.1:8188" + default_model = "" + default_sampler = "euler" + default_scheduler = "normal" + default_steps = 30 + default_cfg = 6.0 + default_width = 768 + default_height = 1280 + default_batch_size = 1 + default_batch_count = 1 + default_seed = 757831338432565 + default_after_generate = "randomize" + default_positive_prefix = "" + default_negative_prefix = "" + default_positive = "best quality,very aesthetic,highres,absurdres,sensitive,A girl dressed in a maid costume with a personality, kneeling in front of her master," + default_negative = "lowres,(bad),bad feet,text,error,fewer,extra,missing,worst quality,jpeg artifacts,low quality,watermark,unfinished,displeasing,oldest,early,chromatic aberration,signature,artistic error,username,scan,[abstract],english text,shiny_skin," + default_preset_name = "None" + default_workflow = "workflow_template" + + css = """ + :root { font-family: sans-serif; } + #output_gallery img, #history_gallery img { border: 2px solid #e0e0e0; border-radius: 8px; } + """ + + with gr.Blocks(css=css, theme=gr.themes.Soft()) as app: + gr.Markdown("

ComfyUI Web Interface

") + + with gr.Tabs(): + with gr.TabItem("Generator"): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("

⚙️ Settings

") + with gr.Row(): + server_addr = gr.Textbox(label="Server Address", value=default_server_address, scale=3) + refresh_btn = gr.Button("🔄 Refresh", scale=1) + model = gr.Dropdown(label="Model (Checkpoint Name)", choices=[], value="") + + with gr.Accordion("Workflow", open=True): + workflow_selector = gr.Dropdown(label="Workflow Template", choices=list(GLOBAL_WORKFLOWS.keys()), value=default_workflow) + + with gr.Accordion("Sampling Parameters", open=True): + sampler = gr.Dropdown(label="Sampler", choices=[], value="euler") + scheduler = gr.Dropdown(label="Scheduler", choices=[], value="normal") + steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=default_steps) + cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=default_cfg) + + with gr.Accordion("Image Dimensions", open=True): + width = gr.Slider(label="Width", minimum=64, maximum=2048, step=64, value=default_width) + height = gr.Slider(label="Height", minimum=64, maximum=2048, step=64, value=default_height) + batch_size = gr.Slider(label="Batch Size (Images per generation)", minimum=1, maximum=16, step=1, value=default_batch_size) + batch_count = gr.Slider(label="Batch Count (Executions)", minimum=1, maximum=20, step=1, value=default_batch_count) + + # Place seed and after_generate within the left settings column to keep two-column layout + with gr.Row(): + seed = gr.Number(label="Seed", value=default_seed, precision=0) + after_generate = gr.Dropdown( + label="After Generate", + choices=["randomize", "increment", "decrement", "fixed"], + value=default_after_generate + ) + + + with gr.Column(scale=2): + gr.Markdown("

🎨 Prompts & Generation

") + + with gr.Accordion("Style Presets", open=True): + preset_selector = gr.Dropdown(label="Select Style", choices=list(GLOBAL_PRESETS.keys()), value=default_preset_name) + preset_name_input = gr.Textbox(label="Style Name (for saving)", lines=1) + positive_prefix_input = gr.Textbox(label="Positive Prefix", lines=3, interactive=True, value=default_positive_prefix) + negative_prefix_input = gr.Textbox(label="Negative Prefix", lines=3, interactive=True, value=default_negative_prefix) + with gr.Row(): + save_preset_btn = gr.Button("💾 Save / Update Style") + delete_preset_btn = gr.Button("🗑️ Delete Style", variant="stop") + preset_status_label = gr.Label(value="Select a style to apply, or edit the fields and save a new one.") + + positive_prompt = gr.Textbox(label="Positive Prompt (Your content)", lines=6, value=default_positive) + negative_prompt = gr.Textbox(label="Negative Prompt (Your content)", lines=3, value=default_negative) + generate_btn = gr.Button("Generate Image", variant="primary") + status_label = gr.Label(value="Idle", label="Status") + output_gallery = gr.Gallery(label="Generated Images", elem_id="output_gallery", columns=4) + + with gr.TabItem("Scheduler / Keep-Alive"): + gr.Markdown("## Scheduled Generation") + gr.Markdown("This feature will periodically run a generation task with the settings from the 'Generator' tab to keep a remote server active. It will always run with a 'Batch Count' of 1.") + scheduler_interval = gr.Number(label="Interval (minutes)", value=10, minimum=1, step=1) + with gr.Row(): + start_scheduler_btn = gr.Button("Start Scheduler") + stop_scheduler_btn = gr.Button("Stop Scheduler") + scheduler_status = gr.Label("Scheduler is stopped.") + scheduler_output = gr.Gallery(label="Last Scheduled Image", columns=1, height="auto") + + with gr.TabItem("History"): + with gr.Row(): + refresh_history_btn = gr.Button("🔄 Refresh History") + delete_btn = gr.Button("🗑️ Delete Selected Images") + history_gallery = gr.Gallery(label="Image History", elem_id="history_gallery", columns=8, allow_preview=True, preview=True) + + with gr.TabItem("API Settings"): + gr.Markdown("## OpenAI-compatible API Settings") + gr.Markdown("Here you can override the main generator settings for requests made to the OpenAI-compatible API. If a field is left blank, it will use the value from the main 'Generator' tab.") + with gr.Row(): + api_return = gr.Dropdown(label="Response Type", choices=["url", "b64_json"], value=USER_CONFIG.get("api_return", "url")) + api_n = gr.Slider(label="Images per request (n)", minimum=1, maximum=8, step=1, value=USER_CONFIG.get("api_n", 1)) + with gr.Accordion("Override Generation Parameters", open=False): + with gr.Row(): + api_server_addr = gr.Textbox(label="Server Address (override)", value=USER_CONFIG.get("api_server_address", "")) + api_model = gr.Textbox(label="Model (ckpt)", value=USER_CONFIG.get("api_model", "")) + with gr.Row(): + api_sampler = gr.Textbox(label="Sampler", value=USER_CONFIG.get("api_sampler", "")) + api_scheduler = gr.Textbox(label="Scheduler", value=USER_CONFIG.get("api_scheduler", "")) + with gr.Row(): + api_steps = gr.Number(label="Steps", value=USER_CONFIG.get("api_steps", 30), precision=0) + api_cfg = gr.Number(label="CFG", value=USER_CONFIG.get("api_cfg", 6.0)) + with gr.Row(): + api_width = gr.Number(label="Width", value=USER_CONFIG.get("api_width", 768), precision=0) + api_height = gr.Number(label="Height", value=USER_CONFIG.get("api_height", 1280), precision=0) + with gr.Row(): + api_seed = gr.Number(label="Seed", value=USER_CONFIG.get("api_seed", 757831338432565), precision=0) + api_after = gr.Dropdown(label="After Generate", choices=["randomize", "increment", "decrement", "fixed"], value=USER_CONFIG.get("api_after_generate", "randomize")) + with gr.Row(): + api_pos_prefix = gr.Textbox(label="Positive Prefix", lines=2, value=USER_CONFIG.get("api_positive_prefix", "")) + api_neg_prefix = gr.Textbox(label="Negative Prefix", lines=2, value=USER_CONFIG.get("api_negative_prefix", "")) + with gr.Row(): + api_workflow = gr.Dropdown(label="Workflow Template", choices=list(GLOBAL_WORKFLOWS.keys()), value=USER_CONFIG.get("api_workflow", "workflow_template")) + api_status = gr.Label("") + with gr.Row(): + api_save_cfg_btn = gr.Button("Save API Settings") + + with gr.TabItem("Settings"): + gr.Markdown("## Workflow Management") + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### Workflow List") + workflow_list = gr.Dropdown(label="Select Workflow", choices=list(GLOBAL_WORKFLOWS.keys()), value=default_workflow) + with gr.Row(): + load_workflow_btn = gr.Button("📂 Load Workflow") + delete_workflow_btn = gr.Button("🗑️ Delete Workflow", variant="stop") + workflow_status = gr.Label(value="Select a workflow to edit or create a new one.") + + with gr.Column(scale=2): + gr.Markdown("### Workflow Editor") + workflow_name_input = gr.Textbox(label="Workflow Name", lines=1, value="workflow_template") + workflow_content_input = gr.Textbox(label="Workflow JSON Content", lines=20, value="", max_lines=30) + with gr.Row(): + save_workflow_btn = gr.Button("💾 Save Workflow", variant="primary") + new_workflow_btn = gr.Button("➕ New Workflow") + workflow_editor_status = gr.Label(value="Edit the workflow JSON content above.") + + gr.Markdown("## Preferences") + with gr.Row(): + language_dropdown = gr.Dropdown(label="Language", choices=["en", "zh"], value=config.get("language", "en")) + autosave_interval = gr.Number(label="Autosave Interval (seconds)", value=config.get("config_save_interval", 20), minimum=5, step=1) + with gr.Row(): + save_prefs_btn = gr.Button("Save Preferences") + prefs_status = gr.Label("") + + # Define Inputs/Outputs for main generation + gen_inputs = [server_addr, positive_prefix_input, negative_prefix_input, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, workflow_selector] + gen_outputs = [status_label, output_gallery] + + # Wire up events + refresh_btn.click(fn=update_choices, inputs=server_addr, outputs=[model, sampler, scheduler]) + + # Real-time save events + server_addr.change(fn=save_server_address, inputs=server_addr, outputs=server_addr) + model.change(fn=save_model, inputs=model, outputs=model) + sampler.change(fn=save_sampler, inputs=sampler, outputs=sampler) + scheduler.change(fn=save_scheduler, inputs=scheduler, outputs=scheduler) + steps.change(fn=save_steps, inputs=steps, outputs=steps) + cfg.change(fn=save_cfg, inputs=cfg, outputs=cfg) + width.change(fn=save_width, inputs=width, outputs=width) + height.change(fn=save_height, inputs=height, outputs=height) + batch_size.change(fn=save_batch_size, inputs=batch_size, outputs=batch_size) + batch_count.change(fn=save_batch_count, inputs=batch_count, outputs=batch_count) + seed.change(fn=save_seed, inputs=seed, outputs=seed) + after_generate.change(fn=save_after_generate, inputs=after_generate, outputs=after_generate) + # Save text fields on blur instead of every keystroke + positive_prefix_input.blur(fn=save_positive_prefix, inputs=positive_prefix_input, outputs=positive_prefix_input) + negative_prefix_input.blur(fn=save_negative_prefix, inputs=negative_prefix_input, outputs=negative_prefix_input) + positive_prompt.blur(fn=save_positive_prompt, inputs=positive_prompt, outputs=positive_prompt) + negative_prompt.blur(fn=save_negative_prompt, inputs=negative_prompt, outputs=negative_prompt) + preset_selector.change(fn=save_preset_name, inputs=preset_selector, outputs=preset_selector) + workflow_selector.change(fn=save_current_workflow, inputs=workflow_selector, outputs=workflow_selector) + + # Preset events + preset_selector.change(fn=select_preset, inputs=preset_selector, outputs=[preset_name_input, positive_prefix_input, negative_prefix_input]) + save_preset_btn.click(fn=save_or_update_preset, inputs=[preset_name_input, positive_prefix_input, negative_prefix_input], outputs=[preset_selector, preset_status_label]) + delete_preset_btn.click(fn=delete_preset, inputs=[preset_name_input], outputs=[preset_selector, preset_name_input, positive_prefix_input, negative_prefix_input, preset_status_label]) + + gen_event = generate_btn.click(fn=generate_images, inputs=gen_inputs, outputs=gen_outputs) + gen_event.then(fn=get_history_images, outputs=history_gallery) + + # Scheduler Tab Events + scheduler_inputs = [scheduler_interval, server_addr, positive_prefix_input, negative_prefix_input, positive_prompt, negative_prompt, model, sampler, scheduler, steps, cfg, width, height, seed, after_generate, batch_size, batch_count, workflow_selector] + + start_scheduler_btn.click(fn=start_scheduler, inputs=scheduler_inputs, outputs=[scheduler_status, scheduler_output]) + stop_scheduler_btn.click(fn=stop_scheduler_global, inputs=None, outputs=scheduler_status) + + # History Tab Events + refresh_history_btn.click(fn=get_history_images, outputs=history_gallery) + delete_btn.click(fn=delete_image, inputs=history_gallery, outputs=history_gallery) + + # API Settings Tab Events + def save_api_settings(*api_args): + keys = [ + "api_return", "api_n", "api_server_address", "api_model", "api_sampler", + "api_scheduler", "api_steps", "api_cfg", "api_width", "api_height", + "api_seed", "api_after_generate", "api_positive_prefix", "api_negative_prefix", "api_workflow" + ] + # Convert to correct types + typed_args = list(api_args) + typed_args[1] = int(typed_args[1]) # api_n + typed_args[6] = int(typed_args[6]) # api_steps + typed_args[7] = float(typed_args[7]) # api_cfg + typed_args[8] = int(typed_args[8]) # api_width + typed_args[9] = int(typed_args[9]) # api_height + typed_args[10] = int(typed_args[10]) # api_seed + + api_config_dict = dict(zip(keys, typed_args)) + queue_config_update(**api_config_dict) + return "API settings saved (will be applied on next auto-save)." + + api_inputs = [ + api_return, api_n, api_server_addr, api_model, api_sampler, api_scheduler, + api_steps, api_cfg, api_width, api_height, api_seed, api_after, + api_pos_prefix, api_neg_prefix, api_workflow + ] + api_save_cfg_btn.click(fn=save_api_settings, inputs=api_inputs, outputs=api_status) + + # Workflow management events + def load_workflow_to_editor(workflow_name): + """Loads a workflow into the editor.""" + if not workflow_name or workflow_name == "workflow_template": + workflow_path = os.path.join(WORKFLOWS_DIR, "workflow_template.json") + else: + workflow_path = os.path.join(WORKFLOWS_DIR, f"{workflow_name}.json") + + if os.path.exists(workflow_path): + try: + with open(workflow_path, 'r', encoding='utf-8') as f: + content = f.read() + return workflow_name, content, f"Loaded workflow '{workflow_name}'" + except Exception as e: + return workflow_name, "", f"Error loading workflow: {e}" + else: + return workflow_name, "", f"Workflow '{workflow_name}' not found" + + def save_workflow_from_editor(workflow_name, workflow_content): + """Saves workflow from editor.""" + success, message = save_workflow(workflow_name, workflow_content) + if success: + # Reload workflows + global GLOBAL_WORKFLOWS + GLOBAL_WORKFLOWS = load_workflows() + return gr.update(choices=list(GLOBAL_WORKFLOWS.keys()), value=workflow_name), message + else: + return gr.update(), message + + def delete_workflow_from_editor(workflow_name): + """Deletes a workflow.""" + success, message = delete_workflow(workflow_name) + if success: + # Reload workflows + global GLOBAL_WORKFLOWS + GLOBAL_WORKFLOWS = load_workflows() + return (gr.update(choices=list(GLOBAL_WORKFLOWS.keys()), value="workflow_template"), + "workflow_template", "", message) + else: + return gr.update(), workflow_name, "", message + + def create_new_workflow(): + """Creates a new empty workflow.""" + return "new_workflow", "", "New workflow created. Enter a name and JSON content." + + # Wire up workflow management events + load_workflow_btn.click(fn=load_workflow_to_editor, inputs=workflow_list, outputs=[workflow_name_input, workflow_content_input, workflow_status]) + save_workflow_btn.click(fn=save_workflow_from_editor, inputs=[workflow_name_input, workflow_content_input], outputs=[workflow_list, workflow_editor_status]) + delete_workflow_btn.click(fn=delete_workflow_from_editor, inputs=workflow_name_input, outputs=[workflow_list, workflow_name_input, workflow_content_input, workflow_editor_status]) + new_workflow_btn.click(fn=create_new_workflow, outputs=[workflow_name_input, workflow_content_input, workflow_editor_status]) + + # Preferences events + def save_preferences(lang, interval): + queue_config_update(language=lang) + seconds = set_config_save_interval(interval) + return f"Saved. Language: {lang}, autosave: {seconds}s" + + save_prefs_btn.click(fn=save_preferences, inputs=[language_dropdown, autosave_interval], outputs=prefs_status) + + # --- App Load and Polling Events --- + + # This function will be polled to update dynamic UI elements + def poll_updates(): + history = get_history_images() + status_text, image_list = get_scheduler_status_for_ui() + return history, status_text, image_list + + # Load user config on page load (runs once) + app.load(fn=load_ui_config, outputs=[ + server_addr, model, sampler, scheduler, steps, cfg, width, height, + batch_size, batch_count, seed, after_generate, positive_prefix_input, + negative_prefix_input, positive_prompt, negative_prompt, preset_selector, workflow_selector, + model, sampler, scheduler, workflow_selector, + # API settings + api_return, api_n, api_server_addr, api_model, api_sampler, api_scheduler, + api_steps, api_cfg, api_width, api_height, api_seed, api_after, + api_pos_prefix, api_neg_prefix, api_workflow + ]) + + # Poll for history and scheduler status updates every 5 seconds + # Use a backward-compatible method for creating a timer + if hasattr(gr, 'Timer'): + # New way for Gradio 4.x and later + timer = gr.Timer(5) + timer.tick(fn=poll_updates, outputs=[history_gallery, scheduler_status, scheduler_output]) + else: + # Old way for Gradio 3.x + app.load(fn=poll_updates, outputs=[history_gallery, scheduler_status, scheduler_output], every=5) + return app + +if __name__ == "__main__": + import argparse + + if FastAPI is None or uvicorn is None: + raise RuntimeError("FastAPI/uvicorn not installed. Please install with: pip install fastapi uvicorn") + + # 1. Create the FastAPI app that will host the API. + api_app = _create_openai_app() + + # 2. Create the Gradio UI app. This also starts the config saver. + webui_app = create_ui() + + # 3. Mount the Gradio app at the root of the FastAPI app. + # The FastAPI app becomes the main entry point. + final_app = gr.mount_gradio_app(api_app, webui_app, path="/") + + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on.") + default_port = int(os.getenv('PORT', 7860)) + parser.add_argument("--port", type=int, default=default_port, help="Port to run the server on.") + args = parser.parse_args() + + print("---") + print(f"Starting server on {args.host}:{args.port}") + print("The Gradio UI will be at the root path.") + print(f"OpenAI-compatible API will be available under http://{args.host}:{args.port}/v1") + print("---") + + try: + uvicorn.run(final_app, host=args.host, port=args.port) + finally: + # Flush any pending config updates on exit + _flush_pending_config()