X-HighVoltage-X commited on
Commit
beb6b67
·
verified ·
1 Parent(s): 4ad6fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -61
app.py CHANGED
@@ -10,91 +10,135 @@ from PIL import Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
- pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
 
 
 
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
17
- # --- LATENT MANIPULATION FUNCTIONS ---
 
 
 
18
  def pack_latents(latents, batch_size, num_channels, height, width):
19
  latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
20
  latents = latents.permute(0, 2, 4, 1, 3, 5)
21
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
 
 
 
 
22
  return latents
23
 
24
 
25
  def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
26
  batch_size, seq_len, channels = latents.shape
27
- # Flux uses a 2x2 patch, so the factor is 2
28
  latents = latents.view(
29
- batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
 
 
 
 
 
30
  )
31
  latents = latents.permute(0, 3, 1, 4, 2, 5)
32
- latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
 
 
 
 
 
33
  return latents
34
 
35
 
36
- # --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
37
- def get_gradual_blend_callback(
 
 
 
38
  pipe,
39
  original_image,
40
  preserved_area_mask,
41
  total_steps,
42
  step_images_list,
43
- start_alpha=1.0,
44
- end_alpha=0.2,
45
  ):
46
  device = pipe.device
47
  dtype = pipe.transformer.dtype
48
 
49
- packed_init_latents = None
50
- packed_preserved_mask = None
51
- h_latent = w_latent = None
 
 
 
 
 
 
52
 
