File size: 14,319 Bytes
886b3d1
 
 
 
 
484e146
 
 
886b3d1
 
 
484e146
886b3d1
 
9bd1a7c
886b3d1
 
484e146
886b3d1
 
 
 
484e146
 
 
 
04ad0ee
 
484e146
9bd1a7c
886b3d1
484e146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886b3d1
484e146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886b3d1
 
 
484e146
886b3d1
484e146
886b3d1
 
484e146
 
886b3d1
 
484e146
886b3d1
 
 
 
 
 
 
 
 
 
484e146
886b3d1
484e146
 
886b3d1
484e146
 
 
 
 
 
 
 
 
 
 
886b3d1
484e146
 
 
 
 
 
886b3d1
9bd1a7c
886b3d1
 
484e146
 
 
 
886b3d1
 
 
 
 
 
 
484e146
886b3d1
 
 
 
484e146
886b3d1
 
484e146
 
886b3d1
 
484e146
 
7dada90
 
 
484e146
 
 
7dada90
886b3d1
 
484e146
 
 
 
 
886b3d1
484e146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886b3d1
 
 
484e146
 
 
 
 
3762242
484e146
 
 
 
 
 
 
 
886b3d1
484e146
 
 
 
 
81963dd
 
484e146
 
 
 
 
 
 
 
 
 
 
 
 
886b3d1
484e146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886b3d1
 
 
 
 
 
484e146
 
 
886b3d1
 
 
 
 
484e146
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import gradio as gr
import numpy as np
import random
import torch
import spaces
import os
import base64
import math

from PIL import Image
from diffusers import QwenImageEditPlusPipeline
from pillow_heif import register_heif_opener

from huggingface_hub import login
from prompt_augment import PromptAugment
login(token=os.environ.get('hf'))

register_heif_opener()

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = QwenImageEditPlusPipeline.from_pretrained(
    "FireRedTeam/FireRed-Image-Edit-1.1", 
    torch_dtype=dtype
).to(device)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

prompt_handler = PromptAugment()

ADAPTER_SPECS = {
    "Covercraft": {
        "repo": "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
        "weights": "FireRed-Image-Edit-Covercraft.safetensors",
        "adapter_name": "covercraft",
    },
    "Lightning": {
        "repo": "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
        "weights": "FireRed-Image-Edit-Lightning-8steps-v1.0.safetensors",
        "adapter_name": "lightning",
    },
    "Makeup": {
        "repo": "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
        "weights": "FireRed-Image-Edit-Makeup.safetensors",
        "adapter_name": "makeup",
    }
}

LOADED_ADAPTERS = set()
LORA_OPTIONS = ["None"] + list(ADAPTER_SPECS.keys())


def load_lora(lora_name):
    """加载并激活指定的 LoRA"""
    if lora_name == "None" or not lora_name:
        if LOADED_ADAPTERS:
            pipe.set_adapters([], adapter_weights=[])
        return
    
    spec = ADAPTER_SPECS.get(lora_name)
    if not spec:
        raise gr.Error(f"LoRA 配置未找到: {lora_name}")

    adapter_name = spec["adapter_name"]

    if adapter_name not in LOADED_ADAPTERS:
        print(f"--- Downloading and Loading Adapter: {lora_name} ---")
        try:
            pipe.load_lora_weights(
                spec["repo"], 
                weight_name=spec["weights"], 
                adapter_name=adapter_name
            )
            LOADED_ADAPTERS.add(adapter_name)
        except Exception as e:
            raise gr.Error(f"Failed to load adapter {lora_name}: {e}")
    else:
        print(f"--- Adapter {lora_name} is already loaded ---")

    pipe.set_adapters([adapter_name], adapter_weights=[1.0])


MAX_SEED = np.iinfo(np.int32).max
MAX_INPUT_IMAGES = 3


def limit_images(images):
    if images is None:
        return None
    if len(images) > MAX_INPUT_IMAGES:
        gr.Info(f"最多支持 {MAX_INPUT_IMAGES} 张图片,已自动移除多余图片")
        return images[:MAX_INPUT_IMAGES]
    return images


def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio
    width = round(width / 32) * 32
    height = round(height / 32) * 32
    return int(width), int(height)


def update_dimensions_on_upload(images, max_area=1024*1024):
    if images is None or len(images) == 0:
        return 0, 0
    
    try:
        first_item = images[0]
        if isinstance(first_item, tuple):
            img = first_item[0]
        else:
            img = first_item
        
        if isinstance(img, Image.Image):
            pil_img = img
        elif isinstance(img, str):
            pil_img = Image.open(img)
        else:
            return 0, 0
        
        h, w = pil_img.height, pil_img.width
        is_multi_image = len(images) > 1
        
        if not is_multi_image:
            return 0, 0
        
        ratio = w / h
        new_w, new_h = calculate_dimensions(max_area, ratio)
        return new_h, new_w
    except Exception as e:
        print(f"获取图片尺寸失败: {e}")
        return 0, 0


