1 / app.py
youneeds's picture
Create app.py
af0de5a verified
import gradio as gr
import torch
import json
from diffusers import AutoPipelineForInpainting, StableDiffusionPipeline
from PIL import Image
# ==========================
# 模型ID映射
# ==========================
MODEL_LIST = {
"SD1.5 局部重绘": "runwayml/stable-diffusion-inpainting",
"SD1.5 基础通用": "runwayml/stable-diffusion-v1-5",
"游戏资产贴图专用": "Yntec/GameAsset",
"2D手绘场景贴图": "Yaeger/Flat2D",
"低多边形LowPoly": "NiamaOfficial/LowPolyStyle",
"体素Voxel艺术": "VoxelArtModels/VoxelBase15",
"唯美插画风": "andite/pastel-mix",
"轻量二次元V3": "gsdf/Counterfeit-V3.0",
"吉卜力手绘动画": "nitrosocke/Ghibli-Style",
"原神二次元场景": "Linaqruf/giant-squid-ai",
"王者荣耀国风原画": "WarriorMama777/AbyssOrangeMix2",
"塞尔达旷野手绘风": "andite/pastel-mix",
"复古像素艺术": "OnodoAI/PixelArtModel",
"RPG像素场景贴图": "ZinPixel/RPGMakerTiles",
"写实建筑场景": "XpucT/Deliberate",
"写实环境无VAE": "SG161222/Realistic_Vision_V5.1_noVAE",
"卡通写实融合": "lilpotat/ToonYou",
"FastSD 极速CPU": "rupeshs/fastsdcpu",
"BK-SDM 轻量化": "nota-ai/bk-sdm-small",
"SD-Turbo 超快出图": "stabilityai/sd-turbo"
}
def load_prompt_presets():
try:
with open("prompts.json", "r", encoding="utf-8") as f:
return json.load(f)
except:
return {}
PROMPT_PRESETS = load_prompt_presets()
pipe_cache = {}
def get_pipe(model_name, mode="txt2img"):
cache_key = f"{model_name}_{mode}"
if cache_key in pipe_cache:
return pipe_cache[cache_key]
model_id = MODEL_LIST[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
if mode == "inpaint":
pipe = AutoPipelineForInpainting.from_pretrained(
model_id, torch_dtype=dtype, low_cpu_mem_usage=True
).to(device)
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=dtype, low_cpu_mem_usage=True
).to(device)
pipe_cache[cache_key] = pipe
return pipe
def update_prompts(model_name):
preset = PROMPT_PRESETS.get(model_name, {})
return gr.update(value=preset.get("positive", "")), gr.update(value=preset.get("negative", ""))
def toggle_input_mode(mode):
if mode == "文生图(纯文本)":
return gr.update(visible=False), gr.update(visible=True, value="✨ 文生图模式:无需上传图片,直接使用提示词生成")
else:
return gr.update(visible=True), gr.update(visible=False, value="")
def update_slider_visibility(mode):
return gr.update(visible=(mode == "图像编辑(重绘)"))
def generate_image(image_data, prompt, neg_prompt, model_name, mode, denoise_str, steps, cfg):
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
if mode == "文生图(纯文本)":
pipe = get_pipe(model_name, mode="txt2img")
output = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
num_inference_steps=steps,
guidance_scale=cfg
).images[0]
else:
if not image_data or "image" not in image_data:
return None
img = image_data["image"].convert("RGB").resize((512, 512))
mask = image_data["mask"].convert("L").resize((512, 512)) if image_data.get("mask") else None
pipe = get_pipe(model_name, mode="inpaint")
output = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
image=img,
mask_image=mask,
strength=denoise_str,
num_inference_steps=steps,
guidance_scale=cfg
).images[0]
return output
except Exception as e:
print(f"生成错误: {e}")
return None
with gr.Blocks(title="全风格AI图像工具", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 全风格AI图像生成 + 重绘工具")
with gr.Row():
with gr.Column(scale=1):
mode_radio = gr.Radio(
choices=["文生图(纯文本)", "图像编辑(重绘)"],
label="✨ 选择生成模式",
value="图像编辑(重绘)"
)
mode_hint = gr.Markdown("", visible=False)
input_img = gr.ImageEditor(
label="上传图片 + 涂抹重绘区域(重绘模式)",
type="pil",
height=420,
visible=True
)
model_select = gr.Dropdown(
label="🎮 选择风格模型",
choices=list(MODEL_LIST.keys()),
value="游戏资产贴图专用"
)
prompt = gr.Textbox(label="正向提示词(可追加修改)", lines=3)
neg_prompt = gr.Textbox(label="反向提示词(可追加修改)", lines=2)
with gr.Column(scale=1):
denoise_str = gr.Slider(label="重绘强度", minimum=0.1, maximum=0.9, value=0.5, visible=True)
steps = gr.Slider(label="采样步数", minimum=10, maximum=30, value=20)
cfg = gr.Slider(label="提示词强度", minimum=1, maximum=15, value=7)
generate_btn = gr.Button("✨ 一键生成", variant="primary")
output_img = gr.Image(label="生成结果", height=420)
mode_radio.change(toggle_input_mode, inputs=mode_radio, outputs=[input_img, mode_hint])
mode_radio.change(update_slider_visibility, inputs=mode_radio, outputs=denoise_str)
model_select.change(update_prompts, inputs=model_select, outputs=[prompt, neg_prompt])
demo.load(fn=lambda: update_prompts("游戏资产贴图专用"), outputs=[prompt, neg_prompt])
generate_btn.click(
fn=generate_image,
inputs=[input_img, prompt, neg_prompt, model_select, mode_radio, denoise_str, steps, cfg],
outputs=output_img
)
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)