File size: 10,142 Bytes
aa95e47
 
 
 
 
 
 
 
4ad6fb2
aa95e47
 
 
f3d4c16
aa95e47
 
 
f3d4c16
4ad6fb2
 
 
f3d4c16
4ad6fb2
aa95e47
4ad6fb2
 
 
f3d4c16
4ad6fb2
f3d4c16
4ad6fb2
 
f3d4c16
4ad6fb2
 
 
f3d4c16
 
4ad6fb2
 
 
 
 
f3d4c16
 
4ad6fb2
 
 
 
f3d4c16
 
 
beb6b67
f3d4c16
 
 
 
 
 
 
 
 
beb6b67
f3d4c16
beb6b67
f3d4c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ad6fb2
 
7206c75
 
f3d4c16
 
 
 
 
 
7206c75
 
 
 
f3d4c16
 
 
7206c75
 
 
 
4ad6fb2
 
 
 
f3d4c16
aa95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3d4c16
aa95e47
 
 
4ad6fb2
 
 
aa95e47
4ad6fb2
 
 
 
7206c75
aa95e47
 
 
4ad6fb2
aa95e47
 
 
 
 
 
 
 
beb6b67
f3d4c16
 
4ad6fb2
 
aa95e47
4ad6fb2
f3d4c16
4ad6fb2
 
 
f3d4c16
 
 
4ad6fb2
 
aa95e47
4ad6fb2
f3d4c16
aa95e47
 
 
 
 
 
 
4ad6fb2
 
aa95e47
 
1041613
aa95e47
 
 
 
 
4ad6fb2
 
 
 
 
 
 
 
aa95e47
 
 
4ad6fb2
 
 
 
 
 
 
 
 
aa95e47
 
 
 
 
 
 
 
f3d4c16
4ad6fb2
 
 
f3d4c16
 
aa95e47
 
 
 
 
 
 
 
4ad6fb2
aa95e47
 
 
 
 
 
 
 
f3d4c16
aa95e47
 
f3d4c16
 
 
 
 
aa95e47
 
f3d4c16
aa95e47
 
 
 
 
 
 
 
4ad6fb2
 
 
aa95e47
 
 
f3d4c16
 
 
 
aa95e47
 
 
4ad6fb2
 
f3d4c16
aa95e47
 
 
 
 
 
4ad6fb2
aa95e47
 
 
 
 
 
 
 
f3d4c16
aa95e47
 
 
4ad6fb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import random

import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxFillPipeline
from loras import LoRA, loras
from PIL import Image

MAX_SEED = np.iinfo(np.int32).max

pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)

flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]

# --- LATENT MANIPULATION FUNCTIONS ---
def pack_latents(latents, batch_size, num_channels, height, width):
    latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
    return latents


def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
    batch_size, seq_len, channels = latents.shape
    # Flux uses a 2x2 patch, so the factor is 2
    latents = latents.view(
        batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
    )
    latents = latents.permute(0, 3, 1, 4, 2, 5)
    latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
    return latents


