X-HighVoltage-X's picture
Update app.py
f3d4c16 verified
import random
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxFillPipeline
from loras import LoRA, loras
from PIL import Image
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
# --- LATENT MANIPULATION FUNCTIONS ---
def pack_latents(latents, batch_size, num_channels, height, width):
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
return latents
def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
batch_size, seq_len, channels = latents.shape
# Flux uses a 2x2 patch, so the factor is 2
latents = latents.view(
batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
return latents
# --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
def get_gradual_blend_callback(
pipe,
original_image,
preserved_area_mask,
total_steps,
step_images_list,
start_alpha=1.0,
end_alpha=0.2,
):
device = pipe.device
dtype = pipe.transformer.dtype
packed_init_latents = None
packed_preserved_mask = None
h_latent = w_latent = None
if preserved_area_mask is not None:
with torch.no_grad():
img_tensor = (
(torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
.unsqueeze(0)
.to(device, dtype)
)
init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
_, _, h_latent, w_latent = init_latents.shape
packed_init_latents = pack_latents(
init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent
)
mask_tensor = (
(torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
.unsqueeze(0)
.unsqueeze(0)
.to(device, dtype)
)
latent_preserved_mask = torch.nn.functional.interpolate(
mask_tensor, size=(h_latent, w_latent), mode="nearest"
)
packed_preserved_mask = pack_latents(
latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
)
def callback_fn(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
if packed_preserved_mask is not None:
progress = step / max(1, total_steps - 1)
current_alpha = start_alpha - (start_alpha - end_alpha) * progress
effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents
if step % 5 == 0 or step == total_steps - 1:
with torch.no_grad():
unpacked = unpack_latents(latents, h_latent, w_latent)
unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
step_images_list.append(img_step)
callback_kwargs["latents"] = latents
return callback_kwargs
return callback_fn
# --- LoRA's FUNCTIONS ---
def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
adapter_names = []
adapter_weights = []
for lora, weight in loras_with_weights:
pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
adapter_names.append(lora.name)
adapter_weights.append(weight)
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
return pipe
def deactivate_loras(pipe):
pipe.unload_lora_weights()
return pipe
# --- GENERATION
def calculate_optimal_dimensions(image):
original_width, original_height = image.size
FIXED_DIMENSION = 1024
aspect_ratio = original_width / original_height
if aspect_ratio > 1:
width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio)
else:
height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio)
return (width // 8) * 8, (height // 8) * 8
@spaces.GPU(duration=60)
def inpaint(
image,
mask,
preserved_area_mask=None,
prompt: str = "",
seed: int = 0,
num_inference_steps: int = 28,
guidance_scale: int = 50,
strength: float = 1.0,
):
image = image.convert("RGB")
mask = mask.convert("L")
width, height = calculate_optimal_dimensions(image)
# Resize to match dimensions
image_resized = image.resize((width, height), Image.LANCZOS)
pipe.to("cuda")
# Setup callback if a preserved area mask is provided
step_images = []
callback = None
if preserved_area_mask is not None:
preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
callback = get_gradual_blend_callback(
pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
)
result = pipe(
image=image_resized,
mask_image=mask.resize((width, height)),
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=torch.Generator().manual_seed(seed),
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["latents"] if callback else None,
).images[0]
return result.convert("RGBA"), step_images, prompt, seed
def inpaint_api(
image,
mask,
preserved_area_mask=None,
prompt: str = "",
seed: int = -1,
num_inference_steps: int = 40,
guidance_scale: float = 30.0,
strength: float = 1.0,
flux_keywords: list[str] = None,
loras_selected: list[tuple[str, float]] = None,
):
selected_loras_with_weights = []
if loras_selected:
for name, weight_value in loras_selected:
try:
weight = float(weight_value)
except (ValueError, TypeError):
continue
lora_obj = next((l for l in loras if l.display_name == name), None)
if lora_obj and weight != 0.0:
selected_loras_with_weights.append((lora_obj, weight))
deactivate_loras(pipe)
if selected_loras_with_weights:
activate_loras(pipe, selected_loras_with_weights)
final_prompt = ""
if flux_keywords:
final_prompt += ", ".join(flux_keywords) + ", "
if selected_loras_with_weights:
for lora, _ in selected_loras_with_weights:
if lora.keyword:
final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
final_prompt += prompt
if not isinstance(seed, int) or seed < 0:
seed = random.randint(0, MAX_SEED)
return inpaint(
image=image,
mask=mask,
preserved_area_mask=preserved_area_mask,
prompt=final_prompt,
seed=seed,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,
)
with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
num_inference_steps_input = gr.Number(label="Inference steps", value=40)
guidance_scale_input = gr.Number(label="Guidance scale", value=30)
strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
gr.Markdown("### Flux Keywords")
flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
if loras:
gr.Markdown("### Available LoRAs")
lora_names = [l.display_name for l in loras]
loras_selected_input = gr.Dataframe(
type="array",
headers=["LoRA", "Weight"],
value=[[name, 0.0] for name in lora_names],
datatype=["str", "number"],
interactive=[False, True],
label="LoRA selection",
)
with gr.Column(scale=3):
image_input = gr.Image(label="Original Image", type="pil")
mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil")
preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
result_image = gr.Image(label="Result")
used_prompt_box = gr.Text(label="Final Prompt")
used_seed_box = gr.Number(label="Used Seed")
steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True)
run_btn.click(
fn=inpaint_api,
inputs=[
image_input,
mask_input,
preserved_area_input,
prompt_input,
seed_slider,
num_inference_steps_input,
guidance_scale_input,
strength_input,
flux_keywords_input,
loras_selected_input,
],
outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box],
)
if __name__ == "__main__":
demo.launch()