1
File size: 6,162 Bytes
af0de5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import json
from diffusers import AutoPipelineForInpainting, StableDiffusionPipeline
from PIL import Image

# ==========================
# 模型ID映射
# ==========================
MODEL_LIST = {
    "SD1.5 局部重绘": "runwayml/stable-diffusion-inpainting",
    "SD1.5 基础通用": "runwayml/stable-diffusion-v1-5",
    "游戏资产贴图专用": "Yntec/GameAsset",
    "2D手绘场景贴图": "Yaeger/Flat2D",
    "低多边形LowPoly": "NiamaOfficial/LowPolyStyle",
    "体素Voxel艺术": "VoxelArtModels/VoxelBase15",
    "唯美插画风": "andite/pastel-mix",
    "轻量二次元V3": "gsdf/Counterfeit-V3.0",
    "吉卜力手绘动画": "nitrosocke/Ghibli-Style",
    "原神二次元场景": "Linaqruf/giant-squid-ai",
    "王者荣耀国风原画": "WarriorMama777/AbyssOrangeMix2",
    "塞尔达旷野手绘风": "andite/pastel-mix",
    "复古像素艺术": "OnodoAI/PixelArtModel",
    "RPG像素场景贴图": "ZinPixel/RPGMakerTiles",
    "写实建筑场景": "XpucT/Deliberate",
    "写实环境无VAE": "SG161222/Realistic_Vision_V5.1_noVAE",
    "卡通写实融合": "lilpotat/ToonYou",
    "FastSD 极速CPU": "rupeshs/fastsdcpu",
    "BK-SDM 轻量化": "nota-ai/bk-sdm-small",
    "SD-Turbo 超快出图": "stabilityai/sd-turbo"
}

def load_prompt_presets():
    try:
        with open("prompts.json", "r", encoding="utf-8") as f:
            return json.load(f)
    except:
        return {}

PROMPT_PRESETS = load_prompt_presets()
pipe_cache = {}

def get_pipe(model_name, mode="txt2img"):
    cache_key = f"{model_name}_{mode}"
    if cache_key in pipe_cache:
        return pipe_cache[cache_key]
    model_id = MODEL_LIST[model_name]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32
    if mode == "inpaint":
        pipe = AutoPipelineForInpainting.from_pretrained(
            model_id, torch_dtype=dtype, low_cpu_mem_usage=True
        ).to(device)
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id, torch_dtype=dtype, low_cpu_mem_usage=True
        ).to(device)
    pipe_cache[cache_key] = pipe
    return pipe

def update_prompts(model_name):
    preset = PROMPT_PRESETS.get(model_name, {})
    return gr.update(value=preset.get("positive", "")), gr.update(value=preset.get("negative", ""))

def toggle_input_mode(mode):
    if mode == "文生图(纯文本)":
        return gr.update(visible=False), gr.update(visible=True, value="✨ 文生图模式:无需上传图片,直接使用提示词生成")
    else:
        return gr.update(visible=True), gr.update(visible=False, value="")

def update_slider_visibility(mode):
    return gr.update(visible=(mode == "图像编辑(重绘)"))

def generate_image(image_data, prompt, neg_prompt, model_name, mode, denoise_str, steps, cfg):
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if mode == "文生图(纯文本)":
            pipe = get_pipe(model_name, mode="txt2img")
            output = pipe(
                prompt=prompt,
                negative_prompt=neg_prompt,
                num_inference_steps=steps,
                guidance_scale=cfg
            ).images[0]
        else:
            if not image_data or "image" not in image_data:
                return None
            img = image_data["image"].convert("RGB").resize((512, 512))
            mask = image_data["mask"].convert("L").resize((512, 512)) if image_data.get("mask") else None
            pipe = get_pipe(model_name, mode="inpaint")
            output = pipe(
                prompt=prompt,
                negative_prompt=neg_prompt,
                image=img,
                mask_image=mask,
                strength=denoise_str,
                num_inference_steps=steps,
                guidance_scale=cfg
            ).images[0]
        return output
    except Exception as e:
        print(f"生成错误: {e}")
        return None

with gr.Blocks(title="全风格AI图像工具", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎨 全风格AI图像生成 + 重绘工具")
    with gr.Row():
        with gr.Column(scale=1):
            mode_radio = gr.Radio(
                choices=["文生图(纯文本)", "图像编辑(重绘)"],
                label="✨ 选择生成模式",
                value="图像编辑(重绘)"
            )
            mode_hint = gr.Markdown("", visible=False)
            input_img = gr.ImageEditor(
                label="上传图片 + 涂抹重绘区域(重绘模式)",
                type="pil",
                height=420,
                visible=True
            )
            model_select = gr.Dropdown(
                label="🎮 选择风格模型",
                choices=list(MODEL_LIST.keys()),
                value="游戏资产贴图专用"
            )
            prompt = gr.Textbox(label="正向提示词(可追加修改)", lines=3)
            neg_prompt = gr.Textbox(label="反向提示词(可追加修改)", lines=2)
        with gr.Column(scale=1):
            denoise_str = gr.Slider(label="重绘强度", minimum=0.1, maximum=0.9, value=0.5, visible=True)
            steps = gr.Slider(label="采样步数", minimum=10, maximum=30, value=20)
            cfg = gr.Slider(label="提示词强度", minimum=1, maximum=15, value=7)
            generate_btn = gr.Button("✨ 一键生成", variant="primary")
            output_img = gr.Image(label="生成结果", height=420)

    mode_radio.change(toggle_input_mode, inputs=mode_radio, outputs=[input_img, mode_hint])
    mode_radio.change(update_slider_visibility, inputs=mode_radio, outputs=denoise_str)
    model_select.change(update_prompts, inputs=model_select, outputs=[prompt, neg_prompt])
    demo.load(fn=lambda: update_prompts("游戏资产贴图专用"), outputs=[prompt, neg_prompt])
    generate_btn.click(
        fn=generate_image,
        inputs=[input_img, prompt, neg_prompt, model_select, mode_radio, denoise_str, steps, cfg],
        outputs=output_img
    )

demo.launch(server_name="0.0.0.0", server_port=7860, share=True)