X-HighVoltage-X commited on
Commit
f3d4c16
·
verified ·
1 Parent(s): 1041613

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -180
app.py CHANGED
@@ -10,132 +10,91 @@ from PIL import Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
- pipe = FluxFillPipeline.from_pretrained(
14
- "black-forest-labs/FLUX.1-Fill-dev",
15
- torch_dtype=torch.bfloat16,
16
- )
17
 
18
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
19
 
20
- # ------------------------------------------------------------------
21
- # LATENT MANIPULATION
22
- # ------------------------------------------------------------------
23
-
24
  def pack_latents(latents, batch_size, num_channels, height, width):
25
  latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
26
  latents = latents.permute(0, 2, 4, 1, 3, 5)
27
- latents = latents.reshape(
28
- batch_size,
29
- (height // 2) * (width // 2),
30
- num_channels * 4,
31
- )
32
  return latents
33
 
34
 
35
  def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
36
  batch_size, seq_len, channels = latents.shape
 
37
  latents = latents.view(
38
- batch_size,
39
- height // h_scale,
40
- width // w_scale,
41
- channels // (h_scale * w_scale),
42
- h_scale,
43
- w_scale,
44
  )
45
  latents = latents.permute(0, 3, 1, 4, 2, 5)
46
- latents = latents.reshape(
47
- batch_size,
48
- channels // (h_scale * w_scale),
49
- height,
50
- width,
51
- )
52
  return latents
53
 
54
 
55
- # ------------------------------------------------------------------
56
- # HARD PRESERVE CALLBACK
57
- # ------------------------------------------------------------------
58
-
59
- def get_hard_preserve_callback(
60
  pipe,
61
  original_image,
62
  preserved_area_mask,
63
  total_steps,
64
  step_images_list,
 
 
65
  ):
66
  device = pipe.device
67
  dtype = pipe.transformer.dtype
68
 
69
- with torch.no_grad():
70
- img_tensor = (
71
- torch.from_numpy(np.array(original_image).transpose(2, 0, 1))
72
- .float()
73
- / 127.5
74
- - 1.0
75
- )
76
- img_tensor = img_tensor.unsqueeze(0).to(device, dtype)
77
-
78
- init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
79
- init_latents = (
80
- init_latents - pipe.vae.config.shift_factor
81
- ) * pipe.vae.config.scaling_factor
82
- init_latents = init_latents.to(dtype)
83
-
84
- _, _, h_latent, w_latent = init_latents.shape
85
-
86
- packed_init_latents = pack_latents(
87
- init_latents,
88
- batch_size=1,
89
- num_channels=16,
90
- height=h_latent,
91
- width=w_latent,
92
- ).to(dtype)
93
-
94
- mask_tensor = (
95
- torch.from_numpy(np.array(preserved_area_mask.convert("L")))
96
- .float()
97
- / 255.0
98
- )
99
- mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device, dtype)
100
 
101
- latent_mask = torch.nn.functional.interpolate(
102
- mask_tensor,
103
- size=(h_latent, w_latent),
104
- mode="nearest",
105
- )
 
 
 
 
106
 
107
- packed_preserved_mask = pack_latents(
108
- latent_mask,
109
- batch_size=1,
110
- num_channels=1,
111
- height=h_latent,
112
- width=w_latent,
113
- )
114
 
