X-HighVoltage-X commited on
Commit
4ad6fb2
·
verified ·
1 Parent(s): beb57ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -104
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import random
2
- from typing import List, Tuple
3
 
4
  import gradio as gr
5
  import numpy as np
@@ -7,6 +6,7 @@ import spaces
7
  import torch
8
  from diffusers import FluxFillPipeline
9
  from loras import LoRA, loras
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
@@ -14,62 +14,129 @@ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", tor
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
19
  adapter_names = []
20
  adapter_weights = []
21
-
22
  for lora, weight in loras_with_weights:
23
- print(f"Loading LoRA: {lora.name} with weight {weight}")
24
  pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
25
  adapter_names.append(lora.name)
26
  adapter_weights.append(weight)
27
-
28
- print(f"Activating adapters: {adapter_names} with weights {adapter_weights}")
29
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
30
-
31
  return pipe
32
 
33
 
34
- def get_loras() -> list[dict]:
35
- return loras
36
-
37
-
38
  def deactivate_loras(pipe):
39
- print("Unloading all LoRAs...")
40
  pipe.unload_lora_weights()
41
  return pipe
42
 
43
 
 
44
  def calculate_optimal_dimensions(image):
45
  original_width, original_height = image.size
46
- MIN_ASPECT_RATIO = 9 / 16
47
- MAX_ASPECT_RATIO = 16 / 9
48
  FIXED_DIMENSION = 1024
49
- original_aspect_ratio = original_width / original_height
50
- if original_aspect_ratio > 1:
51
- width = FIXED_DIMENSION
52
- height = round(FIXED_DIMENSION / original_aspect_ratio)
53
  else:
54
- height = FIXED_DIMENSION
55
- width = round(FIXED_DIMENSION * original_aspect_ratio)
56
- width = (width // 8) * 8
57
- height = (height // 8) * 8
58
- calculated_aspect_ratio = width / height
59
- if calculated_aspect_ratio > MAX_ASPECT_RATIO:
60
- width = int((height * MAX_ASPECT_RATIO // 8) * 8)
61
- elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
62
- height = int((width / MIN_ASPECT_RATIO // 8) * 8)
63
- width = max(width, 576) if width == FIXED_DIMENSION else width
64
- height = max(height, 576) if height == FIXED_DIMENSION else height
65
-
66
- return width, height
67
-
68
-
69
- @spaces.GPU(duration=45)
70
  def inpaint(
71
  image,
72
  mask,
 
73
  prompt: str = "",
74
  seed: int = 0,
75
  num_inference_steps: int = 28,
@@ -80,10 +147,23 @@ def inpaint(
80
  mask = mask.convert("L")
81
  width, height = calculate_optimal_dimensions(image)
82
 
 
 
 
83
  pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
84
  result = pipe(
85
- image=image,
86
- mask_image=mask,
87
  prompt=prompt,
88
  width=width,
89
  height=height,
@@ -91,62 +171,50 @@ def inpaint(
91
  guidance_scale=guidance_scale,
92
  strength=strength,
93
  generator=torch.Generator().manual_seed(seed),
 
 
94
  ).images[0]
95
 
96
- return result.convert("RGBA"), prompt, seed
97
 
98
 
99
  def inpaint_api(
100
  image,
101
  mask,
102
- prompt: str,
103
- seed: int,
104
- num_inference_steps: int,
105
- guidance_scale: int,
106
- strength: float,
107
- flux_keywords: List[str] = None,
108
- loras_selected: List[Tuple[str, float]] = None,
 
109
  ):
110
- flux_keywords = flux_keywords or []
111
- loras_selected = loras_selected or []
112
-
113
- # Convertir nombres a objetos LoRA
114
  selected_loras_with_weights = []
115
 
116
- for name, weight_value in loras_selected:
117
- try:
118
- # Convierte explícitamente el peso (que viene como string) a float
119
- weight = float(weight_value)
120
- except (ValueError, TypeError):
121
- # Ignora si el valor no es un número válido (ej: None o string vacío)
122
- print(f"Valor de peso inválido '{weight_value}' para LoRA '{name}', omitiendo.")
123
- continue # Pasa al siguiente LoRA
124
-
125
- lora_obj = next((l for l in loras if l.display_name == name), None)
126
-
127
- # Ahora la comparación 'weight != 0.0' es segura (float con float)
128
- if lora_obj and weight != 0.0:
129
- selected_loras_with_weights.append((lora_obj, weight))
130
 
131
  deactivate_loras(pipe)
132
  if selected_loras_with_weights:
133
  activate_loras(pipe, selected_loras_with_weights)
134
 
135
- # Construir prompt final
136
  final_prompt = ""
137
-
138
  if flux_keywords:
139
  final_prompt += ", ".join(flux_keywords) + ", "
140
-
141
- for lora, _ in selected_loras_with_weights:
142
- if lora.keyword:
143
- if isinstance(lora.keyword, str):
144
- final_prompt += lora.keyword + ", "
145
- else:
146
- final_prompt += ", ".join(lora.keyword) + ", "
147
-
148
- if final_prompt:
149
- final_prompt += "\n\n"
150
  final_prompt += prompt
151
 
152
  if not isinstance(seed, int) or seed < 0:
@@ -155,6 +223,7 @@ def inpaint_api(
155
  return inpaint(
156
  image=image,
157
  mask=mask,
 
158
  prompt=final_prompt,
159
  seed=seed,
160
  num_inference_steps=num_inference_steps,
@@ -163,24 +232,14 @@ def inpaint_api(
163
  )
164
 
165
 
166
- # ========================
167
- # UI DIRECTA A inpaint_api
168
- # ========================
169
- with gr.Blocks(title="Flux.1 Fill dev Inpainting with LoRAs", theme=gr.themes.Soft()) as demo:
170
- gr.api(get_loras, api_name="get_loras")
171
  with gr.Row():
172
  with gr.Column(scale=2):
173
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
174
-
175
- seed_slider = gr.Slider(
176
- label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1, info="(-1 = Random)", interactive=True
177
- )
178
-
179
- num_inference_steps_input = gr.Number(label="Inference steps", value=40, interactive=True)
180
-
181
- guidance_scale_input = gr.Number(label="Guidance scale", value=28, interactive=True)
182
-
183
- strength_input = gr.Number(label="Strength", value=1.0, interactive=True, maximum=1.0)
184
 
185
  gr.Markdown("### Flux Keywords")
186
  flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
@@ -192,31 +251,29 @@ with gr.Blocks(title="Flux.1 Fill dev Inpainting with LoRAs", theme=gr.themes.So
192
  type="array",
193
  headers=["LoRA", "Weight"],
194
  value=[[name, 0.0] for name in lora_names],
195
- datatype=["str", "number"], # Primera columna string, segunda número
196
- interactive=[False, True], # Solo la segunda columna editable
197
- static_columns=[0],
198
- label="LoRA selection (Weight 0 = disable)",
199
  )
200
 
201
  with gr.Column(scale=3):
202
- image_input = gr.Image(label="Image", type="pil")
203
-
204
- mask_input = gr.Image(label="Mask", type="pil")
205
-
206
- run_btn = gr.Button("Run", variant="primary")
207
 
208
  with gr.Column(scale=3):
209
  result_image = gr.Image(label="Result")
210
-
211
- used_prompt_box = gr.Text(label="Used prompt", lines=4)
212
-
213
- used_seed_box = gr.Number(label="Used seed")
214
 
215
  run_btn.click(
216
  fn=inpaint_api,
217
  inputs=[
218
  image_input,
219
  mask_input,
 
220
  prompt_input,
221
  seed_slider,
222
  num_inference_steps_input,
@@ -225,9 +282,8 @@ with gr.Blocks(title="Flux.1 Fill dev Inpainting with LoRAs", theme=gr.themes.So
225
  flux_keywords_input,
226
  loras_selected_input,
227
  ],
228
- outputs=[result_image, used_prompt_box, used_seed_box],
229
- api_name="inpaint",
230
  )
231
 
232
  if __name__ == "__main__":
233
- demo.launch(share=False, show_error=True)
 
1
  import random
 
2
 
3
  import gradio as gr
4
  import numpy as np
 
6
  import torch
7
  from diffusers import FluxFillPipeline
8
  from loras import LoRA, loras
9
+ from PIL import Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
 
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
17
+ # --- LATENT MANIPULATION FUNCTIONS ---
18
+ def pack_latents(latents, batch_size, num_channels, height, width):
19
+ latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
20
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
21
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
22
+ return latents
23
 
24
+
25
+ def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
26
+ batch_size, seq_len, channels = latents.shape
27
+ # Flux uses a 2x2 patch, so the factor is 2
28
+ latents = latents.view(
29
+ batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
30
+ )
31
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
32
+ latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
33
+ return latents
34
+
35
+
36
+ # --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
37
+ def get_gradual_blend_callback(
38
+ pipe,
39
+ original_image,
40
+ preserved_area_mask,
41
+ total_steps,
42
+ step_images_list,
43
+ start_alpha=1.0,
44
+ end_alpha=0.2,
45
+ ):
46
+ device = pipe.device
47
+ dtype = pipe.transformer.dtype
48
+
49
+ packed_init_latents = None
50
+ packed_preserved_mask = None
51
+ h_latent = w_latent = None
52
+
53
+ if preserved_area_mask is not None:
54
+ with torch.no_grad():
55
+ img_tensor = (
56
+ (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
57
+ .unsqueeze(0)
58
+ .to(device, dtype)
59
+ )
60
+ init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
61
+ init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
62
+
63
+ _, _, h_latent, w_latent = init_latents.shape
64
+
65
+ packed_init_latents = pack_latents(
66
+ init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent
67
+ )
68
+
69
+ mask_tensor = (
70
+ (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
71
+ .unsqueeze(0)
72
+ .unsqueeze(0)
73
+ .to(device, dtype)
74
+ )
75
+ latent_preserved_mask = torch.nn.functional.interpolate(
76
+ mask_tensor, size=(h_latent, w_latent), mode="nearest"
77
+ )
78
+ packed_preserved_mask = pack_latents(
79
+ latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
80
+ )
81
+
82
+ def callback_fn(pipe, step, timestep, callback_kwargs):
83
+ latents = callback_kwargs["latents"]
84
+
85
+ if packed_preserved_mask is not None:
86
+ progress = step / max(1, total_steps - 1)
87
+ current_alpha = start_alpha - (start_alpha - end_alpha) * progress
88
+
89
+ effective_mask = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
90
+ latents = (1 - effective_mask) * latents + effective_mask * packed_init_latents
91
+
92
+ if step % 5 == 0 or step == total_steps - 1:
93
+ with torch.no_grad():
94
+ unpacked = unpack_latents(latents, h_latent, w_latent)
95
+ unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
96
+ decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
97
+ img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
98
+ step_images_list.append(img_step)
99
+
100
+ callback_kwargs["latents"] = latents
101
+ return callback_kwargs
102
+
103
+ return callback_fn
104
+
105
+
106
+ # --- LoRA's FUNCTIONS ---
107
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
108
  adapter_names = []
109
  adapter_weights = []
 
110
  for lora, weight in loras_with_weights:
 
111
  pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
112
  adapter_names.append(lora.name)
113
  adapter_weights.append(weight)
 
 
114
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
115
  return pipe
116
 
117
 
 
 
 
 
118
  def deactivate_loras(pipe):
 
119
  pipe.unload_lora_weights()
120
  return pipe
121
 
122
 
123
+ # --- GENERATION
124
  def calculate_optimal_dimensions(image):
125
  original_width, original_height = image.size
 
 
126
  FIXED_DIMENSION = 1024
127
+ aspect_ratio = original_width / original_height
128
+ if aspect_ratio > 1:
129
+ width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio)
 
130
  else:
131
+ height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio)
132
+ return (width // 8) * 8, (height // 8) * 8
133
+
134
+
135
+ @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
136
  def inpaint(
137
  image,
138
  mask,
139
+ preserved_area_mask=None,
140
  prompt: str = "",
141
  seed: int = 0,
142
  num_inference_steps: int = 28,
 
147
  mask = mask.convert("L")
148
  width, height = calculate_optimal_dimensions(image)
149
 
150
+ # Resize to match dimensions
151
+ image_resized = image.resize((width, height), Image.LANCZOS)
152
+
153
  pipe.to("cuda")
154
+
155
+ # Setup callback if a preserved area mask is provided
156
+ step_images = []
157
+ callback = None
158
+ if preserved_area_mask is not None:
159
+ preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
160
+ callback = get_gradual_blend_callback(
161
+ pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
162
+ )
163
+
164
  result = pipe(
165
+ image=image_resized,
166
+ mask_image=mask.resize((width, height)),
167
  prompt=prompt,
168
  width=width,
169
  height=height,
 
171
  guidance_scale=guidance_scale,
172
  strength=strength,
173
  generator=torch.Generator().manual_seed(seed),
174
+ callback_on_step_end=callback,
175
+ callback_on_step_end_tensor_inputs=["latents"] if callback else None,
176
  ).images[0]
177
 
178
+ return result.convert("RGBA"), step_images, prompt, seed
179
 
180
 
181
  def inpaint_api(
182
  image,
183
  mask,
184
+ preserved_area_mask=None,
185
+ prompt: str = "",
186
+ seed: int = -1,
187
+ num_inference_steps: int = 40,
188
+ guidance_scale: float = 30.0,
189
+ strength: float = 1.0,
190
+ flux_keywords: list[str] = None,
191
+ loras_selected: list[tuple[str, float]] = None,
192
  ):
 
 
 
 
193
  selected_loras_with_weights = []
194
 
195
+ if loras_selected:
196
+ for name, weight_value in loras_selected:
197
+ try:
198
+ weight = float(weight_value)
199
+ except (ValueError, TypeError):
200
+ continue
201
+ lora_obj = next((l for l in loras if l.display_name == name), None)
202
+ if lora_obj and weight != 0.0:
203
+ selected_loras_with_weights.append((lora_obj, weight))
 
 
 
 
 
204
 
205
  deactivate_loras(pipe)
206
  if selected_loras_with_weights:
207
  activate_loras(pipe, selected_loras_with_weights)
208
 
 
209
  final_prompt = ""
 
210
  if flux_keywords:
211
  final_prompt += ", ".join(flux_keywords) + ", "
212
+
213
+ if selected_loras_with_weights:
214
+ for lora, _ in selected_loras_with_weights:
215
+ if lora.keyword:
216
+ final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
217
+
 
 
 
 
218
  final_prompt += prompt
219
 
220
  if not isinstance(seed, int) or seed < 0:
 
223
  return inpaint(
224
  image=image,
225
  mask=mask,
226
+ preserved_area_mask=preserved_area_mask,
227
  prompt=final_prompt,
228
  seed=seed,
229
  num_inference_steps=num_inference_steps,
 
232
  )
233
 
234
 
235
+ with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
 
 
 
 
236
  with gr.Row():
237
  with gr.Column(scale=2):
238
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
239
+ seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
240
+ num_inference_steps_input = gr.Number(label="Inference steps", value=40)
241
+ guidance_scale_input = gr.Number(label="Guidance scale", value=30)
242
+ strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
 
 
 
 
 
 
243
 
244
  gr.Markdown("### Flux Keywords")
245
  flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
 
251
  type="array",
252
  headers=["LoRA", "Weight"],
253
  value=[[name, 0.0] for name in lora_names],
254
+ datatype=["str", "number"],
255
+ interactive=[False, True],
256
+ label="LoRA selection",
 
257
  )
258
 
259
  with gr.Column(scale=3):
260
+ image_input = gr.Image(label="Original Image", type="pil")
261
+ mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil")
262
+ preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil")
263
+ run_btn = gr.Button("Generate", variant="primary")
 
264
 
265
  with gr.Column(scale=3):
266
  result_image = gr.Image(label="Result")
267
+ used_prompt_box = gr.Text(label="Final Prompt")
268
+ used_seed_box = gr.Number(label="Used Seed")
269
+ steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True)
 
270
 
271
  run_btn.click(
272
  fn=inpaint_api,
273
  inputs=[
274
  image_input,
275
  mask_input,
276
+ preserved_area_input,
277
  prompt_input,
278
  seed_slider,
279
  num_inference_steps_input,
 
282
  flux_keywords_input,
283
  loras_selected_input,
284
  ],
285
+ outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box],
 
286
  )
287
 
288
  if __name__ == "__main__":
289
+ demo.launch()