# --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
def get_gradual_blend_callback(
    pipe,
    original_image,
    preserved_area_mask,
    total_steps,
    step_images_list,
    start_alpha=1.0,
    end_alpha=0.2,
):
    device = pipe.device
    dtype = pipe.transformer.dtype

    packed_init_latents = None
    packed_preserved_mask = None
    h_latent = w_latent = None

    if preserved_area_mask is not None:
        with torch.no_grad():
            img_tensor = (
                (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
                .unsqueeze(0)
                .to(device, dtype)
            )
            init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
            init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor

            _, _, h_latent, w_latent = init_latents.shape

            packed_init_latents = pack_latents(
                init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent
            )

            mask_tensor = (
                (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
                .unsqueeze(0)
                .unsqueeze(0)
                .to(device, dtype)
            )
            latent_preserved_mask = torch.nn.functional.interpolate(
                mask_tensor, size=(h_latent, w_latent), mode="nearest"
            )
            packed_preserved_mask = pack_latents(
                latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
            )

    def callback_fn(pipe, step, timestep, callback_kwargs):
        latents = callback_kwargs["latents"]

        if packed_preserved_mask is not None:
            progress = step / max(1, total_steps - 1)
            current_alpha = start_alpha - (start_alpha - end_alpha) * progress

            effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
            latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents

        if step % 5 == 0 or step == total_steps - 1:
            with torch.no_grad():
                unpacked = unpack_latents(latents, h_latent, w_latent)
                unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
                decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
                img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
                step_images_list.append(img_step)

        callback_kwargs["latents"] = latents
        return callback_kwargs

    return callback_fn


# --- LoRA's FUNCTIONS ---
def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
    adapter_names = []
    adapter_weights = []
    for lora, weight in loras_with_weights:
        pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
        adapter_names.append(lora.name)
        adapter_weights.append(weight)
    pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
    return pipe


def deactivate_loras(pipe):
    pipe.unload_lora_weights()
    return pipe


# --- GENERATION
def calculate_optimal_dimensions(image):
    original_width, original_height = image.size
    FIXED_DIMENSION = 1024
    aspect_ratio = original_width / original_height
    if aspect_ratio > 1:
        width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio)
    else:
        height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio)
    return (width // 8) * 8, (height // 8) * 8


@spaces.GPU(duration=60)
def inpaint(
    image,
    mask,
    preserved_area_mask=None,
    prompt: str = "",
    seed: int = 0,
    num_inference_steps: int = 28,
    guidance_scale: int = 50,
    strength: float = 1.0,
):
    image = image.convert("RGB")
    mask = mask.convert("L")
    width, height = calculate_optimal_dimensions(image)

    # Resize to match dimensions
    image_resized = image.resize((width, height), Image.LANCZOS)

    pipe.to("cuda")

    # Setup callback if a preserved area mask is provided
    step_images = []
    callback = None
    if preserved_area_mask is not None:
        preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
        callback = get_gradual_blend_callback(
            pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
        )

    result = pipe(
        image=image_resized,
        mask_image=mask.resize((width, height)),
        prompt=prompt,
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        strength=strength,
        generator=torch.Generator().manual_seed(seed),
        callback_on_step_end=callback,
        callback_on_step_end_tensor_inputs=["latents"] if callback else None,
    ).images[0]

    return result.convert("RGBA"), step_images, prompt, seed


def inpaint_api(
    image,
    mask,
    preserved_area_mask=None,
    prompt: str = "",
    seed: int = -1,
    num_inference_steps: int = 40,
    guidance_scale: float = 30.0,
    strength: float = 1.0,
    flux_keywords: list[str] = None,
    loras_selected: list[tuple[str, float]] = None,
):
    selected_loras_with_weights = []

    if loras_selected:
        for name, weight_value in loras_selected:
            try:
                weight = float(weight_value)
            except (ValueError, TypeError):
                continue
            lora_obj = next((l for l in loras if l.display_name == name), None)
            if lora_obj and weight != 0.0:
                selected_loras_with_weights.append((lora_obj, weight))

    deactivate_loras(pipe)
    if selected_loras_with_weights:
        activate_loras(pipe, selected_loras_with_weights)

    final_prompt = ""
    if flux_keywords:
        final_prompt += ", ".join(flux_keywords) + ", "
    
    if selected_loras_with_weights:
        for lora, _ in selected_loras_with_weights:
            if lora.keyword:
                final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
    
    final_prompt += prompt

    if not isinstance(seed, int) or seed < 0:
        seed = random.randint(0, MAX_SEED)

    return inpaint(
        image=image,
        mask=mask,
        preserved_area_mask=preserved_area_mask,
        prompt=final_prompt,
        seed=seed,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        strength=strength,
    )


with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
            seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
            num_inference_steps_input = gr.Number(label="Inference steps", value=40)
            guidance_scale_input = gr.Number(label="Guidance scale", value=30)
            strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)

            gr.Markdown("### Flux Keywords")
            flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")

            if loras:
                gr.Markdown("### Available LoRAs")
                lora_names = [l.display_name for l in loras]
                loras_selected_input = gr.Dataframe(
                    type="array",
                    headers=["LoRA", "Weight"],
                    value=[[name, 0.0] for name in lora_names],
                    datatype=["str", "number"],
                    interactive=[False, True],
                    label="LoRA selection",
                )

        with gr.Column(scale=3):
            image_input = gr.Image(label="Original Image", type="pil")
            mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil")
            preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil")
            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column(scale=3):
            result_image = gr.Image(label="Result")
            used_prompt_box = gr.Text(label="Final Prompt")
            used_seed_box = gr.Number(label="Used Seed")
            steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True)

    run_btn.click(
        fn=inpaint_api,
        inputs=[
            image_input,
            mask_input,
            preserved_area_input,
            prompt_input,
            seed_slider,
            num_inference_steps_input,
            guidance_scale_input,
            strength_input,
            flux_keywords_input,
            loras_selected_input,
        ],
        outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box],
    )

if __name__ == "__main__":
    demo.launch()