115
- packed_preserved_mask = (packed_preserved_mask > 0.5).to(dtype)
116
- packed_preserved_mask = packed_preserved_mask.repeat(1, 1, 16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def callback_fn(pipe, step, timestep, callback_kwargs):
119
  latents = callback_kwargs["latents"]
120
- latent_dtype = latents.dtype
121
 
122
- latents = (
123
- latents * (1.0 - packed_preserved_mask)
124
- + packed_init_latents * packed_preserved_mask
125
- ).to(latent_dtype)
 
 
126
 
127
  if step % 5 == 0 or step == total_steps - 1:
128
  with torch.no_grad():
129
  unpacked = unpack_latents(latents, h_latent, w_latent)
130
- unpacked = (
131
- unpacked / pipe.vae.config.scaling_factor
132
- ) + pipe.vae.config.shift_factor
133
- decoded = pipe.vae.decode(
134
- unpacked.to(pipe.vae.dtype)
135
- ).sample
136
- img_step = pipe.image_processor.postprocess(
137
- decoded, output_type="pil"
138
- )[0]
139
  step_images_list.append(img_step)
140
 
141
  callback_kwargs["latents"] = latents
@@ -144,10 +103,7 @@ def get_hard_preserve_callback(
144
  return callback_fn
145
 
146
 
147
- # ------------------------------------------------------------------
148
- # LORA UTILITIES
149
- # ------------------------------------------------------------------
150
-
151
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
152
  adapter_names = []
153
  adapter_weights = []
@@ -164,10 +120,7 @@ def deactivate_loras(pipe):
164
  return pipe
165
 
166
 
167
- # ------------------------------------------------------------------
168
- # GENERATION
169
- # ------------------------------------------------------------------
170
-
171
  def calculate_optimal_dimensions(image):
172
  original_width, original_height = image.size
173
  FIXED_DIMENSION = 1024
@@ -192,30 +145,25 @@ def inpaint(
192
  ):
193
  image = image.convert("RGB")
194
  mask = mask.convert("L")
195
-
196
  width, height = calculate_optimal_dimensions(image)
 
 
197
  image_resized = image.resize((width, height), Image.LANCZOS)
198
 
199
  pipe.to("cuda")
200
 
 
201
  step_images = []
202
  callback = None
203
-
204
  if preserved_area_mask is not None:
205
- preserved_area_resized = preserved_area_mask.resize(
206
- (width, height), Image.NEAREST
207
- )
208
- callback = get_hard_preserve_callback(
209
- pipe,
210
- image_resized,
211
- preserved_area_resized,
212
- num_inference_steps,
213
- step_images,
214
  )
215
 
216
  result = pipe(
217
  image=image_resized,
218
- mask_image=mask.resize((width, height), Image.NEAREST),
219
  prompt=prompt,
220
  width=width,
221
  height=height,
@@ -261,16 +209,12 @@ def inpaint_api(
261
  final_prompt = ""
262
  if flux_keywords:
263
  final_prompt += ", ".join(flux_keywords) + ", "
264
-
265
  if selected_loras_with_weights:
266
  for lora, _ in selected_loras_with_weights:
267
  if lora.keyword:
268
- final_prompt += (
269
- lora.keyword
270
- if isinstance(lora.keyword, str)
271
- else ", ".join(lora.keyword)
272
- ) + ", "
273
-
274
  final_prompt += prompt
275
 
276
  if not isinstance(seed, int) or seed < 0:
@@ -288,47 +232,17 @@ def inpaint_api(
288
  )
289
 
290
 
291
- # ------------------------------------------------------------------
292
- # UI
293
- # ------------------------------------------------------------------
294
-
295
- with gr.Blocks(
296
- title="FLUX.1 Fill dev + HARD Area Preservation",
297
- theme=gr.themes.Soft(),
298
- ) as demo:
299
  with gr.Row():
300
  with gr.Column(scale=2):
301
- prompt_input = gr.Text(
302
- label="Prompt",
303
- lines=4,
304
- value="a 25 years old woman",
305
- )
306
- seed_slider = gr.Slider(
307
- label="Seed",
308
- minimum=-1,
309
- maximum=MAX_SEED,
310
- step=1,
311
- value=-1,
312
- )
313
- num_inference_steps_input = gr.Number(
314
- label="Inference steps",
315
- value=40,
316
- )
317
- guidance_scale_input = gr.Number(
318
- label="Guidance scale",
319
- value=30,
320
- )
321
- strength_input = gr.Number(
322
- label="Strength",
323
- value=1.0,
324
- maximum=1.0,
325
- )
326
 
327
  gr.Markdown("### Flux Keywords")
328
- flux_keywords_input = gr.CheckboxGroup(
329
- choices=flux_keywords_available,
330
- label="Flux Keywords",
331
- )
332
 
333
  if loras:
334
  gr.Markdown("### Available LoRAs")
@@ -343,32 +257,16 @@ with gr.Blocks(
343
  )
344
 
345
  with gr.Column(scale=3):
346
- image_input = gr.Image(
347
- label="Original Image",
348
- type="pil",
349
- )
350
- mask_input = gr.Image(
351
- label="Inpaint Mask (Area to change)",
352
- type="pil",
353
- )
354
- preserved_area_input = gr.Image(
355
- label="Preserved Area Mask (Area to keep)",
356
- type="pil",
357
- )
358
- run_btn = gr.Button(
359
- "Generate",
360
- variant="primary",
361
- )
362
 
363
  with gr.Column(scale=3):
364
  result_image = gr.Image(label="Result")
365
  used_prompt_box = gr.Text(label="Final Prompt")
366
  used_seed_box = gr.Number(label="Used Seed")
367
- steps_gallery = gr.Gallery(
368
- label="Evolution (Steps)",
369
- columns=3,
370
- preview=True,
371
- )
372
 
373
  run_btn.click(
374
  fn=inpaint_api,
@@ -384,12 +282,7 @@ with gr.Blocks(
384
  flux_keywords_input,
385
  loras_selected_input,
386
  ],
387
- outputs=[
388
- result_image,
389
- steps_gallery,
390
- used_prompt_box,
391
- used_seed_box,
392
- ],
393
  )
394
 
395
  if __name__ == "__main__":
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
+ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
 
 
 
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
17
+ # --- LATENT MANIPULATION FUNCTIONS ---
 
 
 
18
  def pack_latents(latents, batch_size, num_channels, height, width):
19
  latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
20
  latents = latents.permute(0, 2, 4, 1, 3, 5)
21
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
 
 
 
 
22
  return latents
23
 
24
 
25
  def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
26
  batch_size, seq_len, channels = latents.shape
27
+ # Flux uses a 2x2 patch, so the factor is 2
28
  latents = latents.view(
29
+ batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
 
 
 
 
 
30
  )
31
  latents = latents.permute(0, 3, 1, 4, 2, 5)
32
+ latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
 
 
 
 
 
33
  return latents
34
 
35
 
36
+ # --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
37
+ def get_gradual_blend_callback(
 
 
 
38
  pipe,
39
  original_image,
40
  preserved_area_mask,
41
  total_steps,
42
  step_images_list,
43
+ start_alpha=1.0,
44
+ end_alpha=0.2,
45
  ):
46
  device = pipe.device
47
  dtype = pipe.transformer.dtype
48
 
49
+ packed_init_latents = None
50
+ packed_preserved_mask = None
51
+ h_latent = w_latent = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ if preserved_area_mask is not None:
54
+ with torch.no_grad():
55
+ img_tensor = (
56
+ (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
57
+ .unsqueeze(0)
58
+ .to(device, dtype)
59
+ )
60
+ init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
61
+ init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
62
 
63
+ _, _, h_latent, w_latent = init_latents.shape
 
 
 
 
 
 
64
 
65
+ packed_init_latents = pack_latents(
66
+ init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent
67
+ )
68
+
69
+ mask_tensor = (
70
+ (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
71
+ .unsqueeze(0)
72
+ .unsqueeze(0)
73
+ .to(device, dtype)
74
+ )
75
+ latent_preserved_mask = torch.nn.functional.interpolate(
76
+ mask_tensor, size=(h_latent, w_latent), mode="nearest"
77
+ )
78
+ packed_preserved_mask = pack_latents(
79
+ latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
80
+ )
81
 
82
  def callback_fn(pipe, step, timestep, callback_kwargs):
83
  latents = callback_kwargs["latents"]
 
84
 
85
+ if packed_preserved_mask is not None:
86
+ progress = step / max(1, total_steps - 1)
87
+ current_alpha = start_alpha - (start_alpha - end_alpha) * progress
88
+
89
+ effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
90
+ latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents
91
 
92
  if step % 5 == 0 or step == total_steps - 1:
93
  with torch.no_grad():
94
  unpacked = unpack_latents(latents, h_latent, w_latent)
95
+ unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
96
+ decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
97
+ img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
 
 
 
 
 
 
98
  step_images_list.append(img_step)
99
 
100
  callback_kwargs["latents"] = latents
 
103
  return callback_fn
104
 
105
 
106
+ # --- LoRA's FUNCTIONS ---
 
 
 
107
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
108
  adapter_names = []
109
  adapter_weights = []
 
120
  return pipe
121
 
122
 
123
+ # --- GENERATION
 
 
 
124
  def calculate_optimal_dimensions(image):
125
  original_width, original_height = image.size
126
  FIXED_DIMENSION = 1024
 
145
  ):
146
  image = image.convert("RGB")
147
  mask = mask.convert("L")
 
148
  width, height = calculate_optimal_dimensions(image)
149
+
150
+ # Resize to match dimensions
151
  image_resized = image.resize((width, height), Image.LANCZOS)
152
 
153
  pipe.to("cuda")
154
 
155
+ # Setup callback if a preserved area mask is provided
156
  step_images = []
157
  callback = None
 
158
  if preserved_area_mask is not None:
159
+ preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
160
+ callback = get_gradual_blend_callback(
161
+ pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
 
 
 
 
 
 
162
  )
163
 
164
  result = pipe(
165
  image=image_resized,
166
+ mask_image=mask.resize((width, height)),
167
  prompt=prompt,
168
  width=width,
169
  height=height,
 
209
  final_prompt = ""
210
  if flux_keywords:
211
  final_prompt += ", ".join(flux_keywords) + ", "
212
+
213
  if selected_loras_with_weights:
214
  for lora, _ in selected_loras_with_weights:
215
  if lora.keyword:
216
+ final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
217
+
 
 
 
 
218
  final_prompt += prompt
219
 
220
  if not isinstance(seed, int) or seed < 0:
 
232
  )
233
 
234
 
235
+ with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
236
  with gr.Row():
237
  with gr.Column(scale=2):
238
+ prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
239
+ seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
240
+ num_inference_steps_input = gr.Number(label="Inference steps", value=40)
241
+ guidance_scale_input = gr.Number(label="Guidance scale", value=30)
242
+ strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  gr.Markdown("### Flux Keywords")
245
+ flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
 
 
 
246
 
247
  if loras:
248
  gr.Markdown("### Available LoRAs")
 
257
  )
258
 
259
  with gr.Column(scale=3):
260
+ image_input = gr.Image(label="Original Image", type="pil")
261
+ mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil")
262
+ preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil")
263
+ run_btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  with gr.Column(scale=3):
266
  result_image = gr.Image(label="Result")
267
  used_prompt_box = gr.Text(label="Final Prompt")
268
  used_seed_box = gr.Number(label="Used Seed")
269
+ steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True)
 
 
 
 
270
 
271
  run_btn.click(
272
  fn=inpaint_api,
 
282
  flux_keywords_input,
283
  loras_selected_input,
284
  ],
285
+ outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box],
 
 
 
 
 
286
  )
287
 
288
  if __name__ == "__main__":