@spaces.GPU(duration=180)
def infer(
    input_images,
    prompt,
    lora_choice,
    seed=42,
    randomize_seed=False,
    true_guidance_scale=4.0,
    num_inference_steps=40,
    height=None,
    width=None,
    rewrite_prompt=False,
    num_images_per_prompt=1,
    progress=gr.Progress(track_tqdm=True),
):
    negative_prompt = " "
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator(device=device).manual_seed(seed)
    
    load_lora(lora_choice)
    pil_images = []
    if input_images is not None:
        for item in input_images[:MAX_INPUT_IMAGES]:
            try:
                if isinstance(item, tuple):
                    img = item[0]
                else:
                    img = item
                    
                if isinstance(img, Image.Image):
                    pil_images.append(img.convert("RGB"))
                elif isinstance(img, str):
                    pil_images.append(Image.open(img).convert("RGB"))
            except Exception as e:
                print(f"处理图片出错: {e}")
                continue
    
    if height == 0:
        height = None
    if width == 0:
        width = None
    
    if rewrite_prompt and len(pil_images) > 0:
        prompt = prompt_handler.predict(prompt, [pil_images[0]])
        print(f"Rewritten Prompt: {prompt}")
    
    if pil_images:
        for i, img in enumerate(pil_images):
            print(f"    [{i}] size: {img.width}x{img.height}")
    images = pipe(
        image=pil_images if len(pil_images) > 0 else None,
        prompt=prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        generator=generator,
        guidance_scale=1.0,
        true_cfg_scale=true_guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
    ).images

    return images, seed

css = """
#col-container { margin: 0 auto; max-width: 1200px; }
#edit-btn { height: 100% !important; min-height: 42px; }
"""



