leewheel's picture
Massive sync batch 1
a57c95d verified
import os
import sys
import gc
import uuid
import random
import re
import datetime
import json
import tempfile
import locale
# =========================
# 语言检测
# =========================
try:
system_lang = locale.getdefaultlocale()[0]
is_chinese = system_lang and system_lang.startswith('zh')
except:
is_chinese = False
def get_message(key, *args):
messages = {
"peft_loaded": ("✅ PEFT 库已加载,LoRA 功能可用。", "✅ PEFT library loaded, LoRA functionality available."),
"peft_not_detected": ("⚠️ 警告: 未检测到 PEFT 库。LoRA 功能将禁用。", "⚠️ Warning: PEFT library not detected. LoRA functionality will be disabled."),
"lora_skipped": ("⚠️ [LoRA] 已跳过加载:PEFT 库未安装。", "⚠️ [LoRA] Skipped loading: PEFT library not installed."),
"transformer_not_loaded": ("⚠️ Transformer 未加载,无法应用 LoRA", "⚠️ Transformer not loaded, cannot apply LoRA"),
"lora_file_not_exist": ("⚠️ LoRA 文件不存在: {}", "⚠️ LoRA file does not exist: {}"),
"lora_loading": (" [LoRA] 正在加载: {} (权重: {} * {} = {:.2f})", " [LoRA] Loading: {} (weight: {} * {} = {:.2f})"),
"lora_loaded": ("✅ LoRA 加载成功: {}", "✅ LoRA loaded successfully: {}"),
"lora_failed": ("❌ LoRA 加载严重失败: {}", "❌ LoRA loading failed critically: {}"),
"applying_vae": ("正在应用自定义 VAE: {}", "Applying custom VAE: {}"),
"vae_loaded": ("✅ 自定义 VAE 加载成功", "✅ Custom VAE loaded successfully"),
"vae_failed": ("⚠️ 自定义 VAE 加载失败: {}", "⚠️ Custom VAE loading failed: {}"),
"forcing_to_ram": (" [System] 正在强制将模型搬运至 RAM (请稍候)...", " [System] Forcing model to RAM (please wait)..."),
"model_to_ram": (" [System] 模型已加载至 RAM。", " [System] Model loaded to RAM."),
"t2i_low_vram": (" [T2I] 已启用低显存优化模式", " [T2I] Low VRAM optimization mode enabled"),
"t2i_high_end": (" [T2I] 已启用高端机模式", " [T2I] High-end GPU mode enabled"),
"t2i_pipeline_loaded": ("✅ 文生图 Pipeline 加载完成", "✅ Text-to-Image Pipeline loaded"),
"i2i_pipeline_failed": ("加载图生图 Pipeline 失败:{}", "Failed to load Image-to-Image Pipeline: {}"),
"i2i_pipeline_loaded": ("✅ 图生图 Pipeline 加载完成", "✅ Image-to-Image Pipeline loaded"),
"i2i_low_vram": (" [I2I] 已启用低显存优化模式", " [I2I] Low VRAM optimization mode enabled"),
"i2i_high_end": (" [I2I] 已启用高端机模式", " [I2I] High-end GPU mode enabled"),
"generation_stopped": ("🛑 生成已被用户手动停止", "🛑 Generation stopped by user"),
"upload_image_first": ("⚠️ 请先上传图片!", "⚠️ Please upload an image first!"),
"i2i_model_failed": ("加载图生图模型失败: {}", "Failed to load Image-to-Image model: {}"),
"native_inpaint_failed": ("⚠️ 原生 Inpaint 失败 ({}),使用手动混合模式...", "⚠️ Native Inpaint failed ({}), using manual blending mode..."),
"paint_area": ("⚠️ 请使用画笔在图片上涂抹要修改的区域。", "⚠️ Please use the brush to paint the area to modify on the image."),
"mask_invalid": ("⚠️ Mask 无效,请确保涂抹了区域。", "⚠️ Mask invalid, please ensure an area is painted."),
"model_load_failed": ("模型加载失败: {}", "Model loading failed: {}"),
"inpainting_failed": ("局部重绘失败: {}", "Inpainting failed: {}"),
"generating": ("生成中", "Generating"),
"img2img_processing": ("图生图中", "Img2Img processing"),
}
zh, en = messages[key]
return (zh if is_chinese else en).format(*args)
# 环境配置
os.environ.pop("PYTHONHOME", None)
os.environ.pop("PYTHONPATH", None)
os.environ["DIFFUSERS_USE_PEFT_BACKEND"] = "true"
os.environ["PEFT_DEBUG"] = "false"
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps, ImageEnhance, ImageDraw
import gradio as gr
from diffusers import (
ZImagePipeline,
ZImageImg2ImgPipeline,
AutoencoderKL,
ZImageTransformer2DModel,
FlowMatchEulerDiscreteScheduler
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file
# =========================
# 检测 PEFT 环境
# =========================
PEFT_AVAILABLE = False
try:
import peft
from diffusers.utils import is_peft_available
if is_peft_available():
PEFT_AVAILABLE = True
print(get_message("peft_loaded"))
else:
raise ImportError
except ImportError:
print(get_message("peft_not_detected"))
# =========================
# 双语文本字典
# =========================
TEXT = {
"zh": {
"title": "# 🎨 Z-Image-Turbo Low Vram Edition",
"lang_btn": "EN",
"tab_generate": "图像生成", "tab_edit": "图片编辑", "tab_img2img": "图生图 (增强版)", "tab_inpaint": "局部重绘",
"prompt": "Prompt", "prompt_placeholder": "输入你的描述...", "negative_prompt": "负面提示词", "negative_placeholder": "low quality, blurry, bad anatomy",
"refresh_lora": "🔄 刷新 LoRA", "refresh_model": "🔄 刷新模型", "lora_label": "LoRA", "lora_strength": "LoRA 强度", "lora_weight": "权重",
"model_section": "### 模型选择/Model Selection", "transformer": "Transformer", "vae": "VAE", "vram_type": "显存类型",
"vram_low": "24GB以下 (优化模式)", "vram_high": "高端机模式 (>=24GB)", "device": "设备", "num_images": "生成张数",
"output_format": "输出格式", "width": "宽度", "height": "高度", "steps": "步数", "cfg": "CFG", "seed": "种子", "random_seed": "随机种子",
"generate": "🚀 生成", "stop": "🛑 停止生成", "gallery": "生成结果", "used_seed": "使用种子",
"edit_upload": "上传图片", "rotate": "旋转角度 (度)", "crop_x": "裁剪 X (%)", "crop_y": "裁剪 Y (%)", "crop_w": "裁剪宽度 (%)", "crop_h": "裁剪高度 (%)",
"hflip": "水平翻转", "vflip": "垂直翻转", "edit_btn": "开始编辑", "edited_image": "编辑后的图片",
"filter": "应用滤镜", "brightness": "亮度调整 (%)", "contrast": "对比度调整 (%)", "saturation": "饱和度调整 (%)",
"i2i_ref": "上传参考图", "i2i_prompt": "修改提示词", "i2i_ph": "描述你希望图中发生的变化...", "i2i_mode": "Img2Img 模式",
"i2i_mode_a": "A. 严格保结构(微调风格)", "i2i_mode_b": "B. 强烈听 prompt(允许大改)", "i2i_out_w": "输出宽 (0=自动)", "i2i_out_h": "输出高 (0=自动)",
"i2i_tip": "**提示:** 宽高都为0时自动保持上传图比例并接近1024。", "i2i_strength": "重绘强度", "i2i_btn": "🎨 开始修改", "i2i_note": "注:使用官方 Z-Image Img2Img 引擎。",
"inpaint_editor": "绘制 Mask (白色为修改区,黑色为保留区)", "inpaint_tip": "提示:先上传图片,然后用画笔涂抹要修改的区域。", "inpaint_upload": "上传原图并绘制", "inpaint_desc": "📖 使用指南:涂抹区域(白色/彩色)将被重新生成,未涂抹区域保持原样。",
},
"en": {
"title": "# 🎨 Z-Image-Turbo Low Vram Edition", "lang_btn": "中文",
"tab_generate": "Image Generation", "tab_edit": "Image Editing", "tab_img2img": "Img2Img (Enhanced)", "tab_inpaint": "Inpainting",
"prompt": "Prompt", "prompt_placeholder": "Enter your description...", "negative_prompt": "Negative Prompt", "negative_placeholder": "low quality, blurry",
"refresh_lora": "🔄 Refresh LoRA", "refresh_model": "🔄 Refresh Models", "lora_label": "LoRA", "lora_strength": "LoRA Strength", "lora_weight": "Weight",
"model_section": "### Model Selection", "transformer": "Transformer", "vae": "VAE", "vram_type": "VRAM Type",
"vram_low": "Under 24GB (Optimized)", "vram_high": "High-End GPU Mode (>=24GB)", "device": "Device", "num_images": "Number of Images",
"output_format": "Output Format", "width": "Width", "height": "Height", "steps": "Steps", "cfg": "CFG", "seed": "Seed", "random_seed": "Random Seed",
"generate": "🚀 Generate", "stop": "🛑 Stop Generation", "gallery": "Generated Images", "used_seed": "Used Seed",
"edit_upload": "Upload Image", "rotate": "Rotation (degrees)", "crop_x": "Crop X (%)", "crop_y": "Crop Y (%)", "crop_w": "Crop Width (%)", "crop_h": "Crop Height (%)",
"hflip": "Horizontal Flip", "vflip": "Vertical Flip", "edit_btn": "Apply Edit", "edited_image": "Edited Image",
"filter": "Apply Filter", "brightness": "Brightness (%)", "contrast": "Contrast (%)", "saturation": "Saturation (%)",
"i2i_ref": "Upload Reference", "i2i_prompt": "Modification Prompt", "i2i_ph": "Describe changes...", "i2i_mode": "Img2Img Mode",
"i2i_mode_a": "A. Strict Structure (Style tweak)", "i2i_mode_b": "B. Strong Prompt (Allow changes)", "i2i_out_w": "Output Width (0=Auto)", "i2i_out_h": "Output Height (0=Auto)",
"i2i_tip": "**Tip:** Auto ratio if both 0.", "i2i_strength": "Denoising Strength", "i2i_btn": "🎨 Start Modification", "i2i_note": "Using official Z-Image Img2Img engine.",
"inpaint_editor": "Draw Mask (White=Modify, Black=Keep)", "inpaint_tip": "Tip: Upload image, then paint area to modify.", "inpaint_upload": "Upload & Paint", "inpaint_desc": "📖 Guide: Painted areas (white/color) will be regenerated. Unpainted areas stay original.",
}
}
# =========================
# 路径配置
# =========================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_SNAPSHOT_DIR = os.path.join(BASE_DIR, "cache", "HF_HOME", "hub", "models--Tongyi-MAI--Z-Image-Turbo", "snapshots", "5f4b9cbb80cc95ba44fe6667dfd75710f7db2947")
if not os.path.exists(BASE_SNAPSHOT_DIR):
BASE_SNAPSHOT_DIR = os.path.join(BASE_DIR, "ckpts", "Z-Image-Turbo")
if not os.path.exists(BASE_SNAPSHOT_DIR):
BASE_SNAPSHOT_DIR = "."
TRANSFORMER_ROOT = os.path.join(BASE_SNAPSHOT_DIR, "transformer")
TEXT_ENCODER_ROOT = os.path.join(BASE_SNAPSHOT_DIR, "text_encoder")
VAE_ROOT = os.path.join(BASE_SNAPSHOT_DIR, "vae")
MOD_DIR = os.path.join(BASE_DIR, "MOD")
MOD_TRANSFORMER = os.path.join(MOD_DIR, "transformer")
MOD_VAE = os.path.join(MOD_DIR, "vae")
LORA_ROOT = os.path.join(BASE_DIR, "lora")
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
for p in [MOD_TRANSFORMER, MOD_VAE, LORA_ROOT, OUTPUT_DIR]:
os.makedirs(p, exist_ok=True)
pipe_t2i = None
pipe_i2i = None
current_model_config = {"transformer": "default", "vae": "default", "is_low_vram": True}
is_generating_interrupted = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
def auto_flush_vram():
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
# =========================
# 核心优化:LoRA 加载逻辑
# =========================
def apply_lora_to_pipeline(pipe_local, lora_choice, lora_alpha, lora_scale=1.0):
if not PEFT_AVAILABLE:
print(get_message("lora_skipped"))
return pipe_local
if pipe_local is None:
return pipe_local
if pipe_local.transformer is None:
print(get_message("transformer_not_loaded"))
return pipe_local
if hasattr(pipe_local, "unload_lora_weights"):
try:
pipe_local.unload_lora_weights()
except Exception:
pass
if not lora_choice or lora_choice.lower() == "none":
return pipe_local
lora_path = os.path.join(LORA_ROOT, lora_choice)
if not os.path.exists(lora_path):
print(get_message("lora_file_not_exist", lora_path))
return pipe_local
try:
raw_alpha = float(lora_alpha)
effective_alpha = raw_alpha * lora_scale
if effective_alpha <= 0:
return pipe_local
adapter_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(lora_choice)[0])
print(get_message("lora_loading", lora_choice, raw_alpha, lora_scale, effective_alpha))
pipe_local.load_lora_weights(
LORA_ROOT,
weight_name=lora_choice,
adapter_name=adapter_name
)
pipe_local.set_adapters([adapter_name], adapter_weights=[effective_alpha])
print(get_message("lora_loaded", adapter_name))
except Exception as e:
import traceback
print(get_message("lora_failed", e))
return pipe_local
def scan_lora_items():
if not os.path.isdir(LORA_ROOT):
return []
return sorted([f for f in os.listdir(LORA_ROOT) if f.lower().endswith((".safetensors", ".pt", ".pth"))])
def update_prompt_with_lora(prompt, lora_choice, lora_alpha):
prompt = (prompt or "").strip()
prompt_clean = re.sub(r"<lora:[^>]+>", "", prompt).strip()
if lora_choice and lora_choice.lower() != "none":
try:
alpha = float(lora_alpha)
except: alpha = 1.0
if alpha > 0:
name = os.path.splitext(lora_choice)[0]
alpha_str = f"{alpha:.2f}".rstrip("0").rstrip(".")
return f"{prompt_clean} <lora:{name}:{alpha_str}>"
return prompt_clean
# =========================
# 模型加载逻辑
# =========================
def load_t2i_pipeline(transformer_choice, vae_choice, is_low_vram):
global pipe_t2i, current_model_config
config_key = ("t2i", transformer_choice, vae_choice, is_low_vram)
if pipe_t2i is not None and current_model_config.get("t2i") == config_key:
return pipe_t2i
auto_flush_vram()
pipe_t2i = None
transformer = ZImageTransformer2DModel.from_pretrained(TRANSFORMER_ROOT, torch_dtype=DTYPE, local_files_only=True)
if transformer_choice != "default":
t_path = resolve_model_path(transformer_choice, MOD_TRANSFORMER)
if t_path:
if os.path.isdir(t_path):
custom_t = ZImageTransformer2DModel.from_pretrained(t_path, torch_dtype=DTYPE, local_files_only=True)
transformer = custom_t
else:
state = load_file(t_path, device="cpu")
processed = {}
prefix = "model.diffusion_model."
for k, v in state.items():
new_k = k[len(prefix):] if k.startswith(prefix) else k
processed[new_k] = v.to(DTYPE)
transformer.load_state_dict(processed, strict=False)
del state, processed
text_encoder = AutoModelForCausalLM.from_pretrained(TEXT_ENCODER_ROOT, torch_dtype=DTYPE, local_files_only=True)
pipe_t2i = ZImagePipeline.from_pretrained(
BASE_SNAPSHOT_DIR,
local_files_only=True,
transformer=transformer,
text_encoder=text_encoder,
)
pipe_t2i.to(dtype=DTYPE)
if vae_choice != "default":
v_path = resolve_model_path(vae_choice, MOD_VAE)
if v_path:
print(get_message("applying_vae", vae_choice))
vae_device_map = {"": "cpu"} if is_low_vram else None
try:
if os.path.isfile(v_path):
with tempfile.TemporaryDirectory() as tmpdir:
config_file_path = os.path.join(tmpdir, "config.json")
vae_config_dict = dict(pipe_t2i.vae.config)
with open(config_file_path, "w", encoding="utf-8") as f:
json.dump(vae_config_dict, f, indent=2)
try:
pipe_t2i.vae = AutoencoderKL.from_single_file(v_path, dtype=DTYPE, config=tmpdir, device_map=vae_device_map)
except TypeError:
pipe_t2i.vae = AutoencoderKL.from_single_file(v_path, torch_dtype=DTYPE, config=tmpdir, device_map=vae_device_map)
print(get_message("vae_loaded"))
else:
pipe_t2i.vae = AutoencoderKL.from_pretrained(v_path, torch_dtype=DTYPE, device_map=vae_device_map)
except Exception as e:
print(get_message("vae_failed", e))
if DEVICE == "cuda":
if is_low_vram:
print(get_message("forcing_to_ram"))
pipe_t2i.to("cpu")
print(get_message("model_to_ram"))
pipe_t2i.enable_sequential_cpu_offload()
print(get_message("t2i_low_vram"))
else:
pipe_t2i.to("cuda")
print(get_message("t2i_high_end"))
current_model_config["t2i"] = config_key
print("✅ 文生图 Pipeline 加载完成")
return pipe_t2i
def load_i2i_pipeline(transformer_choice, vae_choice, is_low_vram):
global pipe_i2i, current_model_config
config_key = ("i2i", transformer_choice, vae_choice, is_low_vram)
if pipe_i2i is not None and current_model_config.get("i2i") == config_key:
return pipe_i2i
auto_flush_vram()
pipe_i2i = None
transformer = ZImageTransformer2DModel.from_pretrained(TRANSFORMER_ROOT, torch_dtype=DTYPE, local_files_only=True)
if transformer_choice != "default":
t_path = resolve_model_path(transformer_choice, MOD_TRANSFORMER)
if t_path:
if os.path.isdir(t_path):
custom_t = ZImageTransformer2DModel.from_pretrained(t_path, torch_dtype=DTYPE, local_files_only=True)
transformer = custom_t
else:
state = load_file(t_path, device="cpu")
processed = {}
prefix = "model.diffusion_model."
for k, v in state.items():
new_k = k[len(prefix):] if k.startswith(prefix) else k
processed[new_k] = v.to(DTYPE)
transformer.load_state_dict(processed, strict=False)
del state, processed
try:
pipe_i2i = ZImageImg2ImgPipeline.from_pretrained(
BASE_SNAPSHOT_DIR,
local_files_only=True,
transformer=transformer,
)
except Exception as e:
raise gr.Error(f"加载图生图 Pipeline 失败:{str(e)}")
pipe_i2i.to(dtype=DTYPE)
if vae_choice != "default":
v_path = resolve_model_path(vae_choice, MOD_VAE)
if v_path:
print(get_message("applying_vae", vae_choice))
vae_device_map = {"": "cpu"} if is_low_vram else None
try:
if os.path.isfile(v_path):
with tempfile.TemporaryDirectory() as tmpdir:
config_file_path = os.path.join(tmpdir, "config.json")
vae_config_dict = dict(pipe_i2i.vae.config)
with open(config_file_path, "w", encoding="utf-8") as f:
json.dump(vae_config_dict, f, indent=2)
try:
pipe_i2i.vae = AutoencoderKL.from_single_file(v_path, dtype=DTYPE, config=tmpdir, device_map=vae_device_map)
except TypeError:
pipe_i2i.vae = AutoencoderKL.from_single_file(v_path, torch_dtype=DTYPE, config=tmpdir, device_map=vae_device_map)
print(get_message("vae_loaded"))
else:
pipe_i2i.vae = AutoencoderKL.from_pretrained(v_path, torch_dtype=DTYPE, device_map=vae_device_map)
except Exception as e:
print(get_message("vae_failed", e))
if DEVICE == "cuda":
if is_low_vram:
print(get_message("forcing_to_ram"))
pipe_i2i.to("cpu")
print(get_message("model_to_ram"))
pipe_i2i.enable_sequential_cpu_offload()
print(get_message("i2i_low_vram"))
else:
pipe_i2i.to("cuda")
print(get_message("i2i_high_end"))
current_model_config["i2i"] = config_key
print("✅ 图生图 Pipeline 加载完成")
return pipe_i2i
def interrupt_callback(pipe, step, timestep, callback_kwargs):
global is_generating_interrupted
if is_generating_interrupted:
raise gr.Error("🛑 生成已被用户手动停止")
return callback_kwargs
def scan_model_variants(root_dir):
if not os.path.isdir(root_dir):
return []
items = []
for name in os.listdir(root_dir):
path = os.path.join(root_dir, name)
if os.path.isdir(path):
if os.path.isfile(os.path.join(path, "config.json")):
items.append(name)
elif name.lower().endswith((".safetensors", ".bin")):
items.append(name)
return sorted(items)
def get_choices(mod_root):
return ["default"] + scan_model_variants(mod_root)
def resolve_model_path(choice, mod_root):
if choice == "default":
return None
path = os.path.join(mod_root, choice)
if os.path.exists(path):
return path
return None
def process_mask_for_inpaint(mask_image):
if mask_image is None:
return None
if mask_image.mode == 'RGBA':
import numpy as np
mask_array = np.array(mask_image)
alpha = mask_array[:, :, 3] if mask_array.shape[2] > 3 else None
rgb = mask_array[:, :, :3]
rgb_gray = np.dot(rgb, [0.299, 0.587, 0.114])
if alpha is not None:
mask_gray = np.where(alpha > 10, 255, 0).astype(np.uint8)
else:
mask_gray = np.where(rgb_gray > 10, 255, 0).astype(np.uint8)
mask = Image.fromarray(mask_gray, mode='L')
else:
if mask_image.mode != 'L':
mask_image = mask_image.convert('L')
mask = mask_image.point(lambda p: 255 if p > 10 else 0)
if mask.getextrema()[1] == 0:
return None
return mask
# =========================
# 生成与编辑函数
# =========================
def generate_image(prompt, lora_choice, lora_alpha, num_images, image_format,
width, height, num_inference_steps, guidance_scale, seed, randomize_seed,
transformer_choice, vae_choice, vram_type_str, progress=gr.Progress()):
global is_generating_interrupted
is_generating_interrupted = False
is_low_vram = "24GB" in vram_type_str or "Under 24GB" in vram_type_str or "24G以下" in vram_type_str or "24GB以下" in vram_type_str
pipe_local = load_t2i_pipeline(transformer_choice, vae_choice, is_low_vram)
pipe_local = apply_lora_to_pipeline(pipe_local, lora_choice, lora_alpha)
if randomize_seed:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(DEVICE).manual_seed(int(seed))
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
day_dir = os.path.join(OUTPUT_DIR, date_str)
os.makedirs(day_dir, exist_ok=True)
fmt_map = {"png": ("PNG", "png"), "jpeg": ("JPEG", "jpeg"), "webp": ("WEBP", "webp")}
pil_fmt, ext = fmt_map[image_format.lower()]
results = []
try:
for _ in progress.tqdm(range(int(num_images)), desc="生成中"):
if is_generating_interrupted:
break
img = pipe_local(
prompt=prompt.strip(),
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
callback_on_step_end=interrupt_callback,
).images[0]
filename = os.path.join(day_dir, f"{datetime.datetime.now():%H%M%S}_{uuid.uuid4().hex[:4]}.{ext}")
img.save(filename, format=pil_fmt)
results.append(filename)
finally:
auto_flush_vram()
return results, seed
def run_img2img_enhanced(input_image, prompt, negative_prompt, lora_choice, lora_alpha,
num_images, image_format,
out_w, out_h, i2i_mode, strength_ui, steps_ui, cfg_ui,
seed, randomize_seed,
transformer_choice, vae_choice, vram_type_str, progress=gr.Progress()):
global is_generating_interrupted
is_generating_interrupted = False
is_low_vram = "24GB" in vram_type_str or "Under 24GB" in vram_type_str or "24G以下" in vram_type_str or "24GB以下" in vram_type_str
if input_image is None:
raise gr.Error("⚠️ 请先上传图片!")
try:
pipe_local = load_i2i_pipeline(transformer_choice, vae_choice, is_low_vram)
except Exception as e:
if isinstance(e, gr.Error): raise e
raise gr.Error(f"加载图生图模型失败: {str(e)}")
if i2i_mode.startswith("A"):
lora_scale = 0.35
strength = 0.30
steps = 8
cfg = 1.0
else:
lora_scale = 0.65
strength = 0.45
steps = 6
cfg = 1.5
pipe_local = apply_lora_to_pipeline(pipe_local, lora_choice, lora_alpha, lora_scale)
final_strength = strength_ui
final_steps = int(steps_ui)
final_cfg = cfg_ui
if randomize_seed:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(DEVICE).manual_seed(int(seed))
orig_w, orig_h = input_image.size
if out_w == 0 or out_h == 0:
target_size = 1024
ratio = orig_w / orig_h
if ratio > 1:
w, h = target_size, int(target_size / ratio)
else:
w, h = int(target_size * ratio), target_size
else:
w, h = out_w, out_h
w = (w // 16) * 16
h = (h // 16) * 16
input_image = input_image.resize((w, h), Image.LANCZOS)
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
day_dir = os.path.join(OUTPUT_DIR, date_str)
os.makedirs(day_dir, exist_ok=True)
fmt_map = {"png": ("PNG", "png"), "jpeg": ("JPEG", "jpeg"), "webp": ("WEBP", "webp")}
pil_fmt, ext = fmt_map[image_format.lower()]
results = []
try:
for _ in progress.tqdm(range(int(num_images)), desc="图生图中"):
if is_generating_interrupted:
break
img = pipe_local(
prompt=prompt.strip(),
negative_prompt=negative_prompt.strip(),
image=input_image,
strength=final_strength,
num_inference_steps=final_steps,
guidance_scale=final_cfg,
generator=generator,
callback_on_step_end=interrupt_callback,
).images[0]
filename = os.path.join(day_dir, f"i2i_{datetime.datetime.now():%H%M%S}_{uuid.uuid4().hex[:4]}.{ext}")
img.save(filename, format=pil_fmt)
results.append(filename)
finally:
auto_flush_vram()
return results, seed
def run_inpainting(image_editor_data, prompt, negative_prompt, lora_choice, lora_alpha,
strength, steps, cfg, seed, randomize_seed,
transformer_choice, vae_choice, vram_type_str, progress=gr.Progress()):
global is_generating_interrupted
is_generating_interrupted = False
is_low_vram = "24GB" in vram_type_str or "Under 24GB" in vram_type_str or "24G以下" in vram_type_str or "24GB以下" in vram_type_str
input_image = None
mask_layer = None
if isinstance(image_editor_data, dict):
if 'background' in image_editor_data:
input_image = image_editor_data['background']
if image_editor_data.get('layers'):
mask_layer = image_editor_data['layers'][0]
elif isinstance(image_editor_data, (tuple, list)):
input_image = image_editor_data[0]
mask_layer = image_editor_data[1]
elif isinstance(image_editor_data, Image.Image):
input_image = image_editor_data
if input_image is None:
raise gr.Error("⚠️ 请先上传图片!")
if input_image.mode == 'RGBA':
background = Image.new('RGB', input_image.size, (255,255,255))
background.paste(input_image, (0, 0), input_image)
input_image = background
else:
input_image = input_image.convert("RGB")
if mask_layer is None:
raise gr.Error("⚠️ 请使用画笔在图片上涂抹要修改的区域。")
mask = process_mask_for_inpaint(mask_layer)
if mask is None:
raise gr.Error("⚠️ Mask 无效,请确保涂抹了区域。")
try:
pipe_local = load_i2i_pipeline(transformer_choice, vae_choice, is_low_vram)
except Exception as e:
raise gr.Error(f"模型加载失败: {str(e)}")
pipe_local = apply_lora_to_pipeline(pipe_local, lora_choice, lora_alpha, lora_scale=0.6)
if randomize_seed:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(DEVICE).manual_seed(int(seed))
orig_w, orig_h = input_image.size
if mask.size != (orig_w, orig_h):
mask = mask.resize((orig_w, orig_h), Image.LANCZOS)
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
day_dir = os.path.join(OUTPUT_DIR, date_str)
os.makedirs(day_dir, exist_ok=True)
result_img = None
try:
try:
result_img = pipe_local(
prompt=prompt.strip(),
negative_prompt=negative_prompt.strip(),
image=input_image,
mask_image=mask,
strength=float(strength),
num_inference_steps=int(steps),
guidance_scale=float(cfg),
generator=generator,
callback_on_step_end=interrupt_callback
).images[0]
except (TypeError, AttributeError) as e:
print(f"⚠️ 原生 Inpaint 失败 ({e}),使用手动混合模式...")
img_array = np.array(input_image).astype(np.float32) /255.0
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
mask_3d = np.expand_dims(mask_array, axis=2)
mask_3d = np.repeat(mask_3d,3, axis=2)
noise = np.random.randn(*img_array.shape).astype(np.float32) * 0.1
inpaint_input_array = img_array * (1 - mask_3d) + (img_array + noise) * mask_3d
inpaint_input_array = np.clip(inpaint_input_array, 0, 1)
inpaint_input = Image.fromarray((inpaint_input_array * 255).astype(np.uint8))
generated = pipe_local(
prompt=prompt.strip(),
negative_prompt=negative_prompt.strip(),
image=inpaint_input,
strength=float(strength),
num_inference_steps=int(steps),
guidance_scale=float(cfg),
generator=generator,
callback_on_step_end=interrupt_callback
).images[0]
if generated.size != (orig_w, orig_h):
generated = generated.resize((orig_w, orig_h), Image.LANCZOS)
gen_array = np.array(generated).astype(np.float32) / 255.0
orig_array = np.array(input_image).astype(np.float32) / 255.0
final_array = orig_array * (1 - mask_3d) + gen_array * mask_3d
final_array = np.clip(final_array, 0, 1)
result_img = Image.fromarray((final_array * 255).astype(np.uint8))
filename = os.path.join(day_dir, f"inpaint_{datetime.datetime.now():%H%M%S}_{uuid.uuid4().hex[:4]}.png")
result_img.save(filename)
except Exception as e:
if "任务已手动停止" in str(e): raise
import traceback
traceback.print_exc()
raise gr.Error(f"局部重绘失败: {str(e)}")
finally:
auto_flush_vram()
return [result_img], seed
def edit_image(image, angle, x, y, w, h, hflip, vflip, filter_name, brightness, contrast, saturation):
if image is None:
return None
img = image.copy()
if angle != 0:
img = img.rotate(angle, expand=True)
if x or y or w < 100 or h < 100:
ow, oh = img.size
left = int(ow * x / 100)
top = int(oh * y / 100)
right = int(ow * (x + w) / 100)
bottom = int(oh * (y + h) / 100)
img = img.crop((left, top, right, bottom))
if hflip:
img = ImageOps.mirror(img)
if vflip:
img = ImageOps.flip(img)
if filter_name:
filter_map = {
"模糊": ImageFilter.BLUR, "轮廓": ImageFilter.CONTOUR, "细节": ImageFilter.DETAIL,
"边缘增强": ImageFilter.EDGE_ENHANCE, "更多边缘增强": ImageFilter.EDGE_ENHANCE_MORE,
"浮雕": ImageFilter.EMBOSS, "查找边缘": ImageFilter.FIND_EDGES,
"锐化": ImageFilter.SHARPEN, "平滑": ImageFilter.SMOOTH, "更多平滑": ImageFilter.SMOOTH_MORE,
}
f = filter_map.get(filter_name)
if f:
img = img.filter(f)
if brightness != 0:
img = ImageEnhance.Brightness(img).enhance(1 + brightness / 100)
if contrast != 0:
img = ImageEnhance.Contrast(img).enhance(1 + contrast / 100)
if saturation != 0:
img = ImageEnhance.Color(img).enhance(1 + saturation / 100)
return img
# =========================
# Gradio 界面构建
# =========================
TOTAL_VRAM = torch.cuda.get_device_properties(0).total_memory if DEVICE == "cuda" else 0
DEFAULT_PERF_MODE = "高端机模式 (>=24GB)" if TOTAL_VRAM >= 24 * 1024**3 else "24GB以下 (优化模式)"
with gr.Blocks() as demo:
lang_state = gr.State("zh")
with gr.Row():
title_md = gr.Markdown(TEXT["zh"]["title"])
lang_btn = gr.Button(TEXT["zh"]["lang_btn"], size="sm")
with gr.Tabs() as tabs:
with gr.Tab(TEXT["zh"]["tab_generate"]) as tab_gen:
with gr.Row():
with gr.Column(scale=4):
prompt = gr.Textbox(label=TEXT["zh"]["prompt"], lines=4, placeholder=TEXT["zh"]["prompt_placeholder"])
with gr.Row():
refresh_lora = gr.Button(TEXT["zh"]["refresh_lora"], size="sm")
refresh_model_t2i = gr.Button(TEXT["zh"]["refresh_model"], size="sm")
lora_choices = ["None"] + scan_lora_items()
lora_drop = gr.Dropdown(label=TEXT["zh"]["lora_label"], choices=lora_choices, value="None")
lora_alpha = gr.Slider(0, 2, 1, step=0.05, label=TEXT["zh"]["lora_strength"])
model_section_md = gr.Markdown(TEXT["zh"]["model_section"])
with gr.Row():
transformer_choice = gr.Dropdown(label=TEXT["zh"]["transformer"], choices=get_choices(MOD_TRANSFORMER), value="default")
vae_choice = gr.Dropdown(label=TEXT["zh"]["vae"], choices=get_choices(MOD_VAE), value="default")
vram_type = gr.Radio(
[TEXT["zh"]["vram_low"], TEXT["zh"]["vram_high"]],
label=TEXT["zh"]["vram_type"],
value=DEFAULT_PERF_MODE
)
device_ui = gr.Radio(["cuda", "cpu"], label=TEXT["zh"]["device"], value="cuda" if torch.cuda.is_available() else "cpu", visible=False)
num_images = gr.Slider(1, 8, 1, step=1, label=TEXT["zh"]["num_images"])
image_format = gr.Dropdown(["png", "jpeg", "webp"], value="png", label=TEXT["zh"]["output_format"])
with gr.Row():
width = gr.Slider(512, 2048, 1024, step=64, label=TEXT["zh"]["width"])
height = gr.Slider(512, 2048, 1024, step=64, label=TEXT["zh"]["height"])
num_inference_steps = gr.Slider(1, 50, 10, step=1, label=TEXT["zh"]["steps"])
guidance_scale = gr.Slider(0, 10, 0, step=0.1, label=TEXT["zh"]["cfg"])
seed = gr.Number(label=TEXT["zh"]["seed"], value=42, precision=0)
randomize_seed = gr.Checkbox(label=TEXT["zh"]["random_seed"], value=True)
with gr.Row():
generate_btn = gr.Button(TEXT["zh"]["generate"], variant="primary", size="lg")
stop_btn = gr.Button(TEXT["zh"]["stop"], variant="stop", size="lg", interactive=False)
with gr.Column(scale=6):
gallery = gr.Gallery(label=TEXT["zh"]["gallery"], columns=2, height="80vh")
used_seed = gr.Number(label=TEXT["zh"]["used_seed"], interactive=False)
with gr.Tab(TEXT["zh"]["tab_edit"]) as tab_edit:
with gr.Row():
with gr.Column():
image_input = gr.Image(label=TEXT["zh"]["edit_upload"], type="pil")
with gr.Group():
rotate_angle = gr.Slider(-360, 360, 0, step=1, label=TEXT["zh"]["rotate"])
crop_x = gr.Slider(0, 100, 0, step=1, label=TEXT["zh"]["crop_x"])
crop_y = gr.Slider(0, 100, 0, step=1, label=TEXT["zh"]["crop_y"])
crop_width = gr.Slider(0, 100, 100, step=1, label=TEXT["zh"]["crop_w"])
crop_height = gr.Slider(0, 100, 100, step=1, label=TEXT["zh"]["crop_h"])
flip_horizontal = gr.Checkbox(label=TEXT["zh"]["hflip"])
flip_vertical = gr.Checkbox(label=TEXT["zh"]["vflip"])
edit_btn = gr.Button(TEXT["zh"]["edit_btn"], variant="primary")
with gr.Column():
edited_image_output = gr.Image(label=TEXT["zh"]["edited_image"], type="pil")
with gr.Group():
apply_filter = gr.Dropdown(
["模糊", "轮廓", "细节", "边缘增强", "更多边缘增强", "浮雕", "查找边缘", "锐化", "平滑", "更多平滑"],
label=TEXT["zh"]["filter"]
)
brightness = gr.Slider(-100, 100, 0, step=1, label=TEXT["zh"]["brightness"])
contrast = gr.Slider(-100, 100, 0, step=1, label=TEXT["zh"]["contrast"])
saturation = gr.Slider(-100, 100, 0, step=1, label=TEXT["zh"]["saturation"])
with gr.Tab(TEXT["zh"]["tab_img2img"]) as tab_img2img:
i2i_status_md = gr.Markdown(TEXT["zh"]["i2i_note"])
with gr.Row():
with gr.Column(scale=4):
i2i_image_input = gr.Image(label=TEXT["zh"]["i2i_ref"], type="pil")
i2i_prompt = gr.Textbox(label=TEXT["zh"]["i2i_prompt"], lines=3, placeholder=TEXT["zh"]["i2i_ph"])
i2i_negative_prompt = gr.Textbox(label=TEXT["zh"]["negative_prompt"], lines=2, placeholder=TEXT["zh"]["negative_placeholder"])
with gr.Row():
i2i_refresh_lora = gr.Button(TEXT["zh"]["refresh_lora"], size="sm")
i2i_refresh_model = gr.Button(TEXT["zh"]["refresh_model"], size="sm")
i2i_lora_choices = ["None"] + scan_lora_items()
i2i_lora_drop = gr.Dropdown(label=TEXT["zh"]["lora_label"], choices=i2i_lora_choices, value="None")
i2i_lora_alpha = gr.Slider(0, 2, 1, step=0.05, label=TEXT["zh"]["lora_strength"])
with gr.Accordion(TEXT["zh"]["model_section"], open=False):
i2i_transformer_choice = gr.Dropdown(label=TEXT["zh"]["transformer"], choices=get_choices(MOD_TRANSFORMER), value="default")
i2i_vae_choice = gr.Dropdown(label=TEXT["zh"]["vae"], choices=get_choices(MOD_VAE), value="default")
i2i_vram_type = gr.Radio(
[TEXT["zh"]["vram_low"], TEXT["zh"]["vram_high"]],
label=TEXT["zh"]["vram_type"],
value=DEFAULT_PERF_MODE
)
i2i_mode = gr.Radio(
[TEXT["zh"]["i2i_mode_a"], TEXT["zh"]["i2i_mode_b"]],
label=TEXT["zh"]["i2i_mode"],
value=TEXT["zh"]["i2i_mode_a"]
)
with gr.Row():
i2i_out_w = gr.Slider(0, 2048, 0, step=16, label=TEXT["zh"]["i2i_out_w"])
i2i_out_h = gr.Slider(0, 2048, 0, step=16, label=TEXT["zh"]["i2i_out_h"])
i2i_tip_md = gr.Markdown(TEXT["zh"]["i2i_tip"])
i2i_strength = gr.Slider(0.1, 1.0, 0.4, step=0.05, label=TEXT["zh"]["i2i_strength"])
i2i_steps = gr.Slider(1, 50, 6, step=1, label=TEXT["zh"]["steps"])
i2i_cfg = gr.Slider(0.0, 5.0, 1.0, step=0.1, label=TEXT["zh"]["cfg"])
i2i_num_images = gr.Slider(1, 4, 1, step=1, label=TEXT["zh"]["num_images"])
i2i_image_format = gr.Dropdown(["png", "jpeg", "webp"], value="png", label=TEXT["zh"]["output_format"])
i2i_seed = gr.Number(label=TEXT["zh"]["seed"], value=42, precision=0)
i2i_randomize_seed = gr.Checkbox(label=TEXT["zh"]["random_seed"], value=True)
with gr.Row():
i2i_generate_btn = gr.Button(TEXT["zh"]["i2i_btn"], variant="primary", size="lg")
i2i_stop_btn = gr.Button(TEXT["zh"]["stop"], variant="stop", size="lg", interactive=False)
with gr.Column(scale=6):
i2i_gallery = gr.Gallery(label=TEXT["zh"]["gallery"], columns=2, height="80vh")
i2i_used_seed = gr.Number(label=TEXT["zh"]["used_seed"], interactive=False)
with gr.Tab(TEXT["zh"]["tab_inpaint"]) as tab_inpaint:
with gr.Row():
with gr.Column(scale=4):
inpaint_editor = gr.ImageEditor(
label=TEXT["zh"]["inpaint_upload"],
type="pil",
layers=True,
eraser=True,
brush=gr.Brush(colors=["#FFFFFF", "#000000", "#FF0000"], color_mode="fixed")
)
inpaint_tip_md = gr.Markdown(TEXT["zh"]["inpaint_desc"])
inpaint_prompt = gr.Textbox(label=TEXT["zh"]["i2i_prompt"], lines=3, placeholder=TEXT["zh"]["i2i_ph"])
inpaint_negative_prompt = gr.Textbox(label=TEXT["zh"]["negative_prompt"], lines=2, placeholder=TEXT["zh"]["negative_placeholder"])
with gr.Row():
inpaint_refresh_lora = gr.Button(TEXT["zh"]["refresh_lora"], size="sm")
inpaint_refresh_model = gr.Button(TEXT["zh"]["refresh_model"], size="sm")
inpaint_lora_choices = ["None"] + scan_lora_items()
inpaint_lora_drop = gr.Dropdown(label=TEXT["zh"]["lora_label"], choices=inpaint_lora_choices, value="None")
inpaint_lora_alpha = gr.Slider(0, 2, 1, step=0.05, label=TEXT["zh"]["lora_strength"])
with gr.Accordion(TEXT["zh"]["model_section"], open=False):
inpaint_transformer_choice = gr.Dropdown(label=TEXT["zh"]["transformer"], choices=get_choices(MOD_TRANSFORMER), value="default")
inpaint_vae_choice = gr.Dropdown(label=TEXT["zh"]["vae"], choices=get_choices(MOD_VAE), value="default")
inpaint_vram_type = gr.Radio(
[TEXT["zh"]["vram_low"], TEXT["zh"]["vram_high"]],
label=TEXT["zh"]["vram_type"],
value=DEFAULT_PERF_MODE
)
inpaint_strength = gr.Slider(0.1, 1.0, 0.7, step=0.05, label=TEXT["zh"]["i2i_strength"])
inpaint_steps = gr.Slider(1, 50, 8, step=1, label=TEXT["zh"]["steps"])
inpaint_cfg = gr.Slider(0.0, 5.0, 1.0, step=0.1, label=TEXT["zh"]["cfg"])
inpaint_seed = gr.Number(label=TEXT["zh"]["seed"], value=42, precision=0)
inpaint_randomize_seed = gr.Checkbox(label=TEXT["zh"]["random_seed"], value=True)
with gr.Row():
inpaint_generate_btn = gr.Button(TEXT["zh"]["i2i_btn"], variant="primary", size="lg")
inpaint_stop_btn = gr.Button(TEXT["zh"]["stop"], variant="stop", size="lg", interactive=False)
with gr.Column(scale=6):
inpaint_gallery = gr.Gallery(label=TEXT["zh"]["gallery"], columns=2, height="80vh")
inpaint_used_seed = gr.Number(label=TEXT["zh"]["used_seed"], interactive=False)
def switch_language_full(lang):
new_lang = "en" if lang == "zh" else "zh"
t = TEXT[new_lang]
# 修复:根据硬件显存大小,决定当前应该选哪个语言版本的选项
is_low_vram_hardware = TOTAL_VRAM < 24 * 1024**3
current_vram_val = t['vram_low'] if is_low_vram_hardware else t['vram_high']
# 修复:更新显存选项的值,不仅仅是选项列表
return (
new_lang, t['title'], t['lang_btn'],
gr.update(label=t['tab_generate']), gr.update(label=t['tab_edit']),
gr.update(label=t['tab_img2img']), gr.update(label=t['tab_inpaint']),
gr.update(label=t['prompt'], placeholder=t['prompt_placeholder']),
gr.update(value=t['refresh_lora']), gr.update(value=t['refresh_model']),
gr.update(label=t['lora_label']), gr.update(label=t['lora_strength']),
t['model_section'], gr.update(label=t['transformer']), gr.update(label=t['vae']),
# T2I VRAM: 更新选项和值
gr.update(label=t['vram_type'], choices=[t['vram_low'], t['vram_high']], value=current_vram_val),
gr.update(label=t['device']),
gr.update(label=t['num_images']), gr.update(label=t['output_format']),
gr.update(label=t['width']), gr.update(label=t['height']),
gr.update(label=t['steps']), gr.update(label=t['cfg']),
gr.update(label=t['seed']), gr.update(label=t['random_seed']),
gr.update(value=t['generate']), gr.update(value=t['stop']),
gr.update(label=t['gallery']), gr.update(label=t['used_seed']),
gr.update(label=t['edit_upload']),
gr.update(label=t['rotate']), gr.update(label=t['crop_x']), gr.update(label=t['crop_y']),
gr.update(label=t['crop_w']), gr.update(label=t['crop_h']),
gr.update(label=t['hflip']), gr.update(label=t['vflip']),
gr.update(value=t['edit_btn']), gr.update(label=t['edited_image']),
gr.update(label=t['filter']), gr.update(label=t['brightness']), gr.update(label=t['contrast']), gr.update(label=t['saturation']),
gr.update(value=t['i2i_note']),
gr.update(label=t['i2i_ref']),
gr.update(label=t['i2i_prompt'], placeholder=t['i2i_ph']),
gr.update(label=t['negative_prompt'], placeholder=t['negative_placeholder']),
gr.update(value=t['refresh_lora']), gr.update(value=t['refresh_model']),
gr.update(label=t['lora_label']), gr.update(label=t['lora_strength']),
gr.update(label=t['transformer']), gr.update(label=t['vae']),
# Img2Img VRAM: 更新选项和值
gr.update(label=t['vram_type'], choices=[t['vram_low'], t['vram_high']], value=current_vram_val),
gr.update(label=t['i2i_mode'], choices=[t['i2i_mode_a'], t['i2i_mode_b']]),
gr.update(label=t['i2i_out_w']), gr.update(label=t['i2i_out_h']),
gr.update(value=t['i2i_tip']),
gr.update(label=t['i2i_strength']),
gr.update(label=t['steps']), gr.update(label=t['cfg']),
gr.update(label=t['num_images']), gr.update(label=t['output_format']),
gr.update(label=t['seed']), gr.update(label=t['random_seed']),
gr.update(value=t['i2i_btn']), gr.update(value=t['stop']),
gr.update(label=t['gallery']), gr.update(label=t['used_seed']),
gr.update(label=t['inpaint_upload']),
gr.update(value=t['inpaint_desc']),
gr.update(label=t['i2i_prompt'], placeholder=t['i2i_ph']),
gr.update(label=t['negative_prompt'], placeholder=t['negative_placeholder']),
gr.update(value=t['refresh_lora']), gr.update(value=t['refresh_model']),
gr.update(label=t['lora_label']), gr.update(label=t['lora_strength']),
gr.update(label=t['transformer']), gr.update(label=t['vae']),
# Inpaint VRAM: 更新选项和值
gr.update(label=t['vram_type'], choices=[t['vram_low'], t['vram_high']], value=current_vram_val),
gr.update(label=t['i2i_strength']),
gr.update(label=t['steps']), gr.update(label=t['cfg']),
gr.update(label=t['seed']), gr.update(label=t['random_seed']),
gr.update(value=t['i2i_btn']), gr.update(value=t['stop']),
gr.update(label=t['gallery']), gr.update(label=t['used_seed']),
)
lang_btn.click(
fn=switch_language_full,
inputs=lang_state,
outputs=[
lang_state, title_md, lang_btn,
tab_gen, tab_edit, tab_img2img, tab_inpaint,
prompt, refresh_lora, refresh_model_t2i, lora_drop, lora_alpha, model_section_md,
transformer_choice, vae_choice, vram_type, device_ui, num_images, image_format,
width, height, num_inference_steps, guidance_scale, seed, randomize_seed,
generate_btn, stop_btn, gallery, used_seed,
image_input, rotate_angle, crop_x, crop_y, crop_width, crop_height,
flip_horizontal, flip_vertical, edit_btn, edited_image_output,
apply_filter, brightness, contrast, saturation,
i2i_status_md, i2i_image_input, i2i_prompt, i2i_negative_prompt,
i2i_refresh_lora, i2i_refresh_model, i2i_lora_drop, i2i_lora_alpha,
i2i_transformer_choice, i2i_vae_choice, i2i_vram_type, i2i_mode,
i2i_out_w, i2i_out_h, i2i_tip_md,
i2i_strength, i2i_steps, i2i_cfg,
i2i_num_images, i2i_image_format, i2i_seed, i2i_randomize_seed,
i2i_generate_btn, i2i_stop_btn, i2i_gallery, i2i_used_seed,
inpaint_editor, inpaint_tip_md,
inpaint_prompt, inpaint_negative_prompt,
inpaint_refresh_lora, inpaint_refresh_model, inpaint_lora_drop, inpaint_lora_alpha,
inpaint_transformer_choice, inpaint_vae_choice, inpaint_vram_type,
inpaint_strength, inpaint_steps, inpaint_cfg,
inpaint_seed, inpaint_randomize_seed,
inpaint_generate_btn, inpaint_stop_btn, inpaint_gallery, inpaint_used_seed
]
)
refresh_lora.click(fn=scan_lora_items, outputs=[lora_drop, i2i_lora_drop, inpaint_lora_drop])
lora_drop.change(update_prompt_with_lora, [prompt, lora_drop, lora_alpha], prompt)
def refresh_models_t2i():
return gr.update(choices=get_choices(MOD_TRANSFORMER)), gr.update(choices=get_choices(MOD_VAE))
refresh_model_t2i.click(fn=refresh_models_t2i, outputs=[transformer_choice, vae_choice])
def start_gen(): return gr.update(interactive=False), gr.update(interactive=True)
def end_gen(): return gr.update(interactive=True), gr.update(interactive=False)
def trigger_stop():
global is_generating_interrupted
is_generating_interrupted = True
generate_event = generate_btn.click(fn=start_gen, outputs=[generate_btn, stop_btn]).then(
fn=generate_image,
inputs=[prompt, lora_drop, lora_alpha, num_images, image_format,
width, height, num_inference_steps, guidance_scale, seed, randomize_seed,
transformer_choice, vae_choice, vram_type],
outputs=[gallery, used_seed]
).then(fn=end_gen, outputs=[generate_btn, stop_btn])
stop_btn.click(fn=trigger_stop).then(fn=end_gen, outputs=[generate_btn, stop_btn], cancels=[generate_event])
i2i_refresh_lora.click(fn=scan_lora_items, outputs=[lora_drop, i2i_lora_drop, inpaint_lora_drop])
i2i_lora_drop.change(update_prompt_with_lora, [i2i_prompt, i2i_lora_drop, i2i_lora_alpha], i2i_prompt)
def refresh_models_i2i():
return gr.update(choices=get_choices(MOD_TRANSFORMER)), gr.update(choices=get_choices(MOD_VAE))
i2i_refresh_model.click(fn=refresh_models_i2i, outputs=[i2i_transformer_choice, i2i_vae_choice])
def start_i2i(): return gr.update(interactive=False), gr.update(interactive=True)
def end_i2i(): return gr.update(interactive=True), gr.update(interactive=False)
i2i_generate_event = i2i_generate_btn.click(fn=start_i2i, outputs=[i2i_generate_btn, i2i_stop_btn]).then(
fn=run_img2img_enhanced,
inputs=[i2i_image_input, i2i_prompt, i2i_negative_prompt, i2i_lora_drop, i2i_lora_alpha,
i2i_num_images, i2i_image_format,
i2i_out_w, i2i_out_h, i2i_mode, i2i_strength, i2i_steps, i2i_cfg,
i2i_seed, i2i_randomize_seed,
i2i_transformer_choice, i2i_vae_choice, i2i_vram_type],
outputs=[i2i_gallery, i2i_used_seed]
).then(fn=end_i2i, outputs=[i2i_generate_btn, i2i_stop_btn])
i2i_stop_btn.click(fn=trigger_stop).then(fn=end_i2i, outputs=[i2i_generate_btn, i2i_stop_btn], cancels=[i2i_generate_event])
inpaint_refresh_lora.click(fn=scan_lora_items, outputs=[lora_drop, i2i_lora_drop, inpaint_lora_drop])
inpaint_lora_drop.change(update_prompt_with_lora, [inpaint_prompt, inpaint_lora_drop, inpaint_lora_alpha], inpaint_prompt)
def refresh_models_inpaint():
return gr.update(choices=get_choices(MOD_TRANSFORMER)), gr.update(choices=get_choices(MOD_VAE))
inpaint_refresh_model.click(fn=refresh_models_inpaint, outputs=[inpaint_transformer_choice, inpaint_vae_choice])
def start_inpaint(): return gr.update(interactive=False), gr.update(interactive=True)
def end_inpaint(): return gr.update(interactive=True), gr.update(interactive=False)
inpaint_generate_event = inpaint_generate_btn.click(fn=start_inpaint, outputs=[inpaint_generate_btn, inpaint_stop_btn]).then(
fn=run_inpainting,
inputs=[inpaint_editor, inpaint_prompt, inpaint_negative_prompt, inpaint_lora_drop, inpaint_lora_alpha,
inpaint_strength, inpaint_steps, inpaint_cfg,
inpaint_seed, inpaint_randomize_seed,
inpaint_transformer_choice, inpaint_vae_choice, inpaint_vram_type],
outputs=[inpaint_gallery, inpaint_used_seed]
).then(fn=end_inpaint, outputs=[inpaint_generate_btn, inpaint_stop_btn])
inpaint_stop_btn.click(fn=trigger_stop).then(fn=end_inpaint, outputs=[inpaint_generate_btn, inpaint_stop_btn], cancels=[inpaint_generate_event])
edit_btn.click(
fn=edit_image,
inputs=[image_input, rotate_angle, crop_x, crop_y, crop_width, crop_height,
flip_horizontal, flip_vertical, apply_filter, brightness, contrast, saturation],
outputs=edited_image_output
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(show_error=True)