X-HighVoltage-X commited on
Commit
beb57ea
·
verified ·
1 Parent(s): d852664

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -153
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import random
 
2
 
3
  import gradio as gr
4
  import numpy as np
@@ -6,7 +7,6 @@ import spaces
6
  import torch
7
  from diffusers import FluxFillPipeline
8
  from loras import LoRA, loras
9
- from PIL import Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
@@ -14,122 +14,62 @@ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", tor
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
17
- # --- LATENT MANIPULATION FUNCTIONS ---
18
- def pack_latents(latents, batch_size, num_channels, height, width):
19
- latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
20
- latents = latents.permute(0, 2, 4, 1, 3, 5)
21
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
22
- return latents
23
 
24
-
25
- def unpack_latents(latents, height, width, h_scale=2, w_scale=2):
26
- batch_size, seq_len, channels = latents.shape
27
- # Flux uses a 2x2 patch, so the factor is 2
28
- latents = latents.view(
29
- batch_size, height // h_scale, width // w_scale, channels // (h_scale * w_scale), h_scale, w_scale
30
- )
31
- latents = latents.permute(0, 3, 1, 4, 2, 5)
32
- latents = latents.reshape(batch_size, channels // (h_scale * w_scale), height, width)
33
- return latents
34
-
35
-
36
- # --- CALLBACK (PRESERVED AREA + STEP CAPTURE) ---
37
- def get_gradual_blend_callback(
38
- pipe, original_image, preserved_area_mask, total_steps, step_images_list, start_alpha=1.0, end_alpha=0.2
39
- ):
40
- device = pipe.device
41
- dtype = pipe.transformer.dtype
42
-
43
- with torch.no_grad():
44
- # Prepare original image
45
- img_tensor = (
46
- (torch.from_numpy(np.array(original_image).transpose(2, 0, 1)).float() / 127.5 - 1.0)
47
- .unsqueeze(0)
48
- .to(device, dtype)
49
- )
50
- init_latents = pipe.vae.encode(img_tensor).latent_dist.sample()
51
- init_latents = (init_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
52
-
53
- # Dimensions in latent space
54
- _, _, h_latent, w_latent = init_latents.shape
55
-
56
- # Pack original latents (64 channels)
57
- packed_init_latents = pack_latents(init_latents, batch_size=1, num_channels=16, height=h_latent, width=w_latent)
58
-
59
- # Prepare and pack the preserved area mask (4 channels)
60
- mask_tensor = (
61
- (torch.from_numpy(np.array(preserved_area_mask.convert("L"))).float() / 255.0)
62
- .unsqueeze(0)
63
- .unsqueeze(0)
64
- .to(device, dtype)
65
- )
66
- latent_preserved_mask = torch.nn.functional.interpolate(mask_tensor, size=(h_latent, w_latent), mode="nearest")
67
- packed_preserved_mask = pack_latents(
68
- latent_preserved_mask, batch_size=1, num_channels=1, height=h_latent, width=w_latent
69
- )
70
-
71
- def callback_fn(pipe, step, timestep, callback_kwargs):
72
- latents = callback_kwargs["latents"]
73
-
74
- # A. Preserved Area Logic
75
- progress = step / max(1, total_steps - 1)
76
- current_alpha = start_alpha - (start_alpha - end_alpha) * progress
77
-
78
- # We use .repeat(1, 1, 16) so the 4 mask channels affect the 64 latent channels
79
- effective_mask_64 = (packed_preserved_mask * current_alpha).repeat(1, 1, 16)
80
- latents = (1 - effective_mask_64) * latents + effective_mask_64 * packed_init_latents
81
-
82
- # B. Step Capture (Save an image every 5 steps to save GPU memory)
83
- if step % 5 == 0 or step == total_steps - 1:
84
- with torch.no_grad():
85
- unpacked = unpack_latents(latents, h_latent, w_latent)
86
- unpacked = (unpacked / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
87
-
88
- # Decode and convert to PIL image
89
- decoded = pipe.vae.decode(unpacked.to(pipe.vae.dtype)).sample
90
- img_step = pipe.image_processor.postprocess(decoded, output_type="pil")[0]
91
- step_images_list.append(img_step)
92
-
93
- callback_kwargs["latents"] = latents
94
- return callback_kwargs
95
-
96
- return callback_fn
97
-
98
-
99
- # --- LoRA's FUNCTIONS ---
100
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
101
  adapter_names = []
102
  adapter_weights = []
 
103
  for lora, weight in loras_with_weights:
 
104
  pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
105
  adapter_names.append(lora.name)
106
  adapter_weights.append(weight)
 
 
107
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
108
  return pipe
109
 
110
 
 
 
 
 
111
  def deactivate_loras(pipe):
 
112
  pipe.unload_lora_weights()
113
  return pipe
114
 
115
 
116
- # --- GENERATION
117
  def calculate_optimal_dimensions(image):
118
  original_width, original_height = image.size
 
 
119
  FIXED_DIMENSION = 1024
120
- aspect_ratio = original_width / original_height
121
- if aspect_ratio > 1:
122
- width, height = FIXED_DIMENSION, round(FIXED_DIMENSION / aspect_ratio)
 
123
  else:
124
- height, width = FIXED_DIMENSION, round(FIXED_DIMENSION * aspect_ratio)
125
- return (width // 8) * 8, (height // 8) * 8
126
-
127
-
128
- @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
129
  def inpaint(
130
  image,
131
  mask,
132
- preserved_area_mask=None,
133
  prompt: str = "",
134
  seed: int = 0,
135
  num_inference_steps: int = 28,
@@ -140,23 +80,10 @@ def inpaint(
140
  mask = mask.convert("L")
141
  width, height = calculate_optimal_dimensions(image)
142
 
143
- # Resize to match dimensions
144
- image_resized = image.resize((width, height), Image.LANCZOS)
145
-
146
  pipe.to("cuda")
147
-
148
- # Setup callback if a preserved area mask is provided
149
- step_images = []
150
- callback = None
151
- if preserved_area_mask is not None:
152
- preserved_area_resized = preserved_area_mask.resize((width, height), Image.NEAREST)
153
- callback = get_gradual_blend_callback(
154
- pipe, image_resized, preserved_area_resized, num_inference_steps, step_images
155
- )
156
-
157
  result = pipe(
158
- image=image_resized,
159
- mask_image=mask.resize((width, height)),
160
  prompt=prompt,
161
  width=width,
162
  height=height,
@@ -164,50 +91,62 @@ def inpaint(
164
  guidance_scale=guidance_scale,
165
  strength=strength,
166
  generator=torch.Generator().manual_seed(seed),
167
- callback_on_step_end=callback,
168
- callback_on_step_end_tensor_inputs=["latents"] if callback else None,
169
  ).images[0]
170
 
171
- return result.convert("RGBA"), step_images, prompt, seed
172
 
173
 
174
  def inpaint_api(
175
  image,
176
  mask,
177
- preserved_area_mask=None,
178
- prompt: str = "",
179
- seed: int = -1,
180
- num_inference_steps: int = 40,
181
- guidance_scale: float = 30.0,
182
- strength: float = 1.0,
183
- flux_keywords: list[str] = None,
184
- loras_selected: list[tuple[str, float]] = None,
185
  ):
 
 
 
 
186
  selected_loras_with_weights = []
187
 
188
- if loras_selected:
189
- for name, weight_value in loras_selected:
190
- try:
191
- weight = float(weight_value)
192
- except (ValueError, TypeError):
193
- continue
194
- lora_obj = next((l for l in loras if l.display_name == name), None)
195
- if lora_obj and weight != 0.0:
196
- selected_loras_with_weights.append((lora_obj, weight))
 
 
 
 
 
197
 
198
  deactivate_loras(pipe)
199
  if selected_loras_with_weights:
200
  activate_loras(pipe, selected_loras_with_weights)
201
 
 
202
  final_prompt = ""
 
203
  if flux_keywords:
204
  final_prompt += ", ".join(flux_keywords) + ", "
205
-
206
- if selected_loras_with_weights:
207
- for lora, _ in selected_loras_with_weights:
208
- if lora.keyword:
209
- final_prompt += (lora.keyword if isinstance(lora.keyword, str) else ", ".join(lora.keyword)) + ", "
210
-
 
 
 
 
211
  final_prompt += prompt
212
 
213
  if not isinstance(seed, int) or seed < 0:
@@ -216,7 +155,6 @@ def inpaint_api(
216
  return inpaint(
217
  image=image,
218
  mask=mask,
219
- preserved_area_mask=preserved_area_mask,
220
  prompt=final_prompt,
221
  seed=seed,
222
  num_inference_steps=num_inference_steps,
@@ -225,14 +163,24 @@ def inpaint_api(
225
  )
226
 
227
 
228
- with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft()) as demo:
 
 
 
 
229
  with gr.Row():
230
  with gr.Column(scale=2):
231
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
232
- seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
233
- num_inference_steps_input = gr.Number(label="Inference steps", value=40)
234
- guidance_scale_input = gr.Number(label="Guidance scale", value=30)
235
- strength_input = gr.Number(label="Strength", value=1.0, maximum=1.0)
 
 
 
 
 
 
236
 
237
  gr.Markdown("### Flux Keywords")
238
  flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
@@ -244,29 +192,31 @@ with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft
244
  type="array",
245
  headers=["LoRA", "Weight"],
246
  value=[[name, 0.0] for name in lora_names],
247
- datatype=["str", "number"],
248
- interactive=[False, True],
249
- label="LoRA selection",
 
250
  )
251
 
252
  with gr.Column(scale=3):
253
- image_input = gr.Image(label="Original Image", type="pil")
254
- mask_input = gr.Image(label="Inpaint Mask (Area to change)", type="pil")
255
- preserved_area_input = gr.Image(label="Preserved Area Mask (Area to keep)", type="pil")
256
- run_btn = gr.Button("Generate", variant="primary")
 
257
 
258
  with gr.Column(scale=3):
259
  result_image = gr.Image(label="Result")
260
- used_prompt_box = gr.Text(label="Final Prompt")
261
- used_seed_box = gr.Number(label="Used Seed")
262
- steps_gallery = gr.Gallery(label="Evolution (Steps)", columns=3, preview=True)
 
263
 
264
  run_btn.click(
265
  fn=inpaint_api,
266
  inputs=[
267
  image_input,
268
  mask_input,
269
- preserved_area_input,
270
  prompt_input,
271
  seed_slider,
272
  num_inference_steps_input,
@@ -275,8 +225,9 @@ with gr.Blocks(title="FLUX.1 Fill dev + Area Preservation", theme=gr.themes.Soft
275
  flux_keywords_input,
276
  loras_selected_input,
277
  ],
278
- outputs=[result_image, steps_gallery, used_prompt_box, used_seed_box],
 
279
  )
280
 
281
  if __name__ == "__main__":
282
- demo.launch()
 
1
  import random
2
+ from typing import List, Tuple
3
 
4
  import gradio as gr
5
  import numpy as np
 
7
  import torch
8
  from diffusers import FluxFillPipeline
9
  from loras import LoRA, loras
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
 
14
 
15
  flux_keywords_available = ["IMG_1025.HEIC", "Selfie"]
16
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def activate_loras(pipe: FluxFillPipeline, loras_with_weights: list[tuple[LoRA, float]]):
19
  adapter_names = []
20
  adapter_weights = []
21
+
22
  for lora, weight in loras_with_weights:
23
+ print(f"Loading LoRA: {lora.name} with weight {weight}")
24
  pipe.load_lora_weights(lora.id, weight=weight, adapter_name=lora.name)
25
  adapter_names.append(lora.name)
26
  adapter_weights.append(weight)
27
+
28
+ print(f"Activating adapters: {adapter_names} with weights {adapter_weights}")
29
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
30
+
31
  return pipe
32
 
33
 
34
+ def get_loras() -> list[dict]:
35
+ return loras
36
+
37
+
38
  def deactivate_loras(pipe):
39
+ print("Unloading all LoRAs...")
40
  pipe.unload_lora_weights()
41
  return pipe
42
 
43
 
 
44
  def calculate_optimal_dimensions(image):
45
  original_width, original_height = image.size
46
+ MIN_ASPECT_RATIO = 9 / 16
47
+ MAX_ASPECT_RATIO = 16 / 9
48
  FIXED_DIMENSION = 1024
49
+ original_aspect_ratio = original_width / original_height
50
+ if original_aspect_ratio > 1:
51
+ width = FIXED_DIMENSION
52
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
53
  else:
54
+ height = FIXED_DIMENSION
55
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
56
+ width = (width // 8) * 8
57
+ height = (height // 8) * 8
58
+ calculated_aspect_ratio = width / height
59
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
60
+ width = int((height * MAX_ASPECT_RATIO // 8) * 8)
61
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
62
+ height = int((width / MIN_ASPECT_RATIO // 8) * 8)
63
+ width = max(width, 576) if width == FIXED_DIMENSION else width
64
+ height = max(height, 576) if height == FIXED_DIMENSION else height
65
+
66
+ return width, height
67
+
68
+
69
+ @spaces.GPU(duration=45)
70
  def inpaint(
71
  image,
72
  mask,
 
73
  prompt: str = "",
74
  seed: int = 0,
75
  num_inference_steps: int = 28,
 
80
  mask = mask.convert("L")
81
  width, height = calculate_optimal_dimensions(image)
82
 
 
 
 
83
  pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
84
  result = pipe(
85
+ image=image,
86
+ mask_image=mask,
87
  prompt=prompt,
88
  width=width,
89
  height=height,
 
91
  guidance_scale=guidance_scale,
92
  strength=strength,
93
  generator=torch.Generator().manual_seed(seed),
 
 
94
  ).images[0]
95
 
96
+ return result.convert("RGBA"), prompt, seed
97
 
98
 
99
  def inpaint_api(
100
  image,
101
  mask,
102
+ prompt: str,
103
+ seed: int,
104
+ num_inference_steps: int,
105
+ guidance_scale: int,
106
+ strength: float,
107
+ flux_keywords: List[str] = None,
108
+ loras_selected: List[Tuple[str, float]] = None,
 
109
  ):
110
+ flux_keywords = flux_keywords or []
111
+ loras_selected = loras_selected or []
112
+
113
+ # Convertir nombres a objetos LoRA
114
  selected_loras_with_weights = []
115
 
116
+ for name, weight_value in loras_selected:
117
+ try:
118
+ # Convierte explícitamente el peso (que viene como string) a float
119
+ weight = float(weight_value)
120
+ except (ValueError, TypeError):
121
+ # Ignora si el valor no es un número válido (ej: None o string vacío)
122
+ print(f"Valor de peso inválido '{weight_value}' para LoRA '{name}', omitiendo.")
123
+ continue # Pasa al siguiente LoRA
124
+
125
+ lora_obj = next((l for l in loras if l.display_name == name), None)
126
+
127
+ # Ahora la comparación 'weight != 0.0' es segura (float con float)
128
+ if lora_obj and weight != 0.0:
129
+ selected_loras_with_weights.append((lora_obj, weight))
130
 
131
  deactivate_loras(pipe)
132
  if selected_loras_with_weights:
133
  activate_loras(pipe, selected_loras_with_weights)
134
 
135
+ # Construir prompt final
136
  final_prompt = ""
137
+
138
  if flux_keywords:
139
  final_prompt += ", ".join(flux_keywords) + ", "
140
+
141
+ for lora, _ in selected_loras_with_weights:
142
+ if lora.keyword:
143
+ if isinstance(lora.keyword, str):
144
+ final_prompt += lora.keyword + ", "
145
+ else:
146
+ final_prompt += ", ".join(lora.keyword) + ", "
147
+
148
+ if final_prompt:
149
+ final_prompt += "\n\n"
150
  final_prompt += prompt
151
 
152
  if not isinstance(seed, int) or seed < 0:
 
155
  return inpaint(
156
  image=image,
157
  mask=mask,
 
158
  prompt=final_prompt,
159
  seed=seed,
160
  num_inference_steps=num_inference_steps,
 
163
  )
164
 
165
 
166
+ # ========================
167
+ # UI DIRECTA A inpaint_api
168
+ # ========================
169
+ with gr.Blocks(title="Flux.1 Fill dev Inpainting with LoRAs", theme=gr.themes.Soft()) as demo:
170
+ gr.api(get_loras, api_name="get_loras")
171
  with gr.Row():
172
  with gr.Column(scale=2):
173
  prompt_input = gr.Text(label="Prompt", lines=4, value="a 25 years old woman")
174
+
175
+ seed_slider = gr.Slider(
176
+ label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1, info="(-1 = Random)", interactive=True
177
+ )
178
+
179
+ num_inference_steps_input = gr.Number(label="Inference steps", value=40, interactive=True)
180
+
181
+ guidance_scale_input = gr.Number(label="Guidance scale", value=28, interactive=True)
182
+
183
+ strength_input = gr.Number(label="Strength", value=1.0, interactive=True, maximum=1.0)
184
 
185
  gr.Markdown("### Flux Keywords")
186
  flux_keywords_input = gr.CheckboxGroup(choices=flux_keywords_available, label="Flux Keywords")
 
192
  type="array",
193
  headers=["LoRA", "Weight"],
194
  value=[[name, 0.0] for name in lora_names],
195
+ datatype=["str", "number"], # Primera columna string, segunda número
196
+ interactive=[False, True], # Solo la segunda columna editable
197
+ static_columns=[0],
198
+ label="LoRA selection (Weight 0 = disable)",
199
  )
200
 
201
  with gr.Column(scale=3):
202
+ image_input = gr.Image(label="Image", type="pil")
203
+
204
+ mask_input = gr.Image(label="Mask", type="pil")
205
+
206
+ run_btn = gr.Button("Run", variant="primary")
207
 
208
  with gr.Column(scale=3):
209
  result_image = gr.Image(label="Result")
210
+
211
+ used_prompt_box = gr.Text(label="Used prompt", lines=4)
212
+
213
+ used_seed_box = gr.Number(label="Used seed")
214
 
215
  run_btn.click(
216
  fn=inpaint_api,
217
  inputs=[
218
  image_input,
219
  mask_input,
 
220
  prompt_input,
221
  seed_slider,
222
  num_inference_steps_input,
 
225
  flux_keywords_input,
226
  loras_selected_input,
227
  ],
228
+ outputs=[result_image, used_prompt_box, used_seed_box],
229
+ api_name="inpaint",
230
  )
231
 
232
  if __name__ == "__main__":
233
+ demo.launch(share=False, show_error=True)