def get_image_base64(image_path):
    with open(image_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')


logo_base64 = get_image_base64("logo.png") if os.path.exists("logo.png") else None

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        if logo_base64:
            gr.HTML(f'<img src="data:image/png;base64,{logo_base64}" alt="FireRed Logo" width="400" style="display: block; margin: 0 auto;">')
        else:
            gr.Markdown("# FireRed Image Edit")
        gr.Markdown(f"[Learn more](https://github.com/FireRedTeam/FireRed-Image-Edit) about the FireRed-Image-Edit series. Supports multi-image input (up to {MAX_INPUT_IMAGES} images.)")
        with gr.Row():
            with gr.Column(scale=1):
                input_images = gr.Gallery(
                    label="Upload Images",
                    type="pil",
                    interactive=True,
                    height=300,
                    columns=3,
                    object_fit="contain",
                )
            
            with gr.Column(scale=1):
                result = gr.Gallery(
                    label="Output Images",
                    type="pil",
                    height=300,
                    columns=2,
                    object_fit="contain",
                )
        
        prompt = gr.Textbox(
            label="Edit Prompt",
            placeholder="e.g., transform into anime..",
        )
        
        with gr.Row(equal_height=True):
            with gr.Column(scale=5):
                lora_choice = gr.Dropdown(
                    label="Choose Lora",
                    choices=LORA_OPTIONS,
                    value=LORA_OPTIONS[0] if LORA_OPTIONS else "None",
                )
            with gr.Column(scale=4):
                run_button = gr.Button("Edit Image", variant="primary", elem_id="edit-btn")

        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row():
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
            
            with gr.Row():
                true_guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.0)
                num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=30, step=1, value=30)
            
            with gr.Row():
                height = gr.Slider(label="Height (0=auto)", minimum=0, maximum=2048, step=8, value=0)
                width = gr.Slider(label="Width (0=auto)", minimum=0, maximum=2048, step=8, value=0)
            
            with gr.Row():
                rewrite_prompt = gr.Checkbox(label="Rewrite Prompt", value=False)
                num_images_per_prompt = gr.Slider(label="Num Images", minimum=1, maximum=4, step=1, value=1)

        # Examples
        gr.Examples(
            examples=[
                [["examples/master1.png"], "将背景换为带自然光效的浅蓝色,身穿浅米色蕾丝领上衣,将发型改为右侧佩戴精致珍珠发夹,同时单手向前抬起握着一把宝剑,另一只手自然摆放。面部微笑。",  "None"],
                [["examples/master2.png"], "替换背景为盛开的樱花树场景;更换衣服为黑色西装,为人物添加单肩蓝色书包,单手抓住包带。头发变为高马尾。色调明亮。蹲下。",  "None"],
                [["examples/master5.png"], "Change the background to pink and remove all bamboo leaves. Adjust the character's posture so she is lying in a pink plush basket, with her hands under her chin. Her head and body should be facing the camera, with the character positioned in the center of the frame. Her gaze should be fixed forward. Replace the hat with a headband adorned with pink flowers and pink ears. Change the clothing to beige plush attire. Remove the panda doll. Adjust the facial expression to a smiling face with the mouth open.",  "None"],
                [["examples/master6.png"], "Replace the background with a scene of an outdoor lake and trees. The character faces the camera, with his head slightly tilted to the left side of the screen. His arms are crossed, and he holds a small red drum in his right hand and under his left armpit. Zoom in on the camera view.",  "None"],
                [["examples/master3_1.png", "examples/master3_2.png"], "把图1中的模特换成图2里的长裙和高帮帆布鞋,保持原有姿态和配饰,整体风格统一。",  "None"],
                [["examples/master4_1.png", "examples/master4_2.png"], "把图1中的白色衬衫和棕色半裙,换成图2里的灰褐色连帽卫衣、黑色侧边条纹裤、卡其色工装靴和同色云朵包,保持模特姿态和背景不变。",  "None"],
                [["examples/makeup1.png"], "为人物添加纯欲厌世妆:使用冷白皮哑光粉底均匀肤色,描绘细挑的灰黑色野生眉,眼部晕染浅灰调眼影并加深眼尾,画出上扬的黑色眼线,粘贴浓密卷翘的假睫毛,在眼头和卧蚕处提亮,涂抹深紫调哑光口红并勾勒唇形,在颧骨处扫上浅粉腮红,鼻梁和眉骨处打高光,下颌线处轻扫阴影。", "Makeup"],
                [["examples/makeup2.png"], "为人物添加妆容:使用象牙白哑光粉底均匀肤色,描绘细长柳叶眉并填充浅棕色,眼部晕染浅棕色眼影并加深眼尾,画出自然黑色眼线,粘贴浓密假睫毛,用浅棕色眼影提亮卧蚕;涂抹豆沙色哑光口红并勾勒唇形,在两颊扫上浅粉色腮红,在鼻梁和颧骨处轻扫高光,在面部轮廓处轻扫阴影。", "Makeup"],
                [["examples/text1_1.png", "examples/text1_2.png"], "请在图1添加主标题文本 “谁说我们丑了”,字体样式参考图2中主标题《人!给我开个罐罐》;主标题整体采用横向排版多行错落(非严格对齐),置于图片左下角;在狗狗右下方、贴近前爪附近添加一个手绘“爱心”涂鸦贴纸;增加鱼眼镜头效果", "Covercraft"],
                [["examples/text2_1.png", "examples/text2_2.png"], "请在图1添加主标题文本 “崽子第一次玩冰”,副标题“坐标:东南休闲公园”,主标题和副标题的字体样式参考图2中主标题“无露营不冬天”,主标题整体采用横向排版多行,主标题添加在画面左侧上方;副标题添加在画面左侧下方,字的层级更小,避免修改和遮挡图1主体关键信息(人物/核心景物)和画面中心。", "Covercraft"],
            ],
            inputs=[input_images, prompt, lora_choice],
            outputs=[result, seed],
            fn=infer,
            cache_examples=False,
            label="Examples"
        )

    # 监听 LoRA 选择变化:Lightning 时锁定参数
    def on_lora_change(lora_name):
        if lora_name == "Lightning":
            return (
                gr.update(value=8, interactive=False),      # num_inference_steps
                gr.update(value=1.0, interactive=False),    # true_guidance_scale
                gr.update(value=0, interactive=True),      # seed
                gr.update(value=False, interactive=False),  # randomize_seed
            )
        else:
            return (
                gr.update(value=40, interactive=True),      # num_inference_steps
                gr.update(value=4.0, interactive=True),     # true_guidance_scale
                gr.update(value=42, interactive=True),      # seed
                gr.update(value=True, interactive=True),    # randomize_seed
            )
    
    lora_choice.change(
        fn=on_lora_change,
        inputs=[lora_choice],
        outputs=[num_inference_steps, true_guidance_scale, seed, randomize_seed],
    )
    def on_image_upload(images):
        limited = limit_images(images)
        h, w = update_dimensions_on_upload(limited)
        return limited, h, w
    
    input_images.upload(
        fn=on_image_upload,
        inputs=[input_images],
        outputs=[input_images, height, width],
    )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            input_images,
            prompt, lora_choice, seed, randomize_seed,
            true_guidance_scale, num_inference_steps,
            height, width, rewrite_prompt, num_images_per_prompt,
        ],
        outputs=[result, seed],
    )

if __name__ == "__main__":
    demo.queue()
    demo.launch(allowed_paths=["./"])