import gradio as gr import torch import json from diffusers import AutoPipelineForInpainting, StableDiffusionPipeline from PIL import Image # ========================== # 模型ID映射 # ========================== MODEL_LIST = { "SD1.5 局部重绘": "runwayml/stable-diffusion-inpainting", "SD1.5 基础通用": "runwayml/stable-diffusion-v1-5", "游戏资产贴图专用": "Yntec/GameAsset", "2D手绘场景贴图": "Yaeger/Flat2D", "低多边形LowPoly": "NiamaOfficial/LowPolyStyle", "体素Voxel艺术": "VoxelArtModels/VoxelBase15", "唯美插画风": "andite/pastel-mix", "轻量二次元V3": "gsdf/Counterfeit-V3.0", "吉卜力手绘动画": "nitrosocke/Ghibli-Style", "原神二次元场景": "Linaqruf/giant-squid-ai", "王者荣耀国风原画": "WarriorMama777/AbyssOrangeMix2", "塞尔达旷野手绘风": "andite/pastel-mix", "复古像素艺术": "OnodoAI/PixelArtModel", "RPG像素场景贴图": "ZinPixel/RPGMakerTiles", "写实建筑场景": "XpucT/Deliberate", "写实环境无VAE": "SG161222/Realistic_Vision_V5.1_noVAE", "卡通写实融合": "lilpotat/ToonYou", "FastSD 极速CPU": "rupeshs/fastsdcpu", "BK-SDM 轻量化": "nota-ai/bk-sdm-small", "SD-Turbo 超快出图": "stabilityai/sd-turbo" } def load_prompt_presets(): try: with open("prompts.json", "r", encoding="utf-8") as f: return json.load(f) except: return {} PROMPT_PRESETS = load_prompt_presets() pipe_cache = {} def get_pipe(model_name, mode="txt2img"): cache_key = f"{model_name}_{mode}" if cache_key in pipe_cache: return pipe_cache[cache_key] model_id = MODEL_LIST[model_name] device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 if mode == "inpaint": pipe = AutoPipelineForInpainting.from_pretrained( model_id, torch_dtype=dtype, low_cpu_mem_usage=True ).to(device) else: pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=dtype, low_cpu_mem_usage=True ).to(device) pipe_cache[cache_key] = pipe return pipe def update_prompts(model_name): preset = PROMPT_PRESETS.get(model_name, {}) return gr.update(value=preset.get("positive", "")), gr.update(value=preset.get("negative", "")) def toggle_input_mode(mode): if mode == "文生图(纯文本)": return gr.update(visible=False), gr.update(visible=True, value="✨ 文生图模式:无需上传图片,直接使用提示词生成") else: return gr.update(visible=True), gr.update(visible=False, value="") def update_slider_visibility(mode): return gr.update(visible=(mode == "图像编辑(重绘)")) def generate_image(image_data, prompt, neg_prompt, model_name, mode, denoise_str, steps, cfg): try: device = "cuda" if torch.cuda.is_available() else "cpu" if mode == "文生图(纯文本)": pipe = get_pipe(model_name, mode="txt2img") output = pipe( prompt=prompt, negative_prompt=neg_prompt, num_inference_steps=steps, guidance_scale=cfg ).images[0] else: if not image_data or "image" not in image_data: return None img = image_data["image"].convert("RGB").resize((512, 512)) mask = image_data["mask"].convert("L").resize((512, 512)) if image_data.get("mask") else None pipe = get_pipe(model_name, mode="inpaint") output = pipe( prompt=prompt, negative_prompt=neg_prompt, image=img, mask_image=mask, strength=denoise_str, num_inference_steps=steps, guidance_scale=cfg ).images[0] return output except Exception as e: print(f"生成错误: {e}") return None with gr.Blocks(title="全风格AI图像工具", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 全风格AI图像生成 + 重绘工具") with gr.Row(): with gr.Column(scale=1): mode_radio = gr.Radio( choices=["文生图(纯文本)", "图像编辑(重绘)"], label="✨ 选择生成模式", value="图像编辑(重绘)" ) mode_hint = gr.Markdown("", visible=False) input_img = gr.ImageEditor( label="上传图片 + 涂抹重绘区域(重绘模式)", type="pil", height=420, visible=True ) model_select = gr.Dropdown( label="🎮 选择风格模型", choices=list(MODEL_LIST.keys()), value="游戏资产贴图专用" ) prompt = gr.Textbox(label="正向提示词(可追加修改)", lines=3) neg_prompt = gr.Textbox(label="反向提示词(可追加修改)", lines=2) with gr.Column(scale=1): denoise_str = gr.Slider(label="重绘强度", minimum=0.1, maximum=0.9, value=0.5, visible=True) steps = gr.Slider(label="采样步数", minimum=10, maximum=30, value=20) cfg = gr.Slider(label="提示词强度", minimum=1, maximum=15, value=7) generate_btn = gr.Button("✨ 一键生成", variant="primary") output_img = gr.Image(label="生成结果", height=420) mode_radio.change(toggle_input_mode, inputs=mode_radio, outputs=[input_img, mode_hint]) mode_radio.change(update_slider_visibility, inputs=mode_radio, outputs=denoise_str) model_select.change(update_prompts, inputs=model_select, outputs=[prompt, neg_prompt]) demo.load(fn=lambda: update_prompts("游戏资产贴图专用"), outputs=[prompt, neg_prompt]) generate_btn.click( fn=generate_image, inputs=[input_img, prompt, neg_prompt, model_select, mode_radio, denoise_str, steps, cfg], outputs=output_img ) demo.launch(server_name="0.0.0.0", server_port=7860, share=True)