53
- if preserved_area_mask is not None:
54
- with torch.no_grad():
55
- img_tensor = (
56
- (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
57
- .unsqueeze(0)
58
- .to(device, dtype)
59
- )
60
- init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
61
- init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
62
 
63
- _, _, h_latent, w_latent = init_latents.shape
64
 
65
- packed_init_latents = pack_latents(
66
- init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent
67
- )
 
 
 
 
68
 
69
- mask_tensor = (
70
- (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
71
- .unsqueeze(0)
72
- .unsqueeze(0)
73
- .to(device, dtype)
74
- )
75
- latent_preserved_mask = torch.nn.functional.interpolate(
76
- mask_tensor, size=(h_latent, w_latent), mode="nearest"
77
- )
78
- packed_preserved_mask = pack_latents(
79
- latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def callback_fn(pipe, step, timestep, callback_kwargs):
83
  latents = callback_kwargs["latents"]
84
 
85
- if packed_preserved_mask is not None:
86
- progress = step / max(1, total_steps - 1)
87
- current_alpha = start_alpha - (start_alpha - end_alpha) * progress
88
-
89
- effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
90
- latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents
91
 
 
92
  if step % 5 == 0 or step == total_steps - 1:
93
  with torch.no_grad():
94
  unpacked = unpack_latents(latents, h_latent, w_latent)
95
- unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
96
- decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
97
- img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
 
 
 
 
 
 
98
  step_images_list.append(img_step)
99
 
100
  callback_kwargs["latents"] = latents
@@ -103,7 +147,10 @@ def get_gradual_blend_callback(
103
  return callback_fn
104
 
105
 
106
- # --- LoRA's FUNCTIONS ---
 
 
 
107
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
108
  adapter_names = []
109
  adapter_weights = []
@@ -120,7 +167,10 @@ def deactivate_loras(pipe):
120
  return pipe
121
 
122
 
123
- # --- GENERATION
 
 
 
124
  def calculate_optimal_dimensions(image):
125
  original_width, original_height = image.size
126
  FIXED_DIMENSION = 1024
@@ -145,25 +195,30 @@ def inpaint(
145
  ):
146
  image = image.convert("RGB")
147
  mask = mask.convert("L")
148
- width, height = calculate_optimal_dimensions(image)
149
 
150
- # Resize to match dimensions
151
  image_resized = image.resize((width, height), Image.LANCZOS)
152
 
153
  pipe.to("cuda")
154
 
155
- # Setup callback if a preserved area mask is provided
156
  step_images = []
157
  callback = None
 
158
  if preserved_area_mask is not None:
159
- preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
160
- callback = get_gradual_blend_callback(
161
- pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
 
 
 
 
 
 
162
  )
163
 
164
  result = pipe(
165
  image=image_resized,
166
- mask_image=mask.resize((width, height)),
167
  prompt=prompt,
168
  width=width,
169
  height=height,
@@ -209,12 +264,16 @@ def inpaint_api(
209
  final_prompt = ""
210
  if flux_keywords:
211
  final_prompt += ", ".join(flux_keywords) + ", "
212
-
213
  if selected_loras_with_weights:
214
  for lora, _ in selected_loras_with_weights:
215
  if lora.keyword:
216
- final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
217
-
 
 
 
 
218
  final_prompt += prompt
219
 
220
  if not isinstance(seed, int) or seed < 0:
@@ -232,7 +291,11 @@ def inpaint_api(
232
  )
233
 
234
 
235
- with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
 
 
 
 
236
  with gr.Row():
237
  with gr.Column(scale=2):
238
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
@@ -242,7 +305,10 @@ with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft
242
  strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
243
 
244
  gr.Markdown("### Flux Keywords")
245
- flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
 
 
 
246
 
247
  if loras:
248
  gr.Markdown("### Available LoRAs")
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
+ pipe = FluxFillPipeline.from_pretrained(
14
+ "black-forest-labs/FLUX.1-Fill-dev",
15
+ torch_dtype=torch.bfloat16,
16
+ )
17
 
18
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
19
 
20
+ # ------------------------------------------------------------------
21
+ # LATENT MANIPULATION
22
+ # ------------------------------------------------------------------
23
+
24
  def pack_latents(latents, batch_size, num_channels, height, width):
25
  latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
26
  latents = latents.permute(0, 2, 4, 1, 3, 5)
27
+ latents = latents.reshape(
28
+ batch_size,
29
+ (height // 2) * (width // 2),
30
+ num_channels * 4,
31
+ )
32
  return latents
33
 
34
 
35
  def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
36
  batch_size, seq_len, channels = latents.shape
 
37
  latents = latents.view(
38
+ batch_size,
39
+ height // h_scale,
40
+ width // w_scale,
41
+ channels // (h_scale * w_scale),
42
+ h_scale,
43
+ w_scale,
44
  )
45
  latents = latents.permute(0, 3, 1, 4, 2, 5)
46
+ latents = latents.reshape(
47
+ batch_size,
48
+ channels // (h_scale * w_scale),
49
+ height,
50
+ width,
51
+ )
52
  return latents
53
 
54
 
55
+ # ------------------------------------------------------------------
56
+ # HARD PRESERVE CALLBACK (ABSOLUTE LOCK)
57
+ # ------------------------------------------------------------------
58
+
59
+ def get_hard_preserve_callback(
60
  pipe,
61
  original_image,
62
  preserved_area_mask,
63
  total_steps,
64
  step_images_list,
 
 
65
  ):
66
  device = pipe.device
67
  dtype = pipe.transformer.dtype
68
 
69
+ with torch.no_grad():
70
+ # IMAGE → LATENTS
71
+ img_tensor = (
72
+ torch.from_numpy(np.array(original_image).transpose(2, 0, 1))
73
+ .float()
74
+ / 127.5
75
+ - 1.0
76
+ )
77
+ img_tensor = img_tensor.unsqueeze(0).to(device, dtype)
78
 
79
+ init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
80
+ init_latents = (
81
+ init_latents - pipe.vae.config.shift_factor
82
+ ) * pipe.vae.config.scaling_factor
 
 
 
 
 
83
 
84
+ _, _, h_latent, w_latent = init_latents.shape
85
 
86
+ packed_init_latents = pack_latents(
87
+ init_latents,
88
+ batch_size=1,
89
+ num_channels=16,
90
+ height=h_latent,
91
+ width=w_latent,
92
+ )
93
 
94
+ # MASK → LATENT MASK (BINARY, HARD)
95
+ mask_tensor = (
96
+ torch.from_numpy(np.array(preserved_area_mask.convert("L")))
97
+ .float()
98
+ / 255.0
99
+ )
100
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device, dtype)
101
+
102
+ latent_mask = torch.nn.functional.interpolate(
103
+ mask_tensor,
104
+ size=(h_latent, w_latent),
105
+ mode="nearest", # CRITICAL
106
+ )
107
+
108
+ packed_preserved_mask = pack_latents(
109
+ latent_mask,
110
+ batch_size=1,
111
+ num_channels=1,
112
+ height=h_latent,
113
+ width=w_latent,
114
+ )
115
+
116
+ # strict binary
117
+ packed_preserved_mask = (packed_preserved_mask > 0.5).float()
118
+ packed_preserved_mask = packed_preserved_mask.repeat(1, 1, 16)
119
 
120
  def callback_fn(pipe, step, timestep, callback_kwargs):
121
  latents = callback_kwargs["latents"]
122
 
123
+ # ABSOLUTE OVERWRITE EVERY STEP
124
+ latents = (
125
+ latents * (1.0 - packed_preserved_mask)
126
+ + packed_init_latents * packed_preserved_mask
127
+ )
 
128
 
129
+ # Debug steps
130
  if step % 5 == 0 or step == total_steps - 1:
131
  with torch.no_grad():
132
  unpacked = unpack_latents(latents, h_latent, w_latent)
133
+ unpacked = (
134
+ unpacked / pipe.vae.config.scaling_factor
135
+ ) + pipe.vae.config.shift_factor
136
+ decoded = pipe.vae.decode(
137
+ unpacked.to(pipe.vae.dtype)
138
+ ).sample
139
+ img_step = pipe.image_processor.postprocess(
140
+ decoded, output_type="pil"
141
+ )[0]
142
  step_images_list.append(img_step)
143
 
144
  callback_kwargs["latents"] = latents
 
147
  return callback_fn
148
 
149
 
150
+ # ------------------------------------------------------------------
151
+ # LoRA UTILS
152
+ # ------------------------------------------------------------------
153
+
154
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
155
  adapter_names = []
156
  adapter_weights = []
 
167
  return pipe
168
 
169
 
170
+ # ------------------------------------------------------------------
171
+ # GENERATION
172
+ # ------------------------------------------------------------------
173
+
174
  def calculate_optimal_dimensions(image):
175
  original_width, original_height = image.size
176
  FIXED_DIMENSION = 1024
 
195
  ):
196
  image = image.convert("RGB")
197
  mask = mask.convert("L")
 
198
 
199
+ width, height = calculate_optimal_dimensions(image)
200
  image_resized = image.resize((width, height), Image.LANCZOS)
201
 
202
  pipe.to("cuda")
203
 
 
204
  step_images = []
205
  callback = None
206
+
207
  if preserved_area_mask is not None:
208
+ preserved_area_resized = preserved_area_mask.resize(
209
+ (width, height), Image.NEAREST
210
+ )
211
+ callback = get_hard_preserve_callback(
212
+ pipe,
213
+ image_resized,
214
+ preserved_area_resized,
215
+ num_inference_steps,
216
+ step_images,
217
  )
218
 
219
  result = pipe(
220
  image=image_resized,
221
+ mask_image=mask.resize((width, height), Image.NEAREST),
222
  prompt=prompt,
223
  width=width,
224
  height=height,
 
264
  final_prompt = ""
265
  if flux_keywords:
266
  final_prompt += ", ".join(flux_keywords) + ", "
267
+
268
  if selected_loras_with_weights:
269
  for lora, _ in selected_loras_with_weights:
270
  if lora.keyword:
271
+ final_prompt += (
272
+ lora.keyword
273
+ if isinstance(lora.keyword, str)
274
+ else ", ".join(lora.keyword)
275
+ ) + ", "
276
+
277
  final_prompt += prompt
278
 
279
  if not isinstance(seed, int) or seed < 0:
 
291
  )
292
 
293
 
294
+ # ------------------------------------------------------------------
295
+ # UI
296
+ # ------------------------------------------------------------------
297
+
298
+ with gr.Blocks(title="FLUX.1 Fill dev + HARD Area Preservation", theme=gr.themes.Soft()) as demo:
299
  with gr.Row():
300
  with gr.Column(scale=2):
301
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
 
305
  strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
306
 
307
  gr.Markdown("### Flux Keywords")
308
+ flux_keywords_input = gr.CheckboxGroup(
309
+ choices=flux_keywords_available,
310
+ label="Flux Keywords",
311
+ )
312
 
313
  if loras:
314
  gr.Markdown("### Available LoRAs")