ZhouZJ36DL commited on
Commit
a216e74
·
1 Parent(s): 93bc0f4

modified: app.py

Browse files
app.py CHANGED
@@ -32,441 +32,430 @@ class SamplingOptions:
32
  guidance: float
33
  seed: int | None
34
 
35
- @torch.inference_mode()
36
- def encode(init_image, torch_device, ae):
37
- if next(ae.parameters()).device != torch_device:
38
- ae = ae.to(torch_device)
39
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
40
- init_image = init_image.unsqueeze(0)
41
- init_image = init_image.to(torch_device)
42
- with torch.no_grad():
43
- init_image = ae.encode(init_image).to(torch.bfloat16)
44
- return init_image
45
-
46
 
47
- class FluxEditor:
48
- def __init__(self, args):
49
- self.args = args
50
- self.device = torch.device(args.device)
51
- self.offload = args.offload
52
- self.name = args.name
53
- self.is_schnell = args.name == "flux-schnell"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- self.feature_path = 'feature'
56
 
57
- self.reset()
58
-
59
- self.add_sampling_metadata = True
60
-
61
- if self.name not in configs:
62
- available = ", ".join(configs.keys())
63
- raise ValueError(f"Got unknown model name: {self.name}, chose from {available}")
64
-
65
- # init all components
66
- self.clip = load_clip(self.device)
67
- self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 77)
68
- self.model = load_flow_model(self.name, self.device)
69
- self.ae = load_ae(self.name, self.device)
70
- self.t5.eval()
71
- self.clip.eval()
72
- self.ae.eval()
73
- self.model.eval()
74
-
75
- # clear history
76
- if os.path.exists("history_gradio/history.safetensors"):
77
- os.remove("history_gradio/history.safetensors")
 
 
78
 
 
 
 
 
 
79
 
80
- @torch.inference_mode()
81
- def reset(self):
82
- out_root = 'src/gradio_utils/gradio_outputs'
83
- if not os.path.exists(out_root):
84
- os.makedirs(out_root)
85
- name_dir = f'exp_{len(os.listdir(out_root))}'
86
- self.output_dir = os.path.join(out_root, name_dir)
87
- if not os.path.exists(self.output_dir):
88
- os.makedirs(self.output_dir)
89
- if not os.path.exists("heatmap"):
90
- os.makedirs("heatmap")
91
- if not os.path.exists("heatmap/average_heatmaps"):
92
- os.makedirs("heatmap/average_heatmaps")
93
- self.instructions = ['source']
94
- self.source_image = None
95
- self.history_tensors = {
96
- "source img": torch.zeros((1, 1, 1)),
97
- "prev img": torch.zeros((1, 1, 1))}
98
-
99
- source_prompt = "(Optional) Describe the content of the uploaded image."
100
- traget_prompt = "(Required) Describe the desired content of the edited image."
101
- gallery = None
102
- output_image = None
103
- return source_prompt, traget_prompt, gallery, output_image
104
-
105
-
106
- @torch.inference_mode()
107
- def process_image(self,
108
- init_image,
109
- source_prompt,
110
- target_prompt,
111
- editing_strategy,
112
- denoise_strategy,
113
- num_steps,
114
- guidance,
115
- attn_guidance_start_block,
116
- inject_step,
117
- init_image_2=None):
118
- if init_image is None:
119
- img, gr_gallery = self.generate_image(prompt=target_prompt)
120
- else:
121
- img, gr_gallery = self.edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2)
122
- return img, gr_gallery
123
 
124
 
