ZhouZJ36DL's picture
Update app.py
6a8e8f4 verified
import os
import re
import time
from io import BytesIO
import uuid
from dataclasses import dataclass
from glob import iglob
import argparse
from einops import rearrange
#from fire import Fire
from PIL import ExifTags, Image
from safetensors.torch import load_file, save_file
import spaces
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
from transformers import pipeline
from src.flux.sampling import denoise_fireflow, get_schedule, prepare, prepare_image, unpack, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent, get_noise
from src.flux.util import (configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5)
@dataclass
class SamplingOptions:
source_prompt: str
target_prompt: str
# prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
offload = False
device = "cuda" if torch.cuda.is_available() else "cpu"
name = 'flux-dev'
ae = load_ae(name, device="cpu" if offload else torch_device)
t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
t5.eval()
clip.eval()
ae.eval()
model.eval()
is_schnell = False
add_sampling_metadata = True
# clear history
if os.path.exists("history_gradio/history.safetensors"):
os.remove("history_gradio/history.safetensors")
out_root = 'src/gradio_utils/gradio_outputs'
out_root_prompt = 'src/gradio_utils/gradio_prompts'
if not os.path.exists(out_root):
os.makedirs(out_root)
if not os.path.exists(out_root_prompt):
os.makedirs(out_root_prompt)
exp_folders = [d for d in os.listdir(out_root) if d.startswith("exp_") and d[4:].isdigit()]
if exp_folders:
max_idx = max(int(d[4:]) for d in exp_folders)
name_dir = f"exp_{max_idx + 1}"
else:
name_dir = "exp_0"
output_dir = os.path.join(out_root, name_dir)
output_prompt = os.path.join(out_root_prompt, name_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(output_prompt):
os.makedirs(output_prompt)
if not os.path.exists("heatmap"):
os.makedirs("heatmap")
if not os.path.exists("heatmap/average_heatmaps"):
os.makedirs("heatmap/average_heatmaps")
source_image = None
history_tensors = {
"source img": torch.zeros((1, 1, 1)),
"prev img": torch.zeros((1, 1, 1))}
instructions = ['']
def read_sorted_prompts(folder_path):
# List all .txt files and sort them
files = sorted([f for f in os.listdir(folder_path) if f.endswith('.txt')])
prompts = []
for filename in files:
file_path = os.path.join(folder_path, filename)
with open(file_path, 'r') as f:
prompt = f.read().strip()
prompts.append(prompt)
return prompts
@torch.inference_mode()
def reset():
# clear history
if os.path.exists("history_gradio/history.safetensors"):
os.remove("history_gradio/history.safetensors")
global out_root, out_root_prompt, output_dir, output_prompt, history_tensors, source_image, instructions
if not os.path.exists(out_root):
os.makedirs(out_root)
if not os.path.exists(out_root_prompt):
os.makedirs(out_root_prompt)
exp_folders = [d for d in os.listdir(out_root) if d.startswith("exp_") and d[4:].isdigit()]
if exp_folders:
max_idx = max(int(d[4:]) for d in exp_folders)
name_dir = f"exp_{max_idx + 1}"
else:
name_dir = "exp_0"
output_dir = os.path.join(out_root, name_dir)
output_prompt = os.path.join(out_root_prompt, name_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(output_prompt):
os.makedirs(output_prompt)
if not os.path.exists("heatmap"):
os.makedirs("heatmap")
if not os.path.exists("heatmap/average_heatmaps"):
os.makedirs("heatmap/average_heatmaps")
instructions = ['']
source_image = None
history_tensors = {
"source img": torch.zeros((1, 1, 1)),
"prev img": torch.zeros((1, 1, 1))}
source_prompt = "(Optional) Describe the content of the uploaded image."
traget_prompt = "(Required) Describe the desired content of the edited image."
gallery = None
output_image = None
init_image = None
return source_prompt, traget_prompt, gallery, output_image, init_image
@torch.inference_mode()
def process_image(
init_image,
source_prompt,
target_prompt,
editing_strategy,
denoise_strategy,
num_steps,
guidance,
attn_guidance_start_block,
inject_step,
init_image_2=None):
if init_image is None:
img, gr_gallery = generate_image(prompt=target_prompt)
else:
img, gr_gallery = edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2)
return img, gr_gallery
@spaces.GPU(duration=120)
@torch.inference_mode()
def generate_image(
width=512,
height=512,
num_steps=28,
guidance=3.5,
seed=None,
prompt='',
init_image=None,
image2image_strength=0.0,
):
global ae, t5, clip, model, name, is_schnell, output_dir, output_prompt, add_sampling_metadata, offload, history_tensors
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
seed = None
if seed is None:
g_seed = torch.Generator(device="cpu").seed()
print(f"Generating '{prompt}' with seed {g_seed}")
t0 = time.perf_counter()
if init_image is not None:
if isinstance(init_image, np.ndarray):
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
init_image = init_image.unsqueeze(0)
init_image = init_image.to(device)
init_image = torch.nn.functional.interpolate(init_image, (height, width))
if offload:
ae.encoder.to(device)
init_image = ae.encode(init_image)
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
# prepare input
x = get_noise(
1,
height,
width,
device=device,
dtype=torch.bfloat16,
seed=g_seed,
)
timesteps = get_schedule(
num_steps,
x.shape[-1] * x.shape[-2] // 4,
shift=(not is_schnell),
)
if init_image is not None:
t_idx = int((1 - image2image_strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
if offload:
t5, clip = t5.to(device), clip.to(device)
inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(device)
# denoise initial noise
info = {}
info['feature'] = {}
info['inject_step'] = 0
info['editing_strategy']= ""
info['start_layer_index'] = 0
info['end_layer_index'] = 37
info['reuse_v']= False
qkv_ratio = '1.0,1.0,1.0'
info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
x = denoise_rf(model, **inp, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x[0].float(), height, width)
device = torch.device("cuda")
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if offload:
ae.decoder.cpu()
torch.cuda.empty_cache()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
# bring into PIL format
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
filename = os.path.join(output_dir,f"round_0000_[{prompt}].jpg")
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
if init_image is None:
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
else:
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
instructions = [prompt]
prompt_path = os.path.join(output_prompt, f"round_0000.txt")
with open(prompt_path, "w") as f:
f.write(prompt)
#-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
img_and_prompt = []
history_imgs = sorted(os.listdir(output_dir))
instructions = read_sorted_prompts(output_prompt)
for img_file, prompt_txt in zip(history_imgs, instructions):
img_and_prompt.append((os.path.join(output_dir, img_file), prompt_txt))
history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
return img, history_gallery
@spaces.GPU(duration=200)
@torch.inference_mode()
def edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2=None):
global ae, t5, clip, model, name, is_schnell, output_dir, output_prompt, add_sampling_metadata, offload, source_image, history_tensors, instructions
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
seed = None
#----------------------------- 0.1 prepare multi-turn editing -------------------------------------#
info = {}
shape = init_image.shape
new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
if not any("round_0000" in fname for fname in os.listdir(output_dir)):
Image.fromarray(init_image).save(os.path.join(output_dir,"round_0000_[source].jpg"))
prompt_path = os.path.join(output_prompt, f"round_0000.txt")
with open(prompt_path, "w") as f:
f.write('')
init_image = init_image[:new_h, :new_w, :]
width, height = init_image.shape[0], init_image.shape[1]
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
init_image = init_image.unsqueeze(0)
init_image = init_image.to(device)
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.encoder.to(device)
with torch.no_grad():
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
if init_image_2 is None:
print("init_image_2 is not provided, proceeding with single image processing.")
else:
init_image_2_pil = Image.fromarray(init_image_2) # Convert NumPy array to PIL Image
init_image_2_pil = init_image_2_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
init_image_2 = np.array(init_image_2_pil) # Convert back to NumPy (if needed)
init_image_2 = torch.from_numpy(init_image_2).permute(2, 0, 1).float() / 127.5 - 1
rng = torch.Generator(device=torch.device("cpu"))
opts = SamplingOptions(
source_prompt=source_prompt,
target_prompt=target_prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=None,
)
if opts.seed is None:
opts.seed = torch.Generator(device=torch.device("cpu")).seed()
print(f"Editing with prompt:\n{opts.source_prompt}")
t0 = time.perf_counter()
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
opts.seed = None
#----------------------------- 0.2 prepare attention strategy -------------------------------------#
info = {}
info['feature'] = {}
info['inject_step'] = inject_step
info['editing_strategy']= " ".join(editing_strategy)
info['start_layer_index'] = 0
info['end_layer_index'] = 37
info['reuse_v']= False
qkv_ratio = '1.0,1.0,1.0'
info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
info['attn_guidance'] = attn_guidance_start_block
info['lqr_stop'] = 0.25
#----------------------------- 0.3 prepare latents -------------------------------------#
with torch.no_grad():
inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
if source_image is None:
source_image = inp['img']
inp_target_2 = None
if not init_image_2 is None:
inp_target_2 = prepare_image(init_image_2)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
#timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=False)
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
#----------------------------- 1 Inverting current image -------------------------------------#
denoise_strategies = ['fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion', 'multi_turn_consistent']
denoise_funcs = [denoise_fireflow, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent]
denoise_func = denoise_funcs[denoise_strategies.index(denoise_strategy)]
with torch.no_grad():
z, info = denoise_func(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
#----------------------------- 2 history_tensors used to implement dual-LQR guiding editing -------------------------------------#
inp_target["img"] = z
timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell"))
if torch.all(history_tensors['source img'] == 0):
history_tensors = {
"source img": inp["img"],
"prev img": inp_target_2}
else:
if inp_target_2 is None:
history_tensors["prev img"] = inp["img"]
else:
history_tensors["source img"] = inp["img"]
history_tensors["prev img"] = inp_target_2
#----------------------------- 3 sampling -------------------------------------#
if denoise_strategy in ['rf_inversion', 'multi_turn_consistent']:
x, _ = denoise_func(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info, img_LQR=history_tensors)
else:
x, _ = denoise_func(model, **inp_target, timesteps=timesteps, guidance=opts.guidance, inverse=False, info=info)
#----------------------------- 4 update history_tensors -------------------------------------#
info = {}
history_tensors["source img"] = source_image
history_tensors["prev img"] = x
#----------------------------- 5 decode x to image -------------------------------------#
x = unpack(x.float(), opts.width, opts.height)
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
device = torch.device("cuda")
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
# bring into PIL format and save
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = source_prompt
#-------------------------------- 6 save image -------------------------------------#
#-------------------- 6.1 prepare output folder ----------------------#
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 1
#-------------------- 6.2 editing round ----------------------#
else:
fns = [fn for fn in os.listdir(output_dir)]
if len(fns) > 0:
idx = max(int(fn.split("_")[1]) for fn in fns) + 1
else:
idx = 1
formatted_idx = str(idx).zfill(4) # Format as a 4-digit string
os.makedirs(output_prompt, exist_ok=True)
#-------------------- 6.3 output name ----------------------#
if denoise_strategy == 'multi_turn_consistent':
denoise_strategy = 'MTC'
if target_prompt == '':
target_prompt = 'Reconstruction'
if target_prompt == source_prompt:
target_prompt = 'Reconstruction: ' + target_prompt
target_suffix = " ".join(target_prompt.split()[-5:])
output_name = f"round_{formatted_idx}_{target_suffix}_{denoise_strategy}.jpg"
fn = os.path.join(output_dir, output_name)
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
img.save(fn)
if 'Reconstruction' in target_prompt:
target_prompt = source_prompt
instructions.append(target_prompt)
print("End Edit")
prompt_path = os.path.join(output_prompt, f"round_{formatted_idx}.txt")
with open(prompt_path, "w") as f:
f.write(target_prompt)
#-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
img_and_prompt = []
history_imgs = sorted(os.listdir(output_dir))
instructions = read_sorted_prompts(output_prompt)
for img_file, prompt_txt in zip(history_imgs, instructions):
img_and_prompt.append((os.path.join(output_dir, img_file), prompt_txt))
history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
return img, history_gallery
def on_select(gallery, selected: gr.SelectData):
return gallery[selected.index][0], gallery[selected.index][1]
#return gallery[selected.index][0]
def on_upload(path, uploaded: gr.EventData):
return path[0][0]
def on_change(init_image, changed: gr.EventData):
img_path = list(changed.target.temp_files)
return gr.Gallery(value=[(img_path[0], "")], label="History Image", interactive=True, columns=3), img_path[0]
def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
description = r"""
<h3>Tips 🔔:</h3>
<ol>
<li>The app starts with default settings. To begin: <strong>(1) Click Reset Button.</strong> (2)Try the example image (at the bottom of the page) / Upload your own / Generate one with a target prompt.</li>
<li> Adaptive Attention (attn_guidance): The option<i> Top activated attn-maps</i> is effective only when this editing technique is selected. </li>
<li> If you like this project, please ⭐ us on <a href='https://github.com/ZhouZJ-DL/Multi-turn_Consistent_Image_Editing' target='_blank'>GitHub</a> or cite our <a href='https://arxiv.org/abs/2505.04320' target='_blank'>paper</a>. Thanks for your support! </li>
</ol>
"""
css = '''
.gradio-container {width: 85% !important}
'''
is_schnell = model_name == "flux-schnell"
# Pre-defined examples
examples = [
["src/gradio_utils/gradio_examples/000000000011.jpg", "", "an eagle standing on the branch", ['attn_guidance'], 15, 3.5, 11, 0],
]
with gr.Blocks() as demo:
gr.Markdown(f"# Multi-turn Consistent Image Editing (FLUX.1-dev)")
gr.Markdown(description)
with gr.Row():
with gr.Column():
reset_btn = gr.Button("Reset", variant="primary")
source_prompt = gr.Textbox(label="Source Prompt", value="(Optional) Describe the content of the uploaded image.")
target_prompt = gr.Textbox(label="Target Prompt", value="(Required) Describe the desired content of the edited image.")
with gr.Row():
init_image = gr.Image(label="Initial Image", visible=False, width=200)
init_image_2 = gr.Image(label="Input Image 2", visible=False, width=200)
gallery = gr.Gallery(label ="History Image", interactive=True, columns=3)
editing_strategy = gr.CheckboxGroup(
label="Editing Technique",
choices=['attn_guidance', 'replace_v', 'add_q', 'add_k', 'add_v', 'replace_q', 'replace_k'],
value=['attn_guidance'], # Default: none selected
interactive=True
)
denoise_strategy = gr.Dropdown(
['multi_turn_consistent', 'fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion'],
label="Denoising Technique", value='multi_turn_consistent')
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
with gr.Accordion("Advanced Options", open=True):
num_steps = gr.Slider(1, 30, 15, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Text Guidance", interactive=not is_schnell)
attn_guidance_start_block = gr.Slider(0, 18, 11, step=1, label="Top activated attn-maps", interactive=not is_schnell)
inject_step = gr.Slider(0, 15, 1, step=1, label="Number of inject steps")
output_image = gr.Image(label="Generated/Edited Image")
example_image = gr.Image(label="example Image", visible=False, width=200)
gallery.select(on_select, gallery, [init_image, source_prompt])
#gallery.select(on_select, gallery, [init_image])
gallery.upload(on_upload, gallery, init_image)
example_image.change(on_change, example_image, [gallery, init_image])
generate_btn.click(
fn=process_image,
inputs=[init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2],
outputs=[output_image, gallery]
)
reset_btn.click(fn = reset, outputs=[source_prompt, target_prompt, gallery, output_image, init_image])
# Add examples
gr.Examples(
examples=examples,
inputs=[
example_image,
source_prompt,
target_prompt,
editing_strategy,
num_steps,
guidance,
attn_guidance_start_block,
inject_step
]
)
return demo
demo = create_demo(name, "cuda")
#demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
demo.launch(debug=True)