|
|
import random |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import spaces |
|
|
import torch |
|
|
from diffusers import FluxFillPipeline |
|
|
from loras import LoRA, loras |
|
|
from PIL import Image |
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) |
|
|
|
|
|
flux_keywords_available = ["IMG_1025.HEIC", "Selfie"] |
|
|
|
|
|
|
|
|
def pack_latents(latents, batch_size, num_channels, height, width): |
|
|
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2) |
|
|
latents = latents.permute(0, 2, 4, 1, 3, 5) |
|
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4) |
|
|
return latents |
|
|
|
|
|
|
|
|
def unpack_latents(latents, height, width, h_scale=2, w_scale=2): |
|
|
batch_size, seq_len, channels = latents.shape |
|
|
|
|
|
latents = latents.view( |
|
|
batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale |
|
|
) |
|
|
latents = latents.permute(0, 3, 1, 4, 2, 5) |
|
|
latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width) |
|
|
return latents |
|
|
|
|
|
|
|
|
|
|
|
def get_gradual_blend_callback( |
|
|
pipe, |
|
|
original_image, |
|
|
preserved_area_mask, |
|
|
total_steps, |
|
|
step_images_list, |
|
|
start_alpha=1.0, |
|
|
end_alpha=0.2, |
|
|
): |
|
|
device = pipe.device |
|
|
dtype = pipe.transformer.dtype |
|
|
|
|
|
packed_init_latents = None |
|
|
packed_preserved_mask = None |
|
|
h_latent = w_latent = None |
|
|
|
|
|
if preserved_area_mask is not None: |
|
|
with torch.no_grad(): |
|
|
img_tensor = ( |
|
|
(torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0) |
|
|
.unsqueeze(0) |
|
|
.to(device, dtype) |
|
|
) |
|
|
init_latents = pipe.vae.encode(img_tensor).latent_dist.sample() |
|
|
init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor |
|
|
|
|
|
_, _, h_latent, w_latent = init_latents.shape |
|
|
|
|
|
packed_init_latents = pack_latents( |
|
|
init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent |
|
|
) |
|
|
|
|
|
mask_tensor = ( |
|
|
(torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0) |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(0) |
|
|
.to(device, dtype) |
|
|
) |
|
|
latent_preserved_mask = torch.nn.functional.interpolate( |
|
|
mask_tensor, size=(h_latent, w_latent), mode="nearest" |
|
|
) |
|
|
packed_preserved_mask = pack_latents( |
|
|
latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent |
|
|
) |
|
|
|
|
|
def callback_fn(pipe, step, timestep, callback_kwargs): |
|
|
latents = callback_kwargs["latents"] |
|
|
|
|
|
if packed_preserved_mask is not None: |
|
|
progress = step / max(1, total_steps - 1) |
|
|
current_alpha = start_alpha - (start_alpha - end_alpha) * progress |
|
|
|
|
|
effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16) |
|
|
latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents |
|
|
|
|
|
if step % 5 == 0 or step == total_steps - 1: |
|
|
with torch.no_grad(): |
|
|
unpacked = unpack_latents(latents, h_latent, w_latent) |
|
|
unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor |
|
|
decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample |
|
|
img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0] |
|
|
step_images_list.append(img_step) |
|
|
|
|
|
callback_kwargs["latents"] = latents |
|
|
return callback_kwargs |
|
|
|
|
|
return callback_fn |
|
|
|
|
|
|
|
|
|
|
|
def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]): |
|
|
adapter_names = [] |
|
|
adapter_weights = [] |
|
|
for lora, weight in loras_with_weights: |
|
|
pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name) |
|
|
adapter_names.append(lora.name) |
|
|
adapter_weights.append(weight) |
|
|
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) |
|
|
return pipe |
|
|
|
|
|
|
|
|
def deactivate_loras(pipe): |
|
|
pipe.unload_lora_weights() |
|
|
return pipe |
|
|
|
|
|
|
|
|
|
|
|
def calculate_optimal_dimensions(image): |
|
|
original_width, original_height = image.size |
|
|
FIXED_DIMENSION = 1024 |
|
|
aspect_ratio = original_width / original_height |
|
|
if aspect_ratio > 1: |
|
|
width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio) |
|
|
else: |
|
|
height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio) |
|
|
return (width // 8) * 8, (height // 8) * 8 |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def inpaint( |
|
|
image, |
|
|
mask, |
|
|
preserved_area_mask=None, |
|
|
prompt: str = "", |
|
|
seed: int = 0, |
|
|
num_inference_steps: int = 28, |
|
|
guidance_scale: int = 50, |
|
|
strength: float = 1.0, |
|
|
): |
|
|
image = image.convert("RGB") |
|
|
mask = mask.convert("L") |
|
|
width, height = calculate_optimal_dimensions(image) |
|
|
|
|
|
|
|
|
image_resized = image.resize((width, height), Image.LANCZOS) |
|
|
|
|
|
pipe.to("cuda") |
|
|
|
|
|
|
|
|
step_images = [] |
|
|
callback = None |
|
|
if preserved_area_mask is not None: |
|
|
preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST) |
|
|
callback = get_gradual_blend_callback( |
|
|
pipe, image_resized, preserved_area_resized, num_inference_steps, step_images |
|
|
) |
|
|
|
|
|
result = pipe( |
|
|
image=image_resized, |
|
|
mask_image=mask.resize((width, height)), |
|
|
prompt=prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
strength=strength, |
|
|
generator=torch.Generator().manual_seed(seed), |
|
|
callback_on_step_end=callback, |
|
|
callback_on_step_end_tensor_inputs=["latents"] if callback else None, |
|
|
).images[0] |
|
|
|
|
|
return result.convert("RGBA"), step_images, prompt, seed |
|
|
|
|
|
|
|
|
def inpaint_api( |
|
|
image, |
|
|
mask, |
|
|
preserved_area_mask=None, |
|
|
prompt: str = "", |
|
|
seed: int = -1, |
|
|
num_inference_steps: int = 40, |
|
|
guidance_scale: float = 30.0, |
|
|
strength: float = 1.0, |
|
|
flux_keywords: list[str] = None, |
|
|
loras_selected: list[tuple[str, float]] = None, |
|
|
): |
|
|
selected_loras_with_weights = [] |
|
|
|
|
|
if loras_selected: |
|
|
for name, weight_value in loras_selected: |
|
|
try: |
|
|
weight = float(weight_value) |
|
|
except (ValueError, TypeError): |
|
|
continue |
|
|
lora_obj = next((l for l in loras if l.display_name == name), None) |
|
|
if lora_obj and weight != 0.0: |
|
|
selected_loras_with_weights.append((lora_obj, weight)) |
|
|
|
|
|
deactivate_loras(pipe) |
|
|
if selected_loras_with_weights: |
|
|
activate_loras(pipe, selected_loras_with_weights) |
|
|
|
|
|
final_prompt = "" |
|
|
if flux_keywords: |
|
|
final_prompt += ", ".join(flux_keywords) + ", " |
|
|
|
|
|
if selected_loras_with_weights: |
|
|
for lora, _ in selected_loras_with_weights: |
|
|
if lora.keyword: |
|
|
final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", " |
|
|
|
|
|
final_prompt += prompt |
|
|
|
|
|
if not isinstance(seed, int) or seed < 0: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
return inpaint( |
|
|
image=image, |
|
|
mask=mask, |
|
|
preserved_area_mask=preserved_area_mask, |
|
|
prompt=final_prompt, |
|
|
seed=seed, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
strength=strength, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman") |
|
|
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1) |
|
|
num_inference_steps_input = gr.Number(label="Inference steps", value=40) |
|
|
guidance_scale_input = gr.Number(label="Guidance scale", value=30) |
|
|
strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0) |
|
|
|
|
|
gr.Markdown("### Flux Keywords") |
|
|
flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords") |
|
|
|
|
|
if loras: |
|
|
gr.Markdown("### Available LoRAs") |
|
|
lora_names = [l.display_name for l in loras] |
|
|
loras_selected_input = gr.Dataframe( |
|
|
type="array", |
|
|
headers=["LoRA", "Weight"], |
|
|
value=[[name, 0.0] for name in lora_names], |
|
|
datatype=["str", "number"], |
|
|
interactive=[False, True], |
|
|
label="LoRA selection", |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
image_input = gr.Image(label="Original Image", type="pil") |
|
|
mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil") |
|
|
preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil") |
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
result_image = gr.Image(label="Result") |
|
|
used_prompt_box = gr.Text(label="Final Prompt") |
|
|
used_seed_box = gr.Number(label="Used Seed") |
|
|
steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True) |
|
|
|
|
|
run_btn.click( |
|
|
fn=inpaint_api, |
|
|
inputs=[ |
|
|
image_input, |
|
|
mask_input, |
|
|
preserved_area_input, |
|
|
prompt_input, |
|
|
seed_slider, |
|
|
num_inference_steps_input, |
|
|
guidance_scale_input, |
|
|
strength_input, |
|
|
flux_keywords_input, |
|
|
loras_selected_input, |
|
|
], |
|
|
outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|