125
- @spaces.GPU(duration=120)
126
- @torch.inference_mode()
127
- def generate_image(
128
- self,
129
- width=512,
130
- height=512,
131
- num_steps=28,
132
- guidance=3.5,
133
- seed=None,
134
- prompt='',
135
- init_image=None,
136
- image2image_strength=0.0,
137
- add_sampling_metadata=True,
138
- ):
139
-
140
- if seed is None:
141
- g_seed = torch.Generator(device=torch.device("cpu")).seed()
142
- print(f"Generating '{prompt}' with seed {g_seed}")
143
- t0 = time.perf_counter()
144
-
145
- if init_image is not None:
146
- if isinstance(init_image, np.ndarray):
147
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
148
- init_image = init_image.unsqueeze(0)
149
- init_image = init_image.to(self.device)
150
- init_image = torch.nn.functional.interpolate(init_image, (height, width))
151
- if self.offload:
152
- self.ae.encoder.to(self.device)
153
- init_image = self.ae.encode(init_image)
154
- if self.offload:
155
- self.ae = self.ae.cpu()
156
- torch.cuda.empty_cache()
157
-
158
- # prepare input
159
- x = get_noise(
160
- 1,
161
- height,
162
- width,
163
- device=self.device,
164
- dtype=torch.bfloat16,
165
- seed=g_seed,
166
- )
167
- timesteps = get_schedule(
168
- num_steps,
169
- x.shape[-1] * x.shape[-2] // 4,
170
- shift=(not self.is_schnell),
171
- )
172
- if init_image is not None:
173
- t_idx = int((1 - image2image_strength) * num_steps)
174
- t = timesteps[t_idx]
175
- timesteps = timesteps[t_idx:]
176
- x = t * x + (1.0 - t) * init_image.to(x.dtype)
177
-
178
- if self.offload:
179
- self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
180
- inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
181
-
182
- # offload TEs to CPU, load model to gpu
183
- if self.offload:
184
- self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
185
- torch.cuda.empty_cache()
186
- self.model = self.model.to(self.device)
187
-
188
- self.model = self.model.to(self.device)
189
- # denoise initial noise
190
- info = {}
191
- info['feature'] = {}
192
- info['inject_step'] = 0
193
- info['editing_strategy']= ""
194
- info['start_layer_index'] = 0
195
- info['end_layer_index'] = 37
196
- info['reuse_v']= False
197
- qkv_ratio = '1.0,1.0,1.0'
198
- info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
199
- x = denoise_rf(self.model, **inp, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
200
-
201
- # offload model, load autoencoder to gpu
202
- if self.offload:
203
- self.model.cpu()
204
  torch.cuda.empty_cache()
205
- self.ae.decoder.to(x.device)
206
 
207
- # decode latents to pixel space
208
- x = unpack(x[0].float(), height, width)
209
- self.ae = self.ae.to(x.device)
210
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
211
- x = self.ae.decode(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- if self.offload:
214
- self.ae.decoder.cpu()
215
- torch.cuda.empty_cache()
 
 
216
 
217
- t1 = time.perf_counter()
 
 
218
 
219
- print(f"Done in {t1 - t0:.1f}s.")
220
- # bring into PIL format
221
- x = x.clamp(-1, 1)
222
- x = embed_watermark(x.float())
223
- x = rearrange(x[0], "c h w -> h w c")
224
 
225
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
226
-
227
- filename = os.path.join(self.output_dir,f"round_0000_[{prompt}].jpg")
228
- os.makedirs(os.path.dirname(filename), exist_ok=True)
229
- exif_data = Image.Exif()
230
- if init_image is None:
231
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
232
- else:
233
- exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
234
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
235
- exif_data[ExifTags.Base.Model] = self.name
236
- if add_sampling_metadata:
237
- exif_data[ExifTags.Base.ImageDescription] = prompt
238
- img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
239
- self.instructions = [prompt]
240
-
241
- #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
242
- img_and_prompt = []
243
- history_imgs = sorted(os.listdir(self.output_dir))
244
- for img_file, prompt_txt in zip(history_imgs, self.instructions):
245
- img_and_prompt.append((os.path.join(self.output_dir, img_file), prompt_txt))
246
- history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
247
- return img, history_gallery
248
-
249
-
250
- @spaces.GPU(duration=200)
251
- @torch.inference_mode()
252
- def edit(self, init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2=None):
253
-
254
- torch.cuda.empty_cache()
255
- seed = None
256
 
257
- print(f"Inital_t5_device: {self.t5.hf_module.device}")
258
- print(f"Inital_clip_device: {self.clip.hf_module.device}")
259
- print(f"Inital_flow_model: {self.model.img_in.weight.device}")
260
- print(f"Inital_flow_model self.model.img_in: {self.model.img_in}")
261
- print(f"Inital_flow_model self.model.time_in.out_layer.weight: {self.model.time_in.out_layer.weight}")
262
-
263
- if self.offload:
264
- self.model.cpu()
265
- torch.cuda.empty_cache()
266
- self.ae.encoder.to(self.device)
267
-
268
- #----------------------------- 0.1 prepare multi-turn editing -------------------------------------#
269
- info = {}
270
- shape = init_image.shape
271
- new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
272
- new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
273
-
274
- if not any("round_0000" in fname for fname in os.listdir(self.output_dir)):
275
- Image.fromarray(init_image).save(os.path.join(self.output_dir,"round_0000_[source].jpg"))
276
-
277
-
278
- init_image = init_image[:new_h, :new_w, :]
279
- width, height = init_image.shape[0], init_image.shape[1]
280
- init_image = encode(init_image, self.device, self.ae)
281
-
282
- print(init_image.shape)
283
-
284
- if init_image_2 is None:
285
- print("init_image_2 is not provided, proceeding with single image processing.")
286
- else:
287
- init_image_2_pil = Image.fromarray(init_image_2) # Convert NumPy array to PIL Image
288
- init_image_2_pil = init_image_2_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
289
- init_image_2 = np.array(init_image_2_pil) # Convert back to NumPy (if needed)
290
- init_image_2 = encode(init_image_2, self.device, self.ae)
291
-
292
- rng = torch.Generator(device=torch.device("cpu"))
293
- opts = SamplingOptions(
294
- source_prompt=source_prompt,
295
- target_prompt=target_prompt,
296
- width=width,
297
- height=height,
298
- num_steps=num_steps,
299
- guidance=guidance,
300
- seed=seed,
301
- )
302
- if opts.seed is None:
303
- opts.seed = torch.Generator(device=torch.device("cpu")).seed()
304
-
305
- print(f"Editing with prompt:\n{opts.source_prompt}")
306
- t0 = time.perf_counter()
307
 
308
- opts.seed = None
309
- if self.offload:
310
- self.ae = self.ae.cpu()
311
- torch.cuda.empty_cache()
312
- self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
313
-
314
- #----------------------------- 0.2 prepare attention strategy -------------------------------------#
315
- info = {}
316
- info['feature'] = {}
317
- info['inject_step'] = inject_step
318
- info['editing_strategy']= " ".join(editing_strategy)
319
- info['start_layer_index'] = 0
320
- info['end_layer_index'] = 37
321
- info['reuse_v']= False
322
- qkv_ratio = '1.0,1.0,1.0'
323
- info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
324
- info['attn_guidance'] = attn_guidance_start_block
325
- info['lqr_stop'] = 0.25
326
-
327
- if not os.path.exists(self.feature_path):
328
- os.mkdir(self.feature_path)
329
-
330
-
331
- #----------------------------- 0.3 prepare latents -------------------------------------#
332
- with torch.no_grad():
333
- inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
334
- inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
335
- if self.source_image is None:
336
- self.source_image = inp['img']
337
- inp_target_2 = None
338
- if not init_image_2 is None:
339
- inp_target_2 = prepare_image(init_image_2)
340
- info['lqr_stop'] = 0.35
341
-
342
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
343
- #timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=False)
344
-
345
- # offload TEs to CPU, load model to gpu
346
- if self.offload:
347
- self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
348
- torch.cuda.empty_cache()
349
- self.model = self.model.to(self.device)
350
 
351
- self.model = self.model.to(self.device)
352
- print(f"model has been moved to {self.device}")
353
-
354
- #----------------------------- 1 Inverting current image -------------------------------------#
355
- denoise_strategies = ['fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion', 'multi_turn_consistent']
356
- denoise_funcs = [denoise_fireflow, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent]
357
- denoise_func = denoise_funcs[denoise_strategies.index(denoise_strategy)]
358
- with torch.no_grad():
359
- z, info = denoise_func(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
360
-
361
- print(f"Inverted Z: {z}")
362
- print(info)
363
-
364
-
365
- #----------------------------- 2 history_tensors used to implement dual-LQR guiding editing -------------------------------------#
366
- inp_target["img"] = z
367
- timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
368
-
369
- if torch.all(self.history_tensors['source img'] == 0):
370
- self.history_tensors = {
371
- "source img": inp["img"],
372
- "prev img": inp_target_2}
373
- else:
374
- if inp_target_2 is None:
375
- self.history_tensors["prev img"] = inp["img"]
376
- else:
377
- self.history_tensors["source img"] = inp["img"]
378
- self.history_tensors["prev img"] = inp_target_2
379
-
380
- #----------------------------- 3 sampling -------------------------------------#
381
- if denoise_strategy in ['rf_inversion', 'multi_turn_consistent']:
382
- x, _ = denoise_func(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info, img_LQR=self.history_tensors)
383
- else:
384
- x, _ = denoise_func(self.model, **inp_target, timesteps=timesteps, guidance=opts.guidance, inverse=False, info=info)
 
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- #----------------------------- 4 update history_tensors -------------------------------------#
388
- info = {}
389
- self.history_tensors["source img"] = self.source_image
390
- self.history_tensors["prev img"] = x
391
- '''save_file(history_tensors, "history_gradio/history.safetensors")'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
- # offload model, load autoencoder to gpu
394
- if self.offload:
395
- self.model.cpu()
396
- torch.cuda.empty_cache()
397
- self.ae.decoder.to(x.device)
398
 
 
 
399
 
 
 
 
 
 
 
 
 
400
 
401
- #----------------------------- 5 decode x to image -------------------------------------#
402
- x = unpack(x.float(), opts.width, opts.height)
403
 
404
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
405
- x = self.ae.decode(x)
 
406
 
407
- if torch.cuda.is_available():
408
- torch.cuda.synchronize()
409
- t1 = time.perf_counter()
 
410
 
411
- # bring into PIL format and save
412
- x = x.clamp(-1, 1)
413
- x = embed_watermark(x.float())
414
- x = rearrange(x[0], "c h w -> h w c")
 
 
 
415
 
416
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
417
- exif_data = Image.Exif()
418
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
419
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
420
- exif_data[ExifTags.Base.Model] = self.name
421
- if self.add_sampling_metadata:
422
- exif_data[ExifTags.Base.ImageDescription] = source_prompt
423
-
424
 
425
 
426
- #-------------------------------- 6 save image -------------------------------------#
427
 
428
- #-------------------- 6.1 prepare output folder ----------------------#
429
- if not os.path.exists(self.output_dir):
430
- os.makedirs(self.output_dir)
431
- idx = 0
432
- #-------------------- 6.2 editing round ----------------------#
 
 
 
 
433
  else:
434
- fns = [fn for fn in os.listdir(self.output_dir)]
435
- if len(fns) > 0:
436
- idx = max(int(fn.split("_")[1]) for fn in fns) + 1
437
- else:
438
- idx = 0
439
- formatted_idx = str(idx).zfill(4) # Format as a 4-digit string
440
-
441
- #-------------------- 6.3 output name ----------------------#
442
- if denoise_strategy == 'multi_turn_consistent':
443
- denoise_strategy = 'MTC'
444
- if target_prompt == '':
445
- target_prompt = 'Reconstruction'
446
- if target_prompt == source_prompt:
447
- target_prompt = 'Reconstruction: ' + target_prompt
448
-
449
- target_suffix = " ".join(target_prompt.split()[-5:])
450
- output_name = f"round_{formatted_idx}_{target_suffix}_{denoise_strategy}.jpg"
451
-
452
- fn = os.path.join(self.output_dir, output_name)
453
-
454
- print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
455
- img.save(fn)
456
 
457
- if 'Reconstruction' in target_prompt:
458
- target_prompt = source_prompt
459
- self.instructions.append(target_prompt)
460
- print("End Edit")
461
 
462
- #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
463
- img_and_prompt = []
464
- history_imgs = sorted(os.listdir(self.output_dir))
465
- for img_file, prompt_txt in zip(history_imgs, self.instructions):
466
- img_and_prompt.append((os.path.join(self.output_dir, img_file), prompt_txt))
467
- history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
468
-
469
- return img, history_gallery
470
 
471
 
472
  def on_select(gallery, selected: gr.SelectData):
@@ -480,8 +469,7 @@ def on_change(init_image, changed: gr.EventData):
480
  return gr.Gallery(value=[(img_path[0], "")], label="History Image", interactive=True, columns=3), img_path[0]
481
 
482
 
483
- def create_demo(model_name: str, device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
484
- editor = FluxEditor(args)
485
  is_schnell = model_name == "flux-schnell"
486
 
487
  # Pre-defined examples
@@ -527,11 +515,11 @@ def create_demo(model_name: str, device: str | torch.device = "cuda" if torch.cu
527
  example_image.change(on_change, example_image, [gallery, init_image])
528
 
529
  generate_btn.click(
530
- fn=editor.process_image,
531
  inputs=[init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2],
532
  outputs=[output_image, gallery]
533
  )
534
- reset_btn.click(fn = editor.reset, outputs=[source_prompt, target_prompt, gallery, output_image])
535
 
536
  # Add examples
537
  gr.Examples(
@@ -552,17 +540,6 @@ def create_demo(model_name: str, device: str | torch.device = "cuda" if torch.cu
552
  return demo
553
 
554
 
555
- if __name__ == "__main__":
556
- import argparse
557
- parser = argparse.ArgumentParser(description="Flux")
558
- parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
559
- parser.add_argument("--device", type=str, default="cuda", help="Device to use")
560
- parser.add_argument("--offload", default=False, help="Offload model to CPU when not in use")
561
- parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
562
- parser.add_argument("--port", type=int, default=9090)
563
- args = parser.parse_args()
564
- print(vars(args))
565
-
566
- demo = create_demo(args.name, args.device, args.offload)
567
- #demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
568
- demo.launch(debug=True)
 
32
  guidance: float
33
  seed: int | None
34
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ offload = False
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ name = 'flux-dev'
40
+ ae = load_ae(name, device="cpu" if offload else torch_device)
41
+ t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
42
+ clip = load_clip(device)
43
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
44
+ is_schnell = False
45
+ add_sampling_metadata = True
46
+
47
+ # clear history
48
+ if os.path.exists("history_gradio/history.safetensors"):
49
+ os.remove("history_gradio/history.safetensors")
50
+
51
+ out_root = 'src/gradio_utils/gradio_outputs'
52
+ if not os.path.exists(out_root):
53
+ os.makedirs(out_root)
54
+ name_dir = f'exp_{len(os.listdir(out_root))+1}'
55
+ output_dir = os.path.join(out_root, name_dir)
56
+ if not os.path.exists(output_dir):
57
+ os.makedirs(output_dir)
58
+ if not os.path.exists("heatmap"):
59
+ os.makedirs("heatmap")
60
+ if not os.path.exists("heatmap/average_heatmaps"):
61
+ os.makedirs("heatmap/average_heatmaps")
62
+ source_image = None
63
+ history_tensors = {
64
+ "source img": torch.zeros((1, 1, 1)),
65
+ "prev img": torch.zeros((1, 1, 1))}
66
+ instructions = ['source']
67
 
 
68
 
69
+ @torch.inference_mode()
70
+ def reset():
71
+
72
+ # clear history
73
+ if os.path.exists("history_gradio/history.safetensors"):
74
+ os.remove("history_gradio/history.safetensors")
75
+
76
+ out_root = 'src/gradio_utils/gradio_outputs'
77
+ if not os.path.exists(out_root):
78
+ os.makedirs(out_root)
79
+ name_dir = f'exp_{len(os.listdir(out_root))+1}'
80
+ output_dir = os.path.join(out_root, name_dir)
81
+ if not os.path.exists(output_dir):
82
+ os.makedirs(output_dir)
83
+ if not os.path.exists("heatmap"):
84
+ os.makedirs("heatmap")
85
+ if not os.path.exists("heatmap/average_heatmaps"):
86
+ os.makedirs("heatmap/average_heatmaps")
87
+ instructions = ['source']
88
+ source_image = None
89
+ history_tensors = {
90
+ "source img": torch.zeros((1, 1, 1)),
91
+ "prev img": torch.zeros((1, 1, 1))}
92
 
93
+ source_prompt = "(Optional) Describe the content of the uploaded image."
94
+ traget_prompt = "(Required) Describe the desired content of the edited image."
95
+ gallery = None
96
+ output_image = None
97
+ return source_prompt, traget_prompt, gallery, output_image
98
 
99
+
100
+ @torch.inference_mode()
101
+ def process_image(
102
+ init_image,
103
+ source_prompt,
104
+ target_prompt,
105
+ editing_strategy,
106
+ denoise_strategy,
107
+ num_steps,
108
+ guidance,
109
+ attn_guidance_start_block,
110
+ inject_step,
111
+ init_image_2=None):
112
+ if init_image is None:
113
+ img, gr_gallery = generate_image(prompt=target_prompt)
114
+ else:
115
+ img, gr_gallery = edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2)
116
+ return img, gr_gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
+ @spaces.GPU(duration=120)
120
+ @torch.inference_mode()
121
+ def generate_image(
122
+ width=512,
123
+ height=512,
124
+ num_steps=28,
125
+ guidance=3.5,
126
+ seed=None,
127
+ prompt='',
128
+ init_image=None,
129
+ image2image_strength=0.0,
130
+ ):
131
+ global ae, t5, clip, model, name, is_schnell, output_dir, add_sampling_metadata, offload
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ torch.cuda.empty_cache()
134
+ seed = None
135
+
136
+ if seed is None:
137
+ g_seed = torch.Generator(device="cpu").seed()
138
+ print(f"Generating '{prompt}' with seed {g_seed}")
139
+ t0 = time.perf_counter()
140
+
141
+ if init_image is not None:
142
+ if isinstance(init_image, np.ndarray):
143
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
144
+ init_image = init_image.unsqueeze(0)
145
+ init_image = init_image.to(device)
146
+ init_image = torch.nn.functional.interpolate(init_image, (height, width))
147
+ if offload:
148
+ ae.encoder.to(device)
149
+ init_image = ae.encode(init_image)
150
+ if offload:
151
+ ae = ae.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  torch.cuda.empty_cache()
 
153
 
154
+ # prepare input
155
+ x = get_noise(
156
+ 1,
157
+ height,
158
+ width,
159
+ device=device,
160
+ dtype=torch.bfloat16,
161
+ seed=g_seed,
162
+ )
163
+ timesteps = get_schedule(
164
+ num_steps,
165
+ x.shape[-1] * x.shape[-2] // 4,
166
+ shift=(not is_schnell),
167
+ )
168
+ if init_image is not None:
169
+ t_idx = int((1 - image2image_strength) * num_steps)
170
+ t = timesteps[t_idx]
171
+ timesteps = timesteps[t_idx:]
172
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
173
+
174
+ if offload:
175
+ t5, clip = t5.to(device), clip.to(device)
176
+ inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
177
+
178
+ # offload TEs to CPU, load model to gpu
179
+ if offload:
180
+ t5, clip = t5.cpu(), clip.cpu()
181
+ torch.cuda.empty_cache()
182
+ model = model.to(device)
183
+
184
+ # denoise initial noise
185
+ info = {}
186
+ info['feature'] = {}
187
+ info['inject_step'] = 0
188
+ info['editing_strategy']= ""
189
+ info['start_layer_index'] = 0
190
+ info['end_layer_index'] = 37
191
+ info['reuse_v']= False
192
+ qkv_ratio = '1.0,1.0,1.0'
193
+ info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
194
+ x = denoise_rf(model, **inp, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
195
+
196
+ # offload model, load autoencoder to gpu
197
+ if offload:
198
+ model.cpu()
199
+ torch.cuda.empty_cache()
200
+ ae.decoder.to(x.device)
201
 
202
+ # decode latents to pixel space
203
+ x = unpack(x[0].float(), height, width)
204
+ device = torch.device("cuda")
205
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
206
+ x = ae.decode(x)
207
 
208
+ if offload:
209
+ ae.decoder.cpu()
210
+ torch.cuda.empty_cache()
211
 
212
+ t1 = time.perf_counter()
 
 
 
 
213
 
214
+ print(f"Done in {t1 - t0:.1f}s.")
215
+ # bring into PIL format
216
+ x = x.clamp(-1, 1)
217
+ x = embed_watermark(x.float())
218
+ x = rearrange(x[0], "c h w -> h w c")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
221
+
222
+ filename = os.path.join(output_dir,f"round_0000_[{prompt}].jpg")
223
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
224
+ exif_data = Image.Exif()
225
+ if init_image is None:
226
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
227
+ else:
228
+ exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
229
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
230
+ exif_data[ExifTags.Base.Model] = name
231
+ if add_sampling_metadata:
232
+ exif_data[ExifTags.Base.ImageDescription] = prompt
233
+ img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
234
+ instructions = [prompt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
237
+ img_and_prompt = []
238
+ history_imgs = sorted(os.listdir(output_dir))
239
+ for img_file, prompt_txt in zip(history_imgs, instructions):
240
+ img_and_prompt.append((os.path.join(output_dir, img_file), prompt_txt))
241
+ history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
242
+ return img, history_gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+
245
+ @spaces.GPU(duration=200)
246
+ @torch.inference_mode()
247
+ def edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2=None):
248
+ global ae, t5, clip, model, name, is_schnell, output_dir, add_sampling_metadata, offload, source_image, history_tensors, instructions
249
+
250
+ device = "cuda" if torch.cuda.is_available() else "cpu"
251
+ torch.cuda.empty_cache()
252
+ seed = None
253
+
254
+ print(f"Inital_t5_device: {t5.hf_module.device}")
255
+ print(f"Inital_clip_device: {clip.hf_module.device}")
256
+ print(f"Inital_flow_model: {model.img_in.weight.device}")
257
+ print(f"Inital_flow_model self.model.img_in: {model.img_in}")
258
+ print(f"Inital_flow_model self.model.time_in.out_layer.weight: {model.time_in.out_layer.weight}")
259
+
260
+ #----------------------------- 0.1 prepare multi-turn editing -------------------------------------#
261
+ info = {}
262
+ shape = init_image.shape
263
+ new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
264
+ new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
265
+
266
+ if not any("round_0000" in fname for fname in os.listdir(output_dir)):
267
+ Image.fromarray(init_image).save(os.path.join(output_dir,"round_0000_[source].jpg"))
268
+
269
+ init_image = init_image[:new_h, :new_w, :]
270
+ width, height = init_image.shape[0], init_image.shape[1]
271
+
272
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
273
+ init_image = init_image.unsqueeze(0)
274
+ init_image = init_image.to(device)
275
+ if offload:
276
+ model.cpu()
277
+ torch.cuda.empty_cache()
278
+ ae.encoder.to(device)
279
 
280
+ with torch.no_grad():
281
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
282
+
283
+ if init_image_2 is None:
284
+ print("init_image_2 is not provided, proceeding with single image processing.")
285
+ else:
286
+ init_image_2_pil = Image.fromarray(init_image_2) # Convert NumPy array to PIL Image
287
+ init_image_2_pil = init_image_2_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
288
+ init_image_2 = np.array(init_image_2_pil) # Convert back to NumPy (if needed)
289
+ init_image_2 = torch.from_numpy(init_image_2).permute(2, 0, 1).float() / 127.5 - 1
290
+
291
+ rng = torch.Generator(device=torch.device("cpu"))
292
+ opts = SamplingOptions(
293
+ source_prompt=source_prompt,
294
+ target_prompt=target_prompt,
295
+ width=width,
296
+ height=height,
297
+ num_steps=num_steps,
298
+ guidance=guidance,
299
+ seed=None,
300
+ )
301
+ if opts.seed is None:
302
+ opts.seed = torch.Generator(device=torch.device("cpu")).seed()
303
+
304
+ print(f"Editing with prompt:\n{opts.source_prompt}")
305
+ t0 = time.perf_counter()
306
+
307
+ if offload:
308
+ ae = ae.cpu()
309
+ torch.cuda.empty_cache()
310
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
311
+ opts.seed = None
312
+
313
+
314
+ #----------------------------- 0.2 prepare attention strategy -------------------------------------#
315
+ info = {}
316
+ info['feature'] = {}
317
+ info['inject_step'] = inject_step
318
+ info['editing_strategy']= " ".join(editing_strategy)
319
+ info['start_layer_index'] = 0
320
+ info['end_layer_index'] = 37
321
+ info['reuse_v']= False
322
+ qkv_ratio = '1.0,1.0,1.0'
323
+ info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
324
+ info['attn_guidance'] = attn_guidance_start_block
325
+ info['lqr_stop'] = 0.25
326
+
327
+ #----------------------------- 0.3 prepare latents -------------------------------------#
328
+ with torch.no_grad():
329
+ inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
330
+ inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
331
+ if source_image is None:
332
+ source_image = inp['img']
333
+ inp_target_2 = None
334
+ if not init_image_2 is None:
335
+ inp_target_2 = prepare_image(init_image_2)
336
+ info['lqr_stop'] = 0.35
337
+
338
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
339
+ #timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=False)
340
+
341
+ # offload TEs to CPU, load model to gpu
342
+
343
+ if offload:
344
+ t5, clip = t5.cpu(), clip.cpu()
345
+ torch.cuda.empty_cache()
346
+ model = model.to(torch_device)
347
 
348
+ #----------------------------- 1 Inverting current image -------------------------------------#
349
+ denoise_strategies = ['fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion', 'multi_turn_consistent']
350
+ denoise_funcs = [denoise_fireflow, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent]
351
+ denoise_func = denoise_funcs[denoise_strategies.index(denoise_strategy)]
352
+ with torch.no_grad():
353
+ z, info = denoise_func(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
354
+
355
+
356
+ #----------------------------- 2 history_tensors used to implement dual-LQR guiding editing -------------------------------------#
357
+ inp_target["img"] = z
358
+ timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell"))
359
+
360
+ if torch.all(history_tensors['source img'] == 0):
361
+ history_tensors = {
362
+ "source img": inp["img"],
363
+ "prev img": inp_target_2}
364
+ else:
365
+ if inp_target_2 is None:
366
+ history_tensors["prev img"] = inp["img"]
367
+ else:
368
+ history_tensors["source img"] = inp["img"]
369
+ history_tensors["prev img"] = inp_target_2
370
+
371
+ #----------------------------- 3 sampling -------------------------------------#
372
+ if denoise_strategy in ['rf_inversion', 'multi_turn_consistent']:
373
+ x, _ = denoise_func(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info, img_LQR=history_tensors)
374
+ else:
375
+ x, _ = denoise_func(model, **inp_target, timesteps=timesteps, guidance=opts.guidance, inverse=False, info=info)
376
+
377
 
378
+ #----------------------------- 4 update history_tensors -------------------------------------#
379
+ info = {}
380
+ history_tensors["source img"] = source_image
381
+ history_tensors["prev img"] = x
 
382
 
383
+ #----------------------------- 5 decode x to image -------------------------------------#
384
+ x = unpack(x.float(), opts.width, opts.height)
385
 
386
+ if offload:
387
+ model.cpu()
388
+ torch.cuda.empty_cache()
389
+ ae.decoder.to(x.device)
390
+
391
+ device = torch.device("cuda")
392
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
393
+ x = ae.decode(x)
394
 
 
 
395
 
396
+ if torch.cuda.is_available():
397
+ torch.cuda.synchronize()
398
+ t1 = time.perf_counter()
399
 
400
+ # bring into PIL format and save
401
+ x = x.clamp(-1, 1)
402
+ x = embed_watermark(x.float())
403
+ x = rearrange(x[0], "c h w -> h w c")
404
 
405
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
406
+ exif_data = Image.Exif()
407
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
408
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
409
+ exif_data[ExifTags.Base.Model] = name
410
+ if add_sampling_metadata:
411
+ exif_data[ExifTags.Base.ImageDescription] = source_prompt
412
 
 
 
 
 
 
 
 
 
413
 
414
 
415
+ #-------------------------------- 6 save image -------------------------------------#
416
 
417
+ #-------------------- 6.1 prepare output folder ----------------------#
418
+ if not os.path.exists(output_dir):
419
+ os.makedirs(output_dir)
420
+ idx = 1
421
+ #-------------------- 6.2 editing round ----------------------#
422
+ else:
423
+ fns = [fn for fn in os.listdir(output_dir)]
424
+ if len(fns) > 0:
425
+ idx = max(int(fn.split("_")[1]) for fn in fns) + 1
426
  else:
427
+ idx = 1
428
+ formatted_idx = str(idx).zfill(4) # Format as a 4-digit string
429
+
430
+ #-------------------- 6.3 output name ----------------------#
431
+ if denoise_strategy == 'multi_turn_consistent':
432
+ denoise_strategy = 'MTC'
433
+ if target_prompt == '':
434
+ target_prompt = 'Reconstruction'
435
+ if target_prompt == source_prompt:
436
+ target_prompt = 'Reconstruction: ' + target_prompt
437
+
438
+ target_suffix = " ".join(target_prompt.split()[-5:])
439
+ output_name = f"round_{formatted_idx}_{target_suffix}_{denoise_strategy}.jpg"
440
+
441
+ fn = os.path.join(output_dir, output_name)
442
+
443
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
444
+ img.save(fn)
 
 
 
 
445
 
446
+ if 'Reconstruction' in target_prompt:
447
+ target_prompt = source_prompt
448
+ instructions.append(target_prompt)
449
+ print("End Edit")
450
 
451
+ #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
452
+ img_and_prompt = []
453
+ history_imgs = sorted(os.listdir(output_dir))
454
+ for img_file, prompt_txt in zip(history_imgs, instructions):
455
+ img_and_prompt.append((os.path.join(output_dir, img_file), prompt_txt))
456
+ history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
457
+
458
+ return img, history_gallery
459
 
460
 
461
  def on_select(gallery, selected: gr.SelectData):
 
469
  return gr.Gallery(value=[(img_path[0], "")], label="History Image", interactive=True, columns=3), img_path[0]
470
 
471
 
472
+ def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
 
473
  is_schnell = model_name == "flux-schnell"
474
 
475
  # Pre-defined examples
 
515
  example_image.change(on_change, example_image, [gallery, init_image])
516
 
517
  generate_btn.click(
518
+ fn=process_image,
519
  inputs=[init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2],
520
  outputs=[output_image, gallery]
521
  )
522
+ reset_btn.click(fn = reset, outputs=[source_prompt, target_prompt, gallery, output_image])
523
 
524
  # Add examples
525
  gr.Examples(
 
540
  return demo
541
 
542
 
543
+ demo = create_demo(name, "cuda")
544
+ #demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
545
+ demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
src/flux/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/__init__.cpython-310.pyc and b/src/flux/__pycache__/__init__.cpython-310.pyc differ
 
src/flux/__pycache__/_version.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/_version.cpython-310.pyc and b/src/flux/__pycache__/_version.cpython-310.pyc differ
 
src/flux/__pycache__/math.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/math.cpython-310.pyc and b/src/flux/__pycache__/math.cpython-310.pyc differ
 
src/flux/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/model.cpython-310.pyc and b/src/flux/__pycache__/model.cpython-310.pyc differ
 
src/flux/__pycache__/sampling.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/sampling.cpython-310.pyc and b/src/flux/__pycache__/sampling.cpython-310.pyc differ
 
src/flux/__pycache__/util.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/util.cpython-310.pyc and b/src/flux/__pycache__/util.cpython-310.pyc differ
 
src/flux/modules/__pycache__/autoencoder.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/autoencoder.cpython-310.pyc and b/src/flux/modules/__pycache__/autoencoder.cpython-310.pyc differ
 
src/flux/modules/__pycache__/conditioner.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/conditioner.cpython-310.pyc and b/src/flux/modules/__pycache__/conditioner.cpython-310.pyc differ
 
src/flux/modules/__pycache__/layers.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/layers.cpython-310.pyc and b/src/flux/modules/__pycache__/layers.cpython-310.pyc differ