import random import gradio as gr import numpy as np import spaces import torch from diffusers import FluxFillPipeline from loras import LoRA, loras from PIL import Image MAX_SEED = np.iinfo(np.int32).max pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) flux_keywords_available = ["IMG_1025.HEIC", "Selfie"] # --- LATENT MANIPULATION FUNCTIONS --- def pack_latents(latents, batch_size, num_channels, height, width): latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4) return latents def unpack_latents(latents, height, width, h_scale=2, w_scale=2): batch_size, seq_len, channels = latents.shape # Flux uses a 2x2 patch, so the factor is 2 latents = latents.view( batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale ) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width) return latents # --- CALLBACK (PRESERVED AREA + STEP CAPTURE) --- def get_gradual_blend_callback( pipe, original_image, preserved_area_mask, total_steps, step_images_list, start_alpha=1.0, end_alpha=0.2, ): device = pipe.device dtype = pipe.transformer.dtype packed_init_latents = None packed_preserved_mask = None h_latent = w_latent = None if preserved_area_mask is not None: with torch.no_grad(): img_tensor = ( (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0) .unsqueeze(0) .to(device, dtype) ) init_latents = pipe.vae.encode(img_tensor).latent_dist.sample() init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor _, _, h_latent, w_latent = init_latents.shape packed_init_latents = pack_latents( init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent ) mask_tensor = ( (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0) .unsqueeze(0) .unsqueeze(0) .to(device, dtype) ) latent_preserved_mask = torch.nn.functional.interpolate( mask_tensor, size=(h_latent, w_latent), mode="nearest" ) packed_preserved_mask = pack_latents( latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent ) def callback_fn(pipe, step, timestep, callback_kwargs): latents = callback_kwargs["latents"] if packed_preserved_mask is not None: progress = step / max(1, total_steps - 1) current_alpha = start_alpha - (start_alpha - end_alpha) * progress effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16) latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents if step % 5 == 0 or step == total_steps - 1: with torch.no_grad(): unpacked = unpack_latents(latents, h_latent, w_latent) unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0] step_images_list.append(img_step) callback_kwargs["latents"] = latents return callback_kwargs return callback_fn # --- LoRA's FUNCTIONS --- def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]): adapter_names = [] adapter_weights = [] for lora, weight in loras_with_weights: pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name) adapter_names.append(lora.name) adapter_weights.append(weight) pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) return pipe def deactivate_loras(pipe): pipe.unload_lora_weights() return pipe # --- GENERATION def calculate_optimal_dimensions(image): original_width, original_height = image.size FIXED_DIMENSION = 1024 aspect_ratio = original_width / original_height if aspect_ratio > 1: width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio) else: height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio) return (width // 8) * 8, (height // 8) * 8 @spaces.GPU(duration=60) def inpaint( image, mask, preserved_area_mask=None, prompt: str = "", seed: int = 0, num_inference_steps: int = 28, guidance_scale: int = 50, strength: float = 1.0, ): image = image.convert("RGB") mask = mask.convert("L") width, height = calculate_optimal_dimensions(image) # Resize to match dimensions image_resized = image.resize((width, height), Image.LANCZOS) pipe.to("cuda") # Setup callback if a preserved area mask is provided step_images = [] callback = None if preserved_area_mask is not None: preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST) callback = get_gradual_blend_callback( pipe, image_resized, preserved_area_resized, num_inference_steps, step_images ) result = pipe( image=image_resized, mask_image=mask.resize((width, height)), prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=strength, generator=torch.Generator().manual_seed(seed), callback_on_step_end=callback, callback_on_step_end_tensor_inputs=["latents"] if callback else None, ).images[0] return result.convert("RGBA"), step_images, prompt, seed def inpaint_api( image, mask, preserved_area_mask=None, prompt: str = "", seed: int = -1, num_inference_steps: int = 40, guidance_scale: float = 30.0, strength: float = 1.0, flux_keywords: list[str] = None, loras_selected: list[tuple[str, float]] = None, ): selected_loras_with_weights = [] if loras_selected: for name, weight_value in loras_selected: try: weight = float(weight_value) except (ValueError, TypeError): continue lora_obj = next((l for l in loras if l.display_name == name), None) if lora_obj and weight != 0.0: selected_loras_with_weights.append((lora_obj, weight)) deactivate_loras(pipe) if selected_loras_with_weights: activate_loras(pipe, selected_loras_with_weights) final_prompt = "" if flux_keywords: final_prompt += ", ".join(flux_keywords) + ", " if selected_loras_with_weights: for lora, _ in selected_loras_with_weights: if lora.keyword: final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", " final_prompt += prompt if not isinstance(seed, int) or seed < 0: seed = random.randint(0, MAX_SEED) return inpaint( image=image, mask=mask, preserved_area_mask=preserved_area_mask, prompt=final_prompt, seed=seed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=strength, ) with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman") seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1) num_inference_steps_input = gr.Number(label="Inference steps", value=40) guidance_scale_input = gr.Number(label="Guidance scale", value=30) strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0) gr.Markdown("### Flux Keywords") flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords") if loras: gr.Markdown("### Available LoRAs") lora_names = [l.display_name for l in loras] loras_selected_input = gr.Dataframe( type="array", headers=["LoRA", "Weight"], value=[[name, 0.0] for name in lora_names], datatype=["str", "number"], interactive=[False, True], label="LoRA selection", ) with gr.Column(scale=3): image_input = gr.Image(label="Original Image", type="pil") mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil") preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil") run_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=3): result_image = gr.Image(label="Result") used_prompt_box = gr.Text(label="Final Prompt") used_seed_box = gr.Number(label="Used Seed") steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True) run_btn.click( fn=inpaint_api, inputs=[ image_input, mask_input, preserved_area_input, prompt_input, seed_slider, num_inference_steps_input, guidance_scale_input, strength_input, flux_keywords_input, loras_selected_input, ], outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box], ) if __name__ == "__main__": demo.launch()