ZhouwqZJ commited on
Commit
f0e942d
·
1 Parent(s): acf7b9d

modified: app.py

Browse files
Files changed (39) hide show
  1. .gitattributes +1 -0
  2. app.py +543 -4
  3. model_cards/FLUX.1-dev.md +46 -0
  4. model_cards/FLUX.1-schnell.md +41 -0
  5. model_licenses/LICENSE-FLUX1-dev +42 -0
  6. model_licenses/LICENSE-FLUX1-schnell +54 -0
  7. requirements.txt +94 -0
  8. src/flux/__init__.py +11 -0
  9. src/flux/__main__.py +4 -0
  10. src/flux/__pycache__/__init__.cpython-310.pyc +0 -0
  11. src/flux/__pycache__/__init__.cpython-312.pyc +0 -0
  12. src/flux/__pycache__/_version.cpython-312.pyc +0 -0
  13. src/flux/__pycache__/math.cpython-310.pyc +0 -0
  14. src/flux/__pycache__/math.cpython-312.pyc +0 -0
  15. src/flux/__pycache__/math.cpython-38.pyc +0 -0
  16. src/flux/__pycache__/model.cpython-310.pyc +0 -0
  17. src/flux/__pycache__/model.cpython-312.pyc +0 -0
  18. src/flux/__pycache__/sampling.cpython-310.pyc +0 -0
  19. src/flux/__pycache__/sampling.cpython-312.pyc +0 -0
  20. src/flux/__pycache__/util.cpython-310.pyc +0 -0
  21. src/flux/__pycache__/util.cpython-312.pyc +0 -0
  22. src/flux/_version.py +16 -0
  23. src/flux/api.py +194 -0
  24. src/flux/math.py +170 -0
  25. src/flux/model.py +120 -0
  26. src/flux/modules/__pycache__/autoencoder.cpython-310.pyc +0 -0
  27. src/flux/modules/__pycache__/autoencoder.cpython-312.pyc +0 -0
  28. src/flux/modules/__pycache__/conditioner.cpython-310.pyc +0 -0
  29. src/flux/modules/__pycache__/conditioner.cpython-312.pyc +0 -0
  30. src/flux/modules/__pycache__/layers.cpython-310.pyc +0 -0
  31. src/flux/modules/__pycache__/layers.cpython-312.pyc +0 -0
  32. src/flux/modules/autoencoder.py +313 -0
  33. src/flux/modules/conditioner.py +39 -0
  34. src/flux/modules/layers.py +306 -0
  35. src/flux/sampling.py +584 -0
  36. src/flux/util.py +212 -0
  37. src/gradio_examples/000000000011.jpg +3 -0
  38. src/gradio_examples/221000000002.jpg +3 -0
  39. src/gradio_utils/gradio_utils.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,7 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from io import BytesIO
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from glob import iglob
8
+ import argparse
9
+ from einops import rearrange
10
+ #from fire import Fire
11
+ from PIL import ExifTags, Image
12
+ from safetensors.torch import load_file, save_file
13
+
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
  import gradio as gr
18
+ import numpy as np
19
+ from transformers import pipeline
20
+
21
+ from src.flux.sampling import denoise_fireflow, get_schedule, prepare, prepare_image, unpack, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent, get_noise
22
+ from src.flux.util import (configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5)
23
+
24
+ os.environ["CUDA_VISIBLE_DEVICES"] = "2"
25
+
26
+ @dataclass
27
+ class SamplingOptions:
28
+ source_prompt: str
29
+ target_prompt: str
30
+ # prompt: str
31
+ width: int
32
+ height: int
33
+ num_steps: int
34
+ guidance: float
35
+ seed: int | None
36
+
37
+ @torch.inference_mode()
38
+ def encode(init_image, torch_device, ae):
39
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
40
+ init_image = init_image.unsqueeze(0)
41
+ init_image = init_image.to(torch_device)
42
+ with torch.no_grad():
43
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
44
+ return init_image
45
+
46
+
47
+ class FluxEditor:
48
+ def __init__(self, args):
49
+ self.args = args
50
+ self.device = torch.device(args.device)
51
+ self.offload = args.offload
52
+ self.name = args.name
53
+ self.is_schnell = args.name == "flux-schnell"
54
+
55
+ self.feature_path = 'feature'
56
+
57
+ self.reset()
58
+
59
+ self.add_sampling_metadata = True
60
+
61
+ if self.name not in configs:
62
+ available = ", ".join(configs.keys())
63
+ raise ValueError(f"Got unknown model name: {self.name}, chose from {available}")
64
+
65
+ # init all components
66
+ self.clip = load_clip(self.device)
67
+ self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
68
+ self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
69
+ self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
70
+ self.t5.eval()
71
+ self.clip.eval()
72
+ self.ae.eval()
73
+ self.model.eval()
74
+
75
+ # clear history
76
+ if os.path.exists("history_gradio/history.safetensors"):
77
+ os.remove("history_gradio/history.safetensors")
78
+
79
+
80
+ @torch.inference_mode()
81
+ def reset(self):
82
+ out_root = 'src/gradio_utils/gradio_outputs'
83
+ name_dir = f'exp_{len(os.listdir(out_root))}'
84
+ self.output_dir = os.path.join(out_root, name_dir)
85
+ if not os.path.exists(self.output_dir):
86
+ os.makedirs(self.output_dir)
87
+ self.instructions = ['source']
88
+ self.source_image = None
89
+ self.history_tensors = {
90
+ "source img": torch.zeros((1, 1, 1)),
91
+ "prev img": torch.zeros((1, 1, 1))}
92
+
93
+ source_prompt = "(Optional) Describe the content of the uploaded image."
94
+ traget_prompt = "(Required) Describe the desired content of the edited image."
95
+ gallery = None
96
+ output_image = None
97
+ return source_prompt, traget_prompt, gallery, output_image
98
+
99
+
100
+ @torch.inference_mode()
101
+ def process_image(self,
102
+ init_image,
103
+ source_prompt,
104
+ target_prompt,
105
+ editing_strategy,
106
+ denoise_strategy,
107
+ num_steps,
108
+ guidance,
109
+ attn_guidance_start_block,
110
+ inject_step,
111
+ init_image_2=None):
112
+ if init_image is None:
113
+ img, gr_gallery = self.generate_image(prompt=target_prompt)
114
+ else:
115
+ img, gr_gallery = self.edit(init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2)
116
+ return img, gr_gallery
117
+
118
+
119
+
120
+ @torch.inference_mode()
121
+ def generate_image(
122
+ self,
123
+ width=512,
124
+ height=512,
125
+ num_steps=28,
126
+ guidance=3.5,
127
+ seed=None,
128
+ prompt='',
129
+ init_image=None,
130
+ image2image_strength=0.0,
131
+ add_sampling_metadata=True,
132
+ ):
133
+
134
+ if seed is None:
135
+ g_seed = torch.Generator(device="cpu").seed()
136
+ print(f"Generating '{prompt}' with seed {g_seed}")
137
+ t0 = time.perf_counter()
138
+
139
+ if init_image is not None:
140
+ if isinstance(init_image, np.ndarray):
141
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
142
+ init_image = init_image.unsqueeze(0)
143
+ init_image = init_image.to(self.device)
144
+ init_image = torch.nn.functional.interpolate(init_image, (height, width))
145
+ if self.offload:
146
+ self.ae.encoder.to(self.device)
147
+ init_image = self.ae.encode(init_image.to())
148
+ if self.offload:
149
+ self.ae = self.ae.cpu()
150
+ torch.cuda.empty_cache()
151
+
152
+ # prepare input
153
+ x = get_noise(
154
+ 1,
155
+ height,
156
+ width,
157
+ device=self.device,
158
+ dtype=torch.bfloat16,
159
+ seed=g_seed,
160
+ )
161
+ timesteps = get_schedule(
162
+ num_steps,
163
+ x.shape[-1] * x.shape[-2] // 4,
164
+ shift=(not self.is_schnell),
165
+ )
166
+ if init_image is not None:
167
+ t_idx = int((1 - image2image_strength) * num_steps)
168
+ t = timesteps[t_idx]
169
+ timesteps = timesteps[t_idx:]
170
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
171
+
172
+ if self.offload:
173
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
174
+ inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
175
+
176
+ # offload TEs to CPU, load model to gpu
177
+ if self.offload:
178
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
179
+ torch.cuda.empty_cache()
180
+ self.model = self.model.to(self.device)
181
+
182
+ # denoise initial noise
183
+ info = {}
184
+ info['feature'] = {}
185
+ info['inject_step'] = 0
186
+ info['editing_strategy']= ""
187
+ info['start_layer_index'] = 0
188
+ info['end_layer_index'] = 37
189
+ info['reuse_v']= False
190
+ qkv_ratio = '1.0,1.0,1.0'
191
+ info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
192
+ x = denoise_rf(self.model, **inp, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
193
+
194
+ # offload model, load autoencoder to gpu
195
+ if self.offload:
196
+ self.model.cpu()
197
+ torch.cuda.empty_cache()
198
+ self.ae.decoder.to(x.device)
199
+
200
+ # decode latents to pixel space
201
+ x = unpack(x[0].float(), height, width)
202
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
203
+ x = self.ae.decode(x)
204
+
205
+ if self.offload:
206
+ self.ae.decoder.cpu()
207
+ torch.cuda.empty_cache()
208
+
209
+ t1 = time.perf_counter()
210
+
211
+ print(f"Done in {t1 - t0:.1f}s.")
212
+ # bring into PIL format
213
+ x = x.clamp(-1, 1)
214
+ x = embed_watermark(x.float())
215
+ x = rearrange(x[0], "c h w -> h w c")
216
+
217
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
218
+
219
+ filename = os.path.join(self.output_dir,f"round_0000_[{prompt}].jpg")
220
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
221
+ exif_data = Image.Exif()
222
+ if init_image is None:
223
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
224
+ else:
225
+ exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
226
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
227
+ exif_data[ExifTags.Base.Model] = self.name
228
+ if add_sampling_metadata:
229
+ exif_data[ExifTags.Base.ImageDescription] = prompt
230
+ img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
231
+ self.instructions = [prompt]
232
+
233
+ #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
234
+ img_and_prompt = []
235
+ history_imgs = sorted(os.listdir(self.output_dir))
236
+ for img_file, prompt_txt in zip(history_imgs, self.instructions):
237
+ img_and_prompt.append((os.path.join(self.output_dir, img_file), prompt_txt))
238
+ history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
239
+ return img, history_gallery
240
+
241
+
242
+ @torch.inference_mode()
243
+ def edit(self, init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2=None):
244
+
245
+ torch.cuda.empty_cache()
246
+ seed = None
247
+
248
+ if self.offload:
249
+ self.model.cpu()
250
+ torch.cuda.empty_cache()
251
+ self.ae.encoder.to(self.device)
252
+
253
+ #----------------------------- 0.1 prepare multi-turn editing -------------------------------------#
254
+ info = {}
255
+ shape = init_image.shape
256
+ new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
257
+ new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
258
+
259
+ if not any("round_0000" in fname for fname in os.listdir(self.output_dir)):
260
+ Image.fromarray(init_image).save(os.path.join(self.output_dir,"round_0000_[source].jpg"))
261
+
262
+
263
+ init_image = init_image[:new_h, :new_w, :]
264
+ width, height = init_image.shape[0], init_image.shape[1]
265
+ init_image = encode(init_image, self.device, self.ae)
266
+
267
+ print(init_image.shape)
268
+
269
+ if init_image_2 is None:
270
+ print("init_image_2 is not provided, proceeding with single image processing.")
271
+ else:
272
+ init_image_2_pil = Image.fromarray(init_image_2) # Convert NumPy array to PIL Image
273
+ init_image_2_pil = init_image_2_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
274
+ init_image_2 = np.array(init_image_2_pil) # Convert back to NumPy (if needed)
275
+ init_image_2 = encode(init_image_2, self.device, self.ae)
276
+
277
+ rng = torch.Generator(device="cpu")
278
+ opts = SamplingOptions(
279
+ source_prompt=source_prompt,
280
+ target_prompt=target_prompt,
281
+ width=width,
282
+ height=height,
283
+ num_steps=num_steps,
284
+ guidance=guidance,
285
+ seed=seed,
286
+ )
287
+ if opts.seed is None:
288
+ opts.seed = torch.Generator(device="cpu").seed()
289
+
290
+ print(f"Editing with prompt:\n{opts.source_prompt}")
291
+ t0 = time.perf_counter()
292
+
293
+ opts.seed = None
294
+ if self.offload:
295
+ self.ae = self.ae.cpu()
296
+ torch.cuda.empty_cache()
297
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
298
+
299
+ #----------------------------- 0.2 prepare attention strategy -------------------------------------#
300
+ info = {}
301
+ info['feature'] = {}
302
+ info['inject_step'] = inject_step
303
+ info['editing_strategy']= " ".join(editing_strategy)
304
+ info['start_layer_index'] = 0
305
+ info['end_layer_index'] = 37
306
+ info['reuse_v']= False
307
+ qkv_ratio = '1.0,1.0,1.0'
308
+ info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
309
+ info['attn_guidance'] = attn_guidance_start_block
310
+ info['lqr_stop'] = 0.25
311
+
312
+ if not os.path.exists(self.feature_path):
313
+ os.mkdir(self.feature_path)
314
+
315
+
316
+ #----------------------------- 0.3 prepare latents -------------------------------------#
317
+ with torch.no_grad():
318
+ inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
319
+ inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
320
+ if self.source_image is None:
321
+ self.source_image = inp['img']
322
+ inp_target_2 = None
323
+ if not init_image_2 is None:
324
+ inp_target_2 = prepare_image(init_image_2)
325
+ info['lqr_stop'] = 0.35
326
+
327
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
328
+ #timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=False)
329
+
330
+ # offload TEs to CPU, load model to gpu
331
+ if self.offload:
332
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
333
+ torch.cuda.empty_cache()
334
+ self.model = self.model.to(self.device)
335
+
336
+
337
+
338
+ #----------------------------- 1 Inverting current image -------------------------------------#
339
+ denoise_strategies = ['fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion', 'multi_turn_consistent']
340
+ denoise_funcs = [denoise_fireflow, denoise_rf, denoise_rf_solver, denoise_midpoint, denoise_rf_inversion, denoise_multi_turn_consistent]
341
+ denoise_func = denoise_funcs[denoise_strategies.index(denoise_strategy)]
342
+ with torch.no_grad():
343
+ z, info = denoise_func(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
344
+
345
+
346
+
347
+
348
+ #----------------------------- 2 history_tensors used to implement dual-LQR guiding editing -------------------------------------#
349
+ inp_target["img"] = z
350
+ timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
351
+
352
+ if torch.all(self.history_tensors['source img'] == 0):
353
+ self.history_tensors = {
354
+ "source img": inp["img"],
355
+ "prev img": inp_target_2}
356
+ else:
357
+ if inp_target_2 is None:
358
+ self.history_tensors["prev img"] = inp["img"]
359
+ else:
360
+ self.history_tensors["source img"] = inp["img"]
361
+ self.history_tensors["prev img"] = inp_target_2
362
+
363
+ #----------------------------- 3 sampling -------------------------------------#
364
+ if denoise_strategy in ['rf_inversion', 'multi_turn_consistent']:
365
+ x, _ = denoise_func(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info, img_LQR=self.history_tensors)
366
+ else:
367
+ x, _ = denoise_func(self.model, **inp_target, timesteps=timesteps, guidance=opts.guidance, inverse=False, info=info)
368
+
369
+
370
+ #----------------------------- 4 update history_tensors -------------------------------------#
371
+ info = {}
372
+ self.history_tensors["source img"] = self.source_image
373
+ self.history_tensors["prev img"] = x
374
+ '''save_file(history_tensors, "history_gradio/history.safetensors")'''
375
+
376
+ # offload model, load autoencoder to gpu
377
+ if self.offload:
378
+ self.model.cpu()
379
+ torch.cuda.empty_cache()
380
+ self.ae.decoder.to(x.device)
381
+
382
+
383
+
384
+ #----------------------------- 5 decode x to image -------------------------------------#
385
+ x = unpack(x.float(), opts.width, opts.height)
386
+
387
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
388
+ x = self.ae.decode(x)
389
+
390
+ if torch.cuda.is_available():
391
+ torch.cuda.synchronize()
392
+ t1 = time.perf_counter()
393
+
394
+ # bring into PIL format and save
395
+ x = x.clamp(-1, 1)
396
+ x = embed_watermark(x.float())
397
+ x = rearrange(x[0], "c h w -> h w c")
398
+
399
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
400
+ exif_data = Image.Exif()
401
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
402
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
403
+ exif_data[ExifTags.Base.Model] = self.name
404
+ if self.add_sampling_metadata:
405
+ exif_data[ExifTags.Base.ImageDescription] = source_prompt
406
+
407
+
408
+
409
+ #-------------------------------- 6 save image -------------------------------------#
410
+
411
+ #-------------------- 6.1 prepare output folder ----------------------#
412
+ if not os.path.exists(self.output_dir):
413
+ os.makedirs(self.output_dir)
414
+ idx = 1
415
+ #-------------------- 6.2 editing round ----------------------#
416
+ else:
417
+ fns = [fn for fn in os.listdir(self.output_dir)]
418
+ if len(fns) > 0:
419
+ idx = max(int(fn.split("_")[1]) for fn in fns) + 1
420
+ else:
421
+ idx = 1
422
+ formatted_idx = str(idx).zfill(4) # Format as a 4-digit string
423
+
424
+ #-------------------- 6.3 output name ----------------------#
425
+ if denoise_strategy == 'multi_turn_consistent':
426
+ denoise_strategy = 'MTC'
427
+ if target_prompt == '':
428
+ target_prompt = 'Reconstruction'
429
+ if target_prompt == source_prompt:
430
+ target_prompt = 'Reconstruction: ' + target_prompt
431
+
432
+ output_name = f"round_{formatted_idx}_[{" ".join(target_prompt.split()[-5:])}]_{denoise_strategy}.jpg"
433
+ fn = os.path.join(self.output_dir, output_name)
434
+
435
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
436
+ img.save(fn)
437
+
438
+ if 'Reconstruction' in target_prompt:
439
+ target_prompt = source_prompt
440
+ self.instructions.append(target_prompt)
441
+ print("End Edit")
442
+
443
+ #-------------------- 6.4 save editing prompt, update gradio component: gallery ----------------------#
444
+ img_and_prompt = []
445
+ history_imgs = sorted(os.listdir(self.output_dir))
446
+ for img_file, prompt_txt in zip(history_imgs, self.instructions):
447
+ img_and_prompt.append((os.path.join(self.output_dir, img_file), prompt_txt))
448
+ history_gallery = gr.Gallery(value=img_and_prompt, label="History Image", interactive=True, columns=3)
449
+
450
+ return img, history_gallery
451
+
452
+
453
+ def on_select(gallery, selected: gr.SelectData):
454
+ return gallery[selected.index][0], gallery[selected.index][1]
455
+
456
+ def on_upload(path, uploaded: gr.EventData):
457
+ return path[0][0]
458
+
459
+ def on_change(init_image, changed: gr.EventData):
460
+ img_path = list(changed.target.temp_files)
461
+ return gr.Gallery(value=[(img_path[0], "")], label="History Image", interactive=True, columns=3)
462
+
463
+ def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
464
+ editor = FluxEditor(args)
465
+ is_schnell = model_name == "flux-schnell"
466
+
467
+ # Pre-defined examples
468
+ examples = [
469
+ ["src/gradio_utils/gradio_examples/000000000011.jpg", "", "a photo of a eagle standing on the branch", ['attn_guidance'], 15, 3.5, 11, 0],
470
+ ["src/gradio_utils/gradio_examples/221000000002.jpg", "", "a cat wearing a hat standing on the fence", ['attn_guidance'], 15, 3.5, 11, 0],
471
+ ]
472
+
473
+ with gr.Blocks() as demo:
474
+ gr.Markdown(f"# Multi-turn Consistent Image Editing (FLUX.1-dev)")
475
+
476
+ with gr.Row():
477
+ with gr.Column():
478
+ source_prompt = gr.Textbox(label="Source Prompt", value="(Optional) Describe the content of the uploaded image.")
479
+ target_prompt = gr.Textbox(label="Target Prompt", value="(Required) Describe the desired content of the edited image.")
480
+ with gr.Row():
481
+ init_image = gr.Image(label="Initial Image", visible=False, width=200)
482
+ init_image_2 = gr.Image(label="Input Image 2", visible=False, width=200)
483
+ gallery = gr.Gallery(label ="History Image", interactive=True, columns=3)
484
+ editing_strategy = gr.CheckboxGroup(
485
+ label="Editing Technique",
486
+ choices=['attn_guidance', 'replace_v', 'add_q', 'add_k', 'add_v', 'replace_q', 'replace_k'],
487
+ value=['attn_guidance'], # Default: none selected
488
+ interactive=True
489
+ )
490
+ denoise_strategy = gr.Dropdown(
491
+ ['multi_turn_consistent', 'fireflow', 'rf', 'rf_solver', 'midpoint', 'rf_inversion'],
492
+ label="Denoising Technique", value='multi_turn_consistent')
493
+ generate_btn = gr.Button("Generate")
494
+
495
+ with gr.Column():
496
+ with gr.Accordion("Advanced Options", open=True):
497
+ num_steps = gr.Slider(1, 30, 15, step=1, label="Number of steps")
498
+ guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Text Guidance", interactive=not is_schnell)
499
+ attn_guidance_start_block = gr.Slider(0, 18, 11, step=1, label="Top activated attn-maps", interactive=not is_schnell)
500
+ inject_step = gr.Slider(0, 15, 1, step=1, label="Number of inject steps")
501
+ output_image = gr.Image(label="Generated/Edited Image")
502
+ reset_btn = gr.Button("Reset")
503
+
504
+ gallery.select(on_select, gallery, [init_image, source_prompt])
505
+ gallery.upload(on_upload, gallery, init_image)
506
+ init_image.change(on_change, init_image, gallery)
507
+
508
+ generate_btn.click(
509
+ fn=editor.process_image,
510
+ inputs=[init_image, source_prompt, target_prompt, editing_strategy, denoise_strategy, num_steps, guidance, attn_guidance_start_block, inject_step, init_image_2],
511
+ outputs=[output_image, gallery]
512
+ )
513
+ reset_btn.click(fn = editor.reset, outputs=[source_prompt, target_prompt, gallery, output_image])
514
+
515
+ # Add examples
516
+ gr.Examples(
517
+ examples=examples,
518
+ inputs=[
519
+ init_image,
520
+ source_prompt,
521
+ target_prompt,
522
+ editing_strategy,
523
+ num_steps,
524
+ guidance,
525
+ attn_guidance_start_block,
526
+ inject_step
527
+ ]
528
+ )
529
+
530
+
531
+ return demo
532
+
533
 
534
+ if __name__ == "__main__":
535
+ import argparse
536
+ parser = argparse.ArgumentParser(description="Flux")
537
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
538
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
539
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
540
+ parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
541
+ parser.add_argument("--port", type=int, default=9090)
542
+ args = parser.parse_args()
543
 
544
+ demo = create_demo(args.name, args.device, args.offload)
545
+ #demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
546
+ demo.launch(share=True)
model_cards/FLUX.1-dev.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)
2
+
3
+ `FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
4
+ For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
5
+
6
+ # Key Features
7
+ 1. Cutting-edge output quality, second only to our state-of-the-art model `FLUX.1 [pro]`.
8
+ 2. Competitive prompt following, matching the performance of closed source alternatives.
9
+ 3. Trained using guidance distillation, making `FLUX.1 [dev]` more efficient.
10
+ 4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
11
+ 5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](./licence.md).
12
+
13
+ # Usage
14
+ We provide a reference implementation of `FLUX.1 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
15
+ Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.
16
+
17
+ ## API Endpoints
18
+ The FLUX.1 models are also available via API from the following sources
19
+ 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
20
+ 2. [replicate.com](https://replicate.com/collections/flux)
21
+ 3. [fal.ai](https://fal.ai/models/fal-ai/flux/dev)
22
+
23
+ ## ComfyUI
24
+ `FLUX.1 [dev]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
25
+
26
+ ---
27
+ # Limitations
28
+ - This model is not intended or able to provide factual information.
29
+ - As a statistical model this checkpoint might amplify existing societal biases.
30
+ - The model may fail to generate output that matches the prompts.
31
+ - Prompt following is heavily influenced by the prompting-style.
32
+
33
+ # Out-of-Scope Use
34
+ The model and its derivatives may not be used
35
+
36
+ - In any way that violates any applicable national, federal, state, local or international law or regulation.
37
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
38
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others.
39
+ - To generate or disseminate personal identifiable information that can be used to harm an individual.
40
+ - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
41
+ - To create non-consensual nudity or illegal pornographic content.
42
+ - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
43
+ - Generating or facilitating large-scale disinformation campaigns.
44
+
45
+ # License
46
+ This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
model_cards/FLUX.1-schnell.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)
2
+
3
+ `FLUX.1 [schnell]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
4
+ For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
5
+
6
+ # Key Features
7
+ 1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
8
+ 2. Trained using latent adversarial diffusion distillation, `FLUX.1 [schnell]` can generate high-quality images in only 1 to 4 steps.
9
+ 3. Released under the `apache-2.0` licence, the model can be used for personal, scientific, and commercial purposes.
10
+
11
+ # Usage
12
+ We provide a reference implementation of `FLUX.1 [schnell]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
13
+ Developers and creatives looking to build on top of `FLUX.1 [schnell]` are encouraged to use this as a starting point.
14
+
15
+ ## API Endpoints
16
+ The FLUX.1 models are also available via API from the following sources
17
+ 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
18
+ 2. [replicate.com](https://replicate.com/collections/flux)
19
+ 3. [fal.ai](https://fal.ai/models/fal-ai/flux/schnell)
20
+
21
+ ## ComfyUI
22
+ `FLUX.1 [schnell]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
23
+
24
+ ---
25
+ # Limitations
26
+ - This model is not intended or able to provide factual information.
27
+ - As a statistical model this checkpoint might amplify existing societal biases.
28
+ - The model may fail to generate output that matches the prompts.
29
+ - Prompt following is heavily influenced by the prompting-style.
30
+
31
+ # Out-of-Scope Use
32
+ The model and its derivatives may not be used
33
+
34
+ - In any way that violates any applicable national, federal, state, local or international law or regulation.
35
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
36
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others.
37
+ - To generate or disseminate personal identifiable information that can be used to harm an individual.
38
+ - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
39
+ - To create non-consensual nudity or illegal pornographic content.
40
+ - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
41
+ - Generating or facilitating large-scale disinformation campaigns.
model_licenses/LICENSE-FLUX1-dev ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FLUX.1 [dev] Non-Commercial License
2
+ Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] text-to-image AI model and its elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI model made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”).
3
+ By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
4
+ 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings:
5
+ a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
6
+ b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
7
+ c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose.
8
+ d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
9
+ e. “you” or “your” means the individual or entity entering into this License with Company.
10
+ 2. License Grant.
11
+ a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf.
12
+ b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai.
13
+ c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
14
+ d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model.
15
+ 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
16
+ a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
17
+ b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
18
+ “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc.
19
+ IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
20
+ c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and
21
+ d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions.
22
+ e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
23
+ 4. Restrictions. You will not, and will not permit, assist or cause any third party to
24
+ a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
25
+ b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
26
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or
27
+ d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License.
28
+ e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
29
+ f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
30
+ 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
31
+ 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
32
+ 7. INDEMNIFICATION
33
+
34
+ You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
35
+ 8. Termination; Survival.
36
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
37
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
38
+ c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
39
+ d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11.
40
+ 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
41
+ 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
42
+ 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company.
model_licenses/LICENSE-FLUX1-schnell ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
16
+
17
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
18
+
19
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
20
+
21
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
22
+
23
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
24
+
25
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
26
+
27
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
28
+
29
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
30
+
31
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
32
+
33
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
34
+
35
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
36
+
37
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
38
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
39
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
40
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
41
+
42
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
43
+
44
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
45
+
46
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
47
+
48
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
49
+
50
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
51
+
52
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
53
+
54
+ END OF TERMS AND CONDITIONS
requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.4.0
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.8.0
5
+ certifi==2025.1.31
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ diffusers==0.30.0
9
+ einops==0.8.1
10
+ fastapi==0.115.11
11
+ ffmpy==0.5.0
12
+ filelock==3.17.0
13
+ fire==0.7.0
14
+ -e git+https://github.com/HolmesShuan/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing@df4eab3f73eed5efa175438532962c42bf033cf2#egg=FireFlow
15
+ flux==1.3.5
16
+ fsspec==2025.2.0
17
+ gradio==5.20.0
18
+ gradio_client==1.7.2
19
+ groovy==0.1.2
20
+ h11==0.14.0
21
+ httpcore==1.0.7
22
+ httpx==0.28.1
23
+ huggingface-hub==0.29.1
24
+ idna==3.10
25
+ importlib_metadata==8.6.1
26
+ invisible-watermark==0.2.0
27
+ Jinja2==3.1.5
28
+ markdown-it-py==3.0.0
29
+ MarkupSafe==2.1.5
30
+ mdurl==0.1.2
31
+ mpmath==1.3.0
32
+ networkx==3.4.2
33
+ numpy==2.2.3
34
+ nvidia-cublas-cu12==12.4.5.8
35
+ nvidia-cuda-cupti-cu12==12.4.127
36
+ nvidia-cuda-nvrtc-cu12==12.4.127
37
+ nvidia-cuda-runtime-cu12==12.4.127
38
+ nvidia-cudnn-cu12==9.1.0.70
39
+ nvidia-cufft-cu12==11.2.1.3
40
+ nvidia-curand-cu12==10.3.5.147
41
+ nvidia-cusolver-cu12==11.6.1.9
42
+ nvidia-cusparse-cu12==12.3.1.170
43
+ nvidia-cusparselt-cu12==0.6.2
44
+ nvidia-nccl-cu12==2.21.5
45
+ nvidia-nvjitlink-cu12==12.4.127
46
+ nvidia-nvtx-cu12==12.4.127
47
+ opencv-python==4.11.0.86
48
+ orjson==3.10.15
49
+ packaging==24.2
50
+ pandas==2.2.3
51
+ pillow==11.1.0
52
+ protobuf==5.29.3
53
+ psutil==7.0.0
54
+ pydantic==2.10.6
55
+ pydantic_core==2.27.2
56
+ pydub==0.25.1
57
+ Pygments==2.19.1
58
+ python-dateutil==2.9.0.post0
59
+ python-multipart==0.0.20
60
+ pytz==2025.1
61
+ PyWavelets==1.8.0
62
+ PyYAML==6.0.2
63
+ regex==2024.11.6
64
+ requests==2.32.3
65
+ rich==13.9.4
66
+ ruff==0.9.9
67
+ safehttpx==0.1.6
68
+ safetensors==0.5.3
69
+ scipy==1.15.2
70
+ semantic-version==2.10.0
71
+ sentencepiece==0.2.0
72
+ setuptools==75.8.0
73
+ shellingham==1.5.4
74
+ six==1.17.0
75
+ sniffio==1.3.1
76
+ starlette==0.46.0
77
+ sympy==1.13.1
78
+ termcolor==2.5.0
79
+ tokenizers==0.21.0
80
+ tomlkit==0.13.2
81
+ torch==2.6.0
82
+ torch-fidelity==0.3.0
83
+ torchvision==0.21.0
84
+ tqdm==4.67.1
85
+ transformers==4.49.0
86
+ triton==3.2.0
87
+ typer==0.15.2
88
+ typing_extensions==4.12.2
89
+ tzdata==2025.1
90
+ urllib3==2.3.0
91
+ uvicorn==0.34.0
92
+ websockets==15.0
93
+ wheel==0.45.1
94
+ zipp==3.21.0
src/flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
src/flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
src/flux/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (503 Bytes). View file
 
src/flux/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (630 Bytes). View file
 
src/flux/__pycache__/_version.cpython-312.pyc ADDED
Binary file (576 Bytes). View file
 
src/flux/__pycache__/math.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
src/flux/__pycache__/math.cpython-312.pyc ADDED
Binary file (9.2 kB). View file
 
src/flux/__pycache__/math.cpython-38.pyc ADDED
Binary file (1.46 kB). View file
 
src/flux/__pycache__/model.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
src/flux/__pycache__/model.cpython-312.pyc ADDED
Binary file (5.97 kB). View file
 
src/flux/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
src/flux/__pycache__/sampling.cpython-312.pyc ADDED
Binary file (18.2 kB). View file
 
src/flux/__pycache__/util.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
src/flux/__pycache__/util.cpython-312.pyc ADDED
Binary file (9.3 kB). View file
 
src/flux/_version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.post0+d20241105'
16
+ __version_tuple__ = version_tuple = (0, 0, 'd20241105')
src/flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
src/flux/math.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+ import math
5
+ from torchvision.utils import save_image
6
+ from torchvision.io import read_image
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+
10
+
11
+ def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor:
12
+ q, k = apply_rope(q, k, pe)
13
+
14
+ #x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
15
+ x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info)
16
+ x = rearrange(x, "B H L D -> B L (H D)")
17
+
18
+ return x
19
+
20
+
21
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22
+ q, k = apply_rope(q, k, pe)
23
+
24
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25
+ x = rearrange(x, "B H L D -> B L (H D)")
26
+
27
+ return x
28
+
29
+
30
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31
+ assert dim % 2 == 0
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+ out = torch.einsum("...n,d->...nd", pos, omega)
35
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37
+ return out.float()
38
+
39
+
40
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
46
+
47
+
48
+ def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4):
49
+
50
+ mask_list = []
51
+ for img_path in load_list:
52
+ load_mask_img = Image.open(img_path).convert('L')
53
+ # Define the transformation
54
+ transform = transforms.PILToTensor()
55
+ mask_tensor = transform(load_mask_img)
56
+ mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype) # Set device and dtype
57
+ mask_tensor /= 255.0
58
+ mask_list.append(mask_tensor) # Collect masks
59
+
60
+ # Sort masks based on their activation levels
61
+ mask_list.sort(key=lambda x: x.sum().item(), reverse=True)
62
+ # Select the 5 medium activated masks
63
+ num_masks = len(mask_list)
64
+ if num_masks > mask_num:
65
+ #selected_masks = mask_list[num_masks//2 - mask_num : num_masks//2]
66
+ start_block = info['attn_guidance']
67
+ end_block = info['attn_guidance'] + mask_num
68
+ if end_block > num_masks - 1:
69
+ selected_masks = mask_list[-mask_num: ]
70
+ else:
71
+ selected_masks = mask_list[start_block: end_block]
72
+ else:
73
+ selected_masks = mask_list
74
+
75
+ # Accumulate the selected masks
76
+ for mask in selected_masks:
77
+ mask_accumulator += mask
78
+
79
+ mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype) # Average the masks and convert back to original dtype
80
+ mask_tensor[mask_tensor >= thre] = 1
81
+ mask_tensor[mask_tensor < thre] = 0
82
+
83
+ return mask_tensor
84
+
85
+
86
+ # Efficient implementation equivalent to the following:
87
+ def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info,
88
+ token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3,
89
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
90
+ L, S = query.size(-2), key.size(-2)
91
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
92
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
93
+ if is_causal:
94
+ assert attn_mask is None
95
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
96
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
97
+ attn_bias.to(query.dtype)
98
+
99
+ if attn_mask is not None:
100
+ if attn_mask.dtype == torch.bool:
101
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
102
+ else:
103
+ attn_bias += attn_mask
104
+
105
+ if enable_gqa:
106
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
107
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
108
+
109
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
110
+ attn_weight += attn_bias
111
+ attn_weight = torch.softmax(attn_weight, dim=-1)
112
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
113
+
114
+ if not info['inverse']:
115
+ # GENERATE MASK
116
+ txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape] # lower left part
117
+ # each column maps to a token's heatmap
118
+ token_heatmap = txt_img_cross[:, :, :, token_index] # Shape: [1, 24, 1024]
119
+ token_heatmap = token_heatmap.mean(dim=1)[0] # Shape: [1024]
120
+ min_val, max_val = token_heatmap.min(), token_heatmap.max()
121
+ norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
122
+
123
+ mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
124
+
125
+ H = W = int(math.sqrt(mask_img.size(0)))
126
+ mask_img = mask_img.reshape(H, W)
127
+
128
+ save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png'
129
+ load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer] #save_image(mask_img.unsqueeze(0), save_path)
130
+ save_image(mask_img.unsqueeze(0), save_path)
131
+
132
+ mask_img[mask_img >= thre] = 1
133
+ mask_img[mask_img < thre] = 0
134
+ #save_image(mask_img.unsqueeze(0), save_path)
135
+
136
+ mask_tensor = torch.zeros_like(mask_img) # Set mask_tensor as a zero tensor
137
+ if cur_step > 3:
138
+ mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype) # Accumulator for averaging masks
139
+ mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4)
140
+ if cur_block == 1:
141
+ save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png')
142
+
143
+
144
+ if not torch.all(mask_tensor == 0):
145
+ highlight_factor = 2.0 # Factor to increase weights in the masked area
146
+ reduce_factor = 0.8 # Factor to decrease weights in the unmasked area
147
+
148
+ mask_tensor = mask_tensor.reshape(1, H * W)
149
+ mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1)
150
+ # Create a multiplier tensor: 2.0 where mask is active, 0.5 where mask is inactive.
151
+ multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor))
152
+ attn_weight[:, :, -img_shape:, :15] *= multiplier
153
+
154
+ return attn_weight @ value
155
+
156
+ '''
157
+ if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12):
158
+ mask_img = torch.zeros_like(mask_img)
159
+ for j in range(5):
160
+ token_heatmap = txt_img_cross[:, :, :, j]
161
+ token_heatmap = token_heatmap.mean(dim=1)[0]
162
+ min_val, max_val = token_heatmap.min(), token_heatmap.max()
163
+ norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
164
+
165
+ mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
166
+
167
+ H = W = int(math.sqrt(mask_img.size(0)))
168
+ mask_img = mask_img.reshape(H, W)
169
+ save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png'
170
+ save_image(mask_img.unsqueeze(0), save_path)'''
src/flux/model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7
+ MLPEmbedder, SingleStreamBlock,
8
+ timestep_embedding)
9
+
10
+
11
+ @dataclass
12
+ class FluxParams:
13
+ in_channels: int
14
+ vec_in_dim: int
15
+ context_in_dim: int
16
+ hidden_size: int
17
+ mlp_ratio: float
18
+ num_heads: int
19
+ depth: int
20
+ depth_single_blocks: int
21
+ axes_dim: list[int]
22
+ theta: int
23
+ qkv_bias: bool
24
+ guidance_embed: bool
25
+
26
+
27
+ class Flux(nn.Module):
28
+ """
29
+ Transformer model for flow matching on sequences.
30
+ """
31
+
32
+ def __init__(self, params: FluxParams):
33
+ super().__init__()
34
+
35
+ self.params = params
36
+ self.in_channels = params.in_channels
37
+ self.out_channels = self.in_channels
38
+ if params.hidden_size % params.num_heads != 0:
39
+ raise ValueError(
40
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41
+ )
42
+ pe_dim = params.hidden_size // params.num_heads
43
+ if sum(params.axes_dim) != pe_dim:
44
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45
+ self.hidden_size = params.hidden_size
46
+ self.num_heads = params.num_heads
47
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51
+ self.guidance_in = (
52
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53
+ )
54
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleStreamBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=params.mlp_ratio,
62
+ qkv_bias=params.qkv_bias,
63
+ cur_block=i,
64
+ )
65
+ for i in range(params.depth)
66
+ ]
67
+ )
68
+
69
+ self.single_blocks = nn.ModuleList(
70
+ [
71
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
72
+ for _ in range(params.depth_single_blocks)
73
+ ]
74
+ )
75
+
76
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
77
+
78
+ def forward(
79
+ self,
80
+ img: Tensor,
81
+ img_ids: Tensor,
82
+ txt: Tensor, # t5 text
83
+ txt_ids: Tensor,
84
+ timesteps: Tensor,
85
+ y: Tensor, # clip text
86
+ cur_step: int,
87
+ guidance: Tensor | None = None,
88
+ info = None,
89
+ ) -> Tensor:
90
+ if img.ndim != 3 or txt.ndim != 3:
91
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
92
+
93
+ # running on sequences img
94
+ img = self.img_in(img)
95
+ vec = self.time_in(timestep_embedding(timesteps, 256))
96
+ if self.params.guidance_embed:
97
+ if guidance is None:
98
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
99
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
100
+ vec = vec + self.vector_in(y)
101
+ txt = self.txt_in(txt)
102
+
103
+ ids = torch.cat((txt_ids, img_ids), dim=1)
104
+ pe = self.pe_embedder(ids)
105
+
106
+ for block in self.double_blocks:
107
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, cur_step=cur_step, info=info)
108
+
109
+ cnt = 0
110
+ img = torch.cat((txt, img), 1)
111
+ info['type'] = 'single'
112
+ for block in self.single_blocks:
113
+ info['id'] = cnt
114
+ img, info = block(img, vec=vec, pe=pe, info=info)
115
+ cnt += 1
116
+
117
+ img = img[:, txt.shape[1] :, ...]
118
+
119
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
120
+ return img, info
src/flux/modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (9.06 kB). View file
 
src/flux/modules/__pycache__/autoencoder.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
src/flux/modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
src/flux/modules/__pycache__/conditioner.cpython-312.pyc ADDED
Binary file (2.31 kB). View file
 
src/flux/modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
src/flux/modules/__pycache__/layers.cpython-312.pyc ADDED
Binary file (20.5 kB). View file
 
src/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ # import pdb;pdb.set_trace()
271
+ if self.sample:
272
+ std = torch.exp(0.5 * logvar)
273
+ return mean #+ std * torch.randn_like(mean)
274
+ else:
275
+ return mean
276
+
277
+
278
+ class AutoEncoder(nn.Module):
279
+ def __init__(self, params: AutoEncoderParams):
280
+ super().__init__()
281
+ self.encoder = Encoder(
282
+ resolution=params.resolution,
283
+ in_channels=params.in_channels,
284
+ ch=params.ch,
285
+ ch_mult=params.ch_mult,
286
+ num_res_blocks=params.num_res_blocks,
287
+ z_channels=params.z_channels,
288
+ )
289
+ self.decoder = Decoder(
290
+ resolution=params.resolution,
291
+ in_channels=params.in_channels,
292
+ ch=params.ch,
293
+ out_ch=params.out_ch,
294
+ ch_mult=params.ch_mult,
295
+ num_res_blocks=params.num_res_blocks,
296
+ z_channels=params.z_channels,
297
+ )
298
+ self.reg = DiagonalGaussian()
299
+
300
+ self.scale_factor = params.scale_factor
301
+ self.shift_factor = params.shift_factor
302
+
303
+ def encode(self, x: Tensor) -> Tensor:
304
+ z = self.reg(self.encoder(x))
305
+ z = self.scale_factor * (z - self.shift_factor)
306
+ return z
307
+
308
+ def decode(self, z: Tensor) -> Tensor:
309
+ z = z / self.scale_factor + self.shift_factor
310
+ return self.decoder(z)
311
+
312
+ def forward(self, x: Tensor) -> Tensor:
313
+ return self.decode(self.encode(x))
src/flux/modules/conditioner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+ import os
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = is_clip
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ #self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained("black-forest-labs/FLUX.1-dev/tokenizer_2", max_length=max_length)
19
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
20
+
21
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
22
+
23
+ def forward(self, text: list[str]) -> Tensor:
24
+ batch_encoding = self.tokenizer(
25
+ text,
26
+ truncation=True,
27
+ max_length=self.max_length,
28
+ return_length=False,
29
+ return_overflowing_tokens=False,
30
+ padding="max_length",
31
+ return_tensors="pt",
32
+ )
33
+
34
+ outputs = self.hf_module(
35
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
36
+ attention_mask=None,
37
+ output_hidden_states=False,
38
+ )
39
+ return outputs[self.output_key]
src/flux/modules/layers.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope, adaptive_attention
9
+
10
+ import os
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
+ t.device
42
+ )
43
+
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ if torch.is_floating_point(t):
49
+ embedding = embedding.to(t)
50
+ return embedding
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ def __init__(self, in_dim: int, hidden_dim: int):
55
+ super().__init__()
56
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
+ self.silu = nn.SiLU()
58
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ return self.out_layer(self.silu(self.in_layer(x)))
62
+
63
+
64
+ class RMSNorm(torch.nn.Module):
65
+ def __init__(self, dim: int):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: Tensor):
70
+ x_dtype = x.dtype
71
+ x = x.float()
72
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
+ return (x * rrms).to(dtype=x_dtype) * self.scale
74
+
75
+
76
+ class QKNorm(torch.nn.Module):
77
+ def __init__(self, dim: int):
78
+ super().__init__()
79
+ self.query_norm = RMSNorm(dim)
80
+ self.key_norm = RMSNorm(dim)
81
+
82
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
+ q = self.query_norm(q)
84
+ k = self.key_norm(k)
85
+ return q.to(v), k.to(v)
86
+
87
+
88
+ class SelfAttention(nn.Module):
89
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.norm = QKNorm(head_dim)
96
+ self.proj = nn.Linear(dim, dim)
97
+
98
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99
+ qkv = self.qkv(x)
100
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101
+ q, k = self.norm(q, k, v)
102
+ x = attention(q, k, v, pe=pe)
103
+ x = self.proj(x)
104
+ return x
105
+
106
+
107
+ @dataclass
108
+ class ModulationOut:
109
+ shift: Tensor
110
+ scale: Tensor
111
+ gate: Tensor
112
+
113
+
114
+ class Modulation(nn.Module):
115
+ def __init__(self, dim: int, double: bool):
116
+ super().__init__()
117
+ self.is_double = double
118
+ self.multiplier = 6 if double else 3
119
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
+
121
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
+
124
+ return (
125
+ ModulationOut(*out[:3]),
126
+ ModulationOut(*out[3:]) if self.is_double else None,
127
+ )
128
+
129
+
130
+ class DoubleStreamBlock(nn.Module):
131
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, cur_block:int, qkv_bias: bool = False):
132
+ super().__init__()
133
+
134
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
135
+ self.num_heads = num_heads
136
+ self.hidden_size = hidden_size
137
+ self.img_mod = Modulation(hidden_size, double=True)
138
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
+
141
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.img_mlp = nn.Sequential(
143
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144
+ nn.GELU(approximate="tanh"),
145
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146
+ )
147
+
148
+ self.txt_mod = Modulation(hidden_size, double=True)
149
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
+
152
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_mlp = nn.Sequential(
154
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155
+ nn.GELU(approximate="tanh"),
156
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157
+ )
158
+ self.cur_block = cur_block
159
+
160
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, cur_step: int, info) -> tuple[Tensor, Tensor]:
161
+
162
+ img_mod1, img_mod2 = self.img_mod(vec)
163
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
164
+
165
+ # prepare image for attention
166
+ img_modulated = self.img_norm1(img)
167
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
168
+ img_qkv = self.img_attn.qkv(img_modulated)
169
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
170
+
171
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
172
+
173
+ # prepare txt for attention
174
+ txt_modulated = self.txt_norm1(txt)
175
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
176
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
177
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
178
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
179
+
180
+ # run actual attention
181
+ q = torch.cat((txt_q, img_q), dim=2) #[8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
182
+ k = torch.cat((txt_k, img_k), dim=2)
183
+ v = torch.cat((txt_v, img_v), dim=2)
184
+ # import pdb;pdb.set_trace()
185
+
186
+ # if using adaptive attention guidance during samping
187
+ if not info['inverse'] and 'attn_guidance' in info['editing_strategy']:
188
+ attn = adaptive_attention(q, k, v, pe=pe, txt_shape=txt.shape[1], img_shape=img.shape[1], cur_step=cur_step, cur_block=self.cur_block, info=info)
189
+
190
+ attn = attention(q, k, v, pe=pe)
191
+
192
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
193
+
194
+ # calculate the img bloks
195
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
196
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
197
+
198
+ # calculate the txt bloks
199
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
200
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
201
+
202
+ return img, txt
203
+
204
+
205
+ class SingleStreamBlock(nn.Module):
206
+ """
207
+ A DiT block with parallel linear layers as described in
208
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ hidden_size: int,
214
+ num_heads: int,
215
+ mlp_ratio: float = 4.0,
216
+ qk_scale: float | None = None,
217
+ ):
218
+ super().__init__()
219
+ self.hidden_dim = hidden_size
220
+ self.num_heads = num_heads
221
+ head_dim = hidden_size // num_heads
222
+ self.scale = qk_scale or head_dim**-0.5
223
+
224
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
225
+ # qkv and mlp_in
226
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
227
+ # proj and mlp_out
228
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
229
+
230
+ self.norm = QKNorm(head_dim)
231
+
232
+ self.hidden_size = hidden_size
233
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
234
+
235
+ self.mlp_act = nn.GELU(approximate="tanh")
236
+ self.modulation = Modulation(hidden_size, double=False)
237
+
238
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
239
+ mod, _ = self.modulation(vec)
240
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
241
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
242
+
243
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
244
+ q, k = self.norm(q, k, v)
245
+
246
+ # Save the features in the memory
247
+ if info['inject'] and info['id'] <= info['end_layer_index'] and info['id'] >= info['start_layer_index']:
248
+ v_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V'
249
+ k_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'K'
250
+ q_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q'
251
+ if info['inverse']:
252
+ if info['reuse_v']:
253
+ info['feature'][v_feature_name] = v.cpu()
254
+ else:
255
+ editing_strategy = info['editing_strategy']
256
+ qkv_ratio = info['qkv_ratio']
257
+ if 'q' in editing_strategy:
258
+ info['feature'][q_feature_name] = (q * qkv_ratio[0]).cpu()
259
+ if 'k' in editing_strategy:
260
+ info['feature'][k_feature_name] = (k * qkv_ratio[1]).cpu()
261
+ if 'v' in editing_strategy:
262
+ info['feature'][v_feature_name] = (v * qkv_ratio[2]).cpu()
263
+ else:
264
+ if info['reuse_v']:
265
+ if v_feature_name in info['feature']:
266
+ v = info['feature'][v_feature_name].cuda()
267
+ else:
268
+ editing_strategy = info['editing_strategy']
269
+ if 'replace_v' in editing_strategy:
270
+ if v_feature_name in info['feature']:
271
+ v = info['feature'][v_feature_name].cuda()
272
+ if 'add_v' in editing_strategy:
273
+ if v_feature_name in info['feature']:
274
+ v += info['feature'][v_feature_name].cuda()
275
+ if 'replace_k' in editing_strategy:
276
+ if k_feature_name in info['feature']:
277
+ k = info['feature'][k_feature_name].cuda()
278
+ if 'add_k' in editing_strategy:
279
+ if k_feature_name in info['feature']:
280
+ k += info['feature'][k_feature_name].cuda()
281
+ if 'replace_q' in editing_strategy:
282
+ if q_feature_name in info['feature']:
283
+ q = info['feature'][q_feature_name].cuda()
284
+ if 'add_q' in editing_strategy:
285
+ if q_feature_name in info['feature']:
286
+ q += info['feature'][q_feature_name].cuda()
287
+
288
+ # compute attention
289
+ attn = attention(q, k, v, pe=pe)
290
+ # compute activation in mlp stream, cat again and run second linear layer
291
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
292
+ return x + mod.gate * output, info
293
+
294
+
295
+ class LastLayer(nn.Module):
296
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
297
+ super().__init__()
298
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
299
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
300
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
301
+
302
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
303
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
304
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
305
+ x = self.linear(x)
306
+ return x
src/flux/sampling.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional, Union, List, Dict, Any
3
+ import os
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import Tensor
9
+
10
+ from .model import Flux
11
+ from .modules.conditioner import HFEmbedder
12
+ from .modules.autoencoder import AutoEncoder
13
+
14
+
15
+
16
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
17
+ bs, c, h, w = img.shape
18
+ if bs == 1 and not isinstance(prompt, str):
19
+ bs = len(prompt)
20
+
21
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
22
+ if img.shape[0] == 1 and bs > 1:
23
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
24
+
25
+ img_ids = torch.zeros(h // 2, w // 2, 3)
26
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
27
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
28
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
29
+
30
+ if isinstance(prompt, str):
31
+ prompt = [prompt]
32
+ txt = t5(prompt)
33
+ if txt.shape[0] == 1 and bs > 1:
34
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
35
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
36
+
37
+ vec = clip(prompt)
38
+ if vec.shape[0] == 1 and bs > 1:
39
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
40
+
41
+ return {
42
+ "img": img,
43
+ "img_ids": img_ids.to(img.device),
44
+ "txt": txt.to(img.device),
45
+ "txt_ids": txt_ids.to(img.device),
46
+ "vec": vec.to(img.device),
47
+ }
48
+
49
+
50
+ def prepare_image(img: Tensor):
51
+ bs, c, h, w = img.shape
52
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
53
+ if img.shape[0] == 1 and bs > 1:
54
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
55
+
56
+ return img
57
+
58
+
59
+ def time_shift(mu: float, sigma: float, t: Tensor):
60
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
61
+
62
+
63
+ def get_lin_function(
64
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
65
+ ) -> Callable[[float], float]:
66
+ m = (y2 - y1) / (x2 - x1)
67
+ b = y1 - m * x1
68
+ return lambda x: m * x + b
69
+
70
+
71
+ def get_noise(
72
+ num_samples: int,
73
+ height: int,
74
+ width: int,
75
+ device: torch.device,
76
+ dtype: torch.dtype,
77
+ seed: int,
78
+ ):
79
+ return torch.randn(
80
+ num_samples,
81
+ 16,
82
+ # allow for packing
83
+ 2 * math.ceil(height / 16),
84
+ 2 * math.ceil(width / 16),
85
+ device=device,
86
+ dtype=dtype,
87
+ generator=torch.Generator(device=device).manual_seed(seed),
88
+ )
89
+
90
+
91
+
92
+ def get_schedule(
93
+ num_steps: int,
94
+ image_seq_len: int,
95
+ base_shift: float = 0.5,
96
+ max_shift: float = 1.15,
97
+ shift: bool = True,
98
+ ) -> list[float]:
99
+ # extra step for zero
100
+ timesteps = torch.linspace(1, 0, num_steps + 1)
101
+
102
+ # shifting the schedule to favor high timesteps for higher signal images
103
+ if shift:
104
+ # estimate mu based on linear estimation between two points
105
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
106
+ timesteps = time_shift(mu, 1.0, timesteps)
107
+
108
+ return timesteps.tolist()
109
+
110
+
111
+ def denoise_rf(
112
+ model: Flux,
113
+ # model input
114
+ img: Tensor,
115
+ img_ids: Tensor,
116
+ txt: Tensor,
117
+ txt_ids: Tensor,
118
+ vec: Tensor,
119
+ # sampling parameters
120
+ timesteps: list[float],
121
+ inverse,
122
+ info,
123
+ guidance: float = 4.0
124
+ ):
125
+ # this is ignored for schnell
126
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
127
+
128
+ if inverse:
129
+ timesteps = timesteps[::-1]
130
+ inject_list = inject_list[::-1]
131
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
132
+
133
+ step_list = []
134
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
135
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
136
+ info['t'] = t_prev if inverse else t_curr
137
+ info['inverse'] = inverse
138
+ info['second_order'] = False
139
+ info['inject'] = inject_list[i]
140
+
141
+ pred, info = model(
142
+ img=img,
143
+ img_ids=img_ids,
144
+ txt=txt,
145
+ txt_ids=txt_ids,
146
+ y=vec,
147
+ timesteps=t_vec,
148
+ guidance=guidance_vec,
149
+ info=info,
150
+ cur_step = i
151
+ )
152
+ img = img + (t_prev - t_curr) * pred
153
+
154
+ return img, info
155
+
156
+
157
+ def denoise_rf_solver(
158
+ model: Flux,
159
+ # model input
160
+ img: Tensor,
161
+ img_ids: Tensor,
162
+ txt: Tensor,
163
+ txt_ids: Tensor,
164
+ vec: Tensor,
165
+ # sampling parameters
166
+ timesteps: list[float],
167
+ inverse,
168
+ info,
169
+ guidance: float = 4.0,
170
+ img_ori: Optional[Tensor] = None
171
+ ):
172
+ # this is ignored for schnell
173
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
174
+
175
+ if inverse:
176
+ timesteps = timesteps[::-1]
177
+ inject_list = inject_list[::-1]
178
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
179
+
180
+ step_list = []
181
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
182
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
183
+ info['t'] = t_prev if inverse else t_curr
184
+ info['inverse'] = inverse
185
+ info['second_order'] = False
186
+ info['inject'] = inject_list[i]
187
+
188
+ pred, info = model(
189
+ img=img,
190
+ img_ids=img_ids,
191
+ txt=txt,
192
+ txt_ids=txt_ids,
193
+ y=vec,
194
+ timesteps=t_vec,
195
+ guidance=guidance_vec,
196
+ info=info,
197
+ cur_step = i
198
+ )
199
+
200
+ img_mid = img + (t_prev - t_curr) / 2 * pred
201
+
202
+ t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
203
+ info['second_order'] = True
204
+ pred_mid, info = model(
205
+ img=img_mid,
206
+ img_ids=img_ids,
207
+ txt=txt,
208
+ txt_ids=txt_ids,
209
+ y=vec,
210
+ timesteps=t_vec_mid,
211
+ guidance=guidance_vec,
212
+ info=info,
213
+ cur_step = i
214
+ )
215
+
216
+ first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
217
+ img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
218
+
219
+ return img, info
220
+
221
+
222
+ def denoise_fireflow(
223
+ model: Flux,
224
+ # model input
225
+ img: Tensor,
226
+ img_ids: Tensor,
227
+ txt: Tensor,
228
+ txt_ids: Tensor,
229
+ vec: Tensor,
230
+ # sampling parameters
231
+ timesteps: list[float],
232
+ inverse,
233
+ info,
234
+ guidance: float = 4.0,
235
+ img_ori: Optional[Tensor] = None,
236
+ ae: Optional[AutoEncoder] = None, # Optional AutoEncoder for decoding
237
+ device: Optional[Union[str, torch.device]] = None # Optional device specification
238
+ ):
239
+ # this is ignored for schnell
240
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
241
+
242
+ if inverse:
243
+ timesteps = timesteps[::-1]
244
+ inject_list = inject_list[::-1]
245
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
246
+
247
+ step_list = []
248
+ next_step_velocity = None
249
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
250
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
251
+ info['t'] = t_prev if inverse else t_curr
252
+ info['inverse'] = inverse
253
+ info['second_order'] = False
254
+ info['inject'] = inject_list[i]
255
+
256
+ if next_step_velocity is None:
257
+ pred, info = model(
258
+ img=img,
259
+ img_ids=img_ids,
260
+ txt=txt,
261
+ txt_ids=txt_ids,
262
+ y=vec,
263
+ timesteps=t_vec,
264
+ guidance=guidance_vec,
265
+ info=info,
266
+ cur_step=i
267
+ )
268
+ else:
269
+ pred = next_step_velocity
270
+
271
+ img_mid = img + (t_prev - t_curr) / 2 * pred
272
+
273
+ t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
274
+ info['second_order'] = True
275
+ pred_mid, info = model(
276
+ img=img_mid,
277
+ img_ids=img_ids,
278
+ txt=txt,
279
+ txt_ids=txt_ids,
280
+ y=vec,
281
+ timesteps=t_vec_mid,
282
+ guidance=guidance_vec,
283
+ info=info,
284
+ cur_step=i
285
+ )
286
+ next_step_velocity = pred_mid
287
+
288
+ img = img + (t_prev - t_curr) * pred_mid
289
+
290
+ ########################### save generating steps ##############################
291
+ #idx = len(timesteps) - 1
292
+ #fn = f'result/intermediate_{idx}steps'
293
+ #if not os.path.exists(fn):
294
+ #os.makedirs(fn)
295
+ #fn += f'/fireflow_{t_prev}.jpg'
296
+ #if inverse:
297
+ #fn = f'result/intermediate_{idx}steps/inverse_fireflow_{t_prev}.jpg'
298
+
299
+ # decode latents to pixel space
300
+ #x = unpack(img.float(), img.shape[1] ** 0.5 * 16, img.shape[1] ** 0.5 * 16)
301
+
302
+ #with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
303
+ #x = ae.decode(x)
304
+
305
+ # bring into PIL format and save
306
+ #x = x.clamp(-1, 1)
307
+ #x = rearrange(x[0], "c h w -> h w c")
308
+ #x = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
309
+ #x.save(fn)
310
+ ########################### save generating steps ##############################
311
+
312
+ return img, info
313
+
314
+
315
+ def denoise_midpoint(
316
+ model: Flux,
317
+ # model input
318
+ img: Tensor,
319
+ img_ids: Tensor,
320
+ txt: Tensor,
321
+ txt_ids: Tensor,
322
+ vec: Tensor,
323
+ # sampling parameters
324
+ timesteps: list[float],
325
+ inverse,
326
+ info,
327
+ guidance: float = 4.0
328
+ ):
329
+ # this is ignored for schnell
330
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
331
+
332
+ if inverse:
333
+ timesteps = timesteps[::-1]
334
+ inject_list = inject_list[::-1]
335
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
336
+
337
+ step_list = []
338
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
339
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
340
+ info['t'] = t_prev if inverse else t_curr
341
+ info['inverse'] = inverse
342
+ info['second_order'] = False
343
+ info['inject'] = inject_list[i]
344
+
345
+ pred, info = model(
346
+ img=img,
347
+ img_ids=img_ids,
348
+ txt=txt,
349
+ txt_ids=txt_ids,
350
+ y=vec,
351
+ timesteps=t_vec,
352
+ guidance=guidance_vec,
353
+ info=info
354
+ )
355
+
356
+ img_mid = img + (t_prev - t_curr) / 2 * pred
357
+
358
+ t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
359
+ info['second_order'] = True
360
+ pred_mid, info = model(
361
+ img=img_mid,
362
+ img_ids=img_ids,
363
+ txt=txt,
364
+ txt_ids=txt_ids,
365
+ y=vec,
366
+ timesteps=t_vec_mid,
367
+ guidance=guidance_vec,
368
+ info=info
369
+ )
370
+ next_step_velocity = pred_mid
371
+
372
+ img = img + (t_prev - t_curr) * pred_mid
373
+
374
+ return img, info
375
+
376
+
377
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
378
+ return rearrange(
379
+ x,
380
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
381
+ h=math.ceil(height / 16),
382
+ w=math.ceil(width / 16),
383
+ ph=2,
384
+ pw=2,
385
+ )
386
+
387
+
388
+ def denoise_rf_inversion(
389
+ model: Flux,
390
+ # model input
391
+ img: Tensor,
392
+ img_ids: Tensor,
393
+ txt: Tensor,
394
+ txt_ids: Tensor,
395
+ vec: Tensor,
396
+ # sampling parameters
397
+ timesteps: list[float],
398
+ inverse,
399
+ info,
400
+ guidance: float = 4.0,
401
+ stop_timestep: float = 0.35,
402
+ img_LQR: Dict = {"source img": None, "prev img": None}
403
+ ):
404
+ # this is ignored for schnell
405
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
406
+
407
+ gamma_steps = int(stop_timestep * len(timesteps[:-1]))
408
+ #gamma_steps = 9
409
+ gamma = [0.9] * gamma_steps + [0] * (len(timesteps[:-1]) - gamma_steps) # γ ∈ [0, 1] the controller guidance, γ can be time-varying
410
+
411
+ if inverse:
412
+ # todo if inverse, text prompt is φ
413
+ timesteps = timesteps[::-1]
414
+ inject_list = inject_list[::-1]
415
+ gamma = [0.5] * len(timesteps[:-1]) # γ ∈ [0, 1] the controller guidance
416
+
417
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
418
+
419
+ step_list = []
420
+ y1 = torch.randn(img.shape, device=img.device, dtype=img.dtype)
421
+
422
+ y0, y_prev = None, None
423
+ if img_LQR['source img'] is not None:
424
+ y0 = img_LQR['source img'].to(img.device)
425
+ if img_LQR['prev img'] is not None:
426
+ y_prev = img_LQR['prev img'].to(img.device)
427
+
428
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
429
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
430
+ info['t'] = t_prev if inverse else t_curr
431
+ info['inverse'] = inverse
432
+ info['second_order'] = False
433
+ info['inject'] = inject_list[i]
434
+
435
+ pred, info = model(
436
+ img=img,
437
+ img_ids=img_ids,
438
+ txt=txt,
439
+ txt_ids=txt_ids,
440
+ y=vec,
441
+ timesteps=t_vec,
442
+ guidance=guidance_vec,
443
+ info=info,
444
+ cur_step=i
445
+ )
446
+
447
+ # 6. Unconditional Vector field uti(Yti) = u(Yti, ti, Φ(“”); φ)
448
+ unconditional_vector_field = pred
449
+ if not inverse:
450
+ unconditional_vector_field = -unconditional_vector_field
451
+
452
+ if inverse:
453
+ # 7.Conditional Vector field uti(Yti|y1) = (y1−Yti)/1−ti
454
+ conditional_vector_field = (y1 - img) / (1 - t_curr)
455
+ else:
456
+ # 7.Conditional Vector field uti(Xti|y0) = (y0−Xti)/(1−ti)
457
+ t_i = i / len(timesteps[:-1]) # Empiracally better results
458
+ #conditional_vector_field = (y0 - img) / t_curr
459
+ if y_prev is None:
460
+ conditional_vector_field = (y0 - img) / (1 - t_i)
461
+ else:
462
+ #conditional_vector_field = (y_prev - img) / (1 - t_i)
463
+ conditional_vector_field = (y0 - img) / (1 - t_i) + 0.7 * ((y_prev - img) / (1 - t_i) - (y0 - img) / (1 - t_i))
464
+
465
+ # 8. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti))
466
+ controlled_vector_field = unconditional_vector_field + gamma[i] * (conditional_vector_field - unconditional_vector_field)
467
+
468
+ # 9. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti))
469
+ delta_t = t_prev - t_curr
470
+ if delta_t < 0:
471
+ delta_t = t_curr - t_prev
472
+ img = img + delta_t * controlled_vector_field
473
+
474
+ return img, info
475
+
476
+
477
+ def denoise_multi_turn_consistent(
478
+ model: Flux,
479
+ # model input
480
+ img: Tensor,
481
+ img_ids: Tensor,
482
+ txt: Tensor,
483
+ txt_ids: Tensor,
484
+ vec: Tensor,
485
+ # sampling parameters
486
+ timesteps: list[float],
487
+ inverse,
488
+ info,
489
+ guidance: float = 4.0,
490
+ #img_ori: Optional[Tensor] = None
491
+ img_LQR: Dict = {"source img": None, "prev img": None}
492
+ ):
493
+ # this is ignored for schnell
494
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
495
+
496
+ gamma_steps = int(info['lqr_stop'] * len(timesteps[:-1]))
497
+ #gamma_steps = 9
498
+ gamma = [0.9] * gamma_steps + [0] * (len(timesteps[:-1]) - gamma_steps) # γ ∈ [0, 1] the controller guidance, γ can be time-varying
499
+
500
+ if inverse:
501
+ # todo if inverse, text prompt is φ
502
+ timesteps = timesteps[::-1]
503
+ inject_list = inject_list[::-1]
504
+ gamma = [0.5] * len(timesteps[:-1]) # γ ∈ [0, 1] the controller guidance
505
+
506
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
507
+
508
+ step_list = []
509
+ y1 = torch.randn(img.shape, device=img.device, dtype=img.dtype)
510
+
511
+ y0, y_prev = None, None
512
+ if img_LQR['source img'] is not None:
513
+ y0 = img_LQR['source img'].to(img.device)
514
+ if img_LQR['prev img'] is not None:
515
+ y_prev = img_LQR['prev img'].to(img.device)
516
+
517
+ next_step_velocity = None
518
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
519
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
520
+ info['t'] = t_prev if inverse else t_curr
521
+ info['inverse'] = inverse
522
+ info['second_order'] = False
523
+ info['inject'] = inject_list[i]
524
+
525
+ if next_step_velocity is None:
526
+ pred, info = model(
527
+ img=img,
528
+ img_ids=img_ids,
529
+ txt=txt,
530
+ txt_ids=txt_ids,
531
+ y=vec,
532
+ timesteps=t_vec,
533
+ guidance=guidance_vec,
534
+ info=info,
535
+ cur_step=i
536
+ )
537
+ else:
538
+ pred = next_step_velocity
539
+
540
+ img_mid = img + (t_prev - t_curr) / 2 * pred
541
+
542
+ t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
543
+ info['second_order'] = True
544
+ pred_mid, info = model(
545
+ img=img_mid,
546
+ img_ids=img_ids,
547
+ txt=txt,
548
+ txt_ids=txt_ids,
549
+ y=vec,
550
+ timesteps=t_vec_mid,
551
+ guidance=guidance_vec,
552
+ info=info,
553
+ cur_step=i
554
+ )
555
+ next_step_velocity = pred_mid
556
+
557
+ # 6. Unconditional Vector field uti(Yti) = u(Yti, ti, Φ(“”); φ)
558
+ unconditional_vector_field = pred_mid
559
+ if not inverse:
560
+ unconditional_vector_field = -unconditional_vector_field
561
+
562
+ if inverse:
563
+ # 7.Conditional Vector field uti(Yti|y1) = (y1−Yti)/(1−ti)
564
+ conditional_vector_field = (y1 - img) / (1 - t_curr + (t_prev - t_curr) / 2)
565
+ else:
566
+ # 7.Conditional Vector field uti(Xti|y0) = (y0−Xti)/(1−ti)
567
+ t_i = i / len(timesteps[:-1]) # Empiracally better results
568
+ #conditional_vector_field = (y0 - img) / t_curr
569
+ if y_prev is None:
570
+ conditional_vector_field = (y0 - img) / (1 - t_i)
571
+ else:
572
+ conditional_vector_field = (y0 - img) / (1 - t_i) + 0.7 * ((y_prev - img) / (1 - t_i) - (y0 - img) / (1 - t_i))
573
+ #conditional_vector_field = (y_prev - img) / (1 - t_i)
574
+
575
+ # 8. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti))
576
+ controlled_vector_field = unconditional_vector_field + gamma[i] * (conditional_vector_field - unconditional_vector_field)
577
+
578
+ # 9. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti))
579
+ delta_t = t_prev - t_curr
580
+ if delta_t < 0:
581
+ delta_t = t_curr - t_prev
582
+ img = img + delta_t * controlled_vector_field
583
+
584
+ return img, info
src/flux/util.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import pickle
5
+ import torch
6
+ from einops import rearrange
7
+ from huggingface_hub import hf_hub_download
8
+ from imwatermark import WatermarkEncoder
9
+ from safetensors.torch import load_file as load_sft
10
+
11
+ from flux.model import Flux, FluxParams
12
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
13
+ from flux.modules.conditioner import HFEmbedder
14
+ from flux.sampling import unpack
15
+
16
+
17
+ @dataclass
18
+ class ModelSpec:
19
+ params: FluxParams
20
+ ae_params: AutoEncoderParams
21
+ ckpt_path: str | None
22
+ ae_path: str | None
23
+ repo_id: str | None
24
+ repo_flow: str | None
25
+ repo_ae: str | None
26
+
27
+ configs = {
28
+ "flux-dev": ModelSpec(
29
+ repo_id="black-forest-labs/FLUX.1-dev",
30
+ repo_flow="flux1-dev.safetensors",
31
+ repo_ae="ae.safetensors",
32
+ ckpt_path=os.getenv("FLUX_DEV"),
33
+ params=FluxParams(
34
+ in_channels=64,
35
+ vec_in_dim=768,
36
+ context_in_dim=4096,
37
+ hidden_size=3072,
38
+ mlp_ratio=4.0,
39
+ num_heads=24,
40
+ depth=19,
41
+ depth_single_blocks=38,
42
+ axes_dim=[16, 56, 56],
43
+ theta=10_000,
44
+ qkv_bias=True,
45
+ guidance_embed=True,
46
+ ),
47
+ ae_path=os.getenv("AE"),
48
+ ae_params=AutoEncoderParams(
49
+ resolution=256,
50
+ in_channels=3,
51
+ ch=128,
52
+ out_ch=3,
53
+ ch_mult=[1, 2, 4, 4],
54
+ num_res_blocks=2,
55
+ z_channels=16,
56
+ scale_factor=0.3611,
57
+ shift_factor=0.1159,
58
+ ),
59
+ ),
60
+ "flux-schnell": ModelSpec(
61
+ repo_id="black-forest-labs/FLUX.1-schnell",
62
+ repo_flow="flux1-schnell.safetensors",
63
+ repo_ae="ae.safetensors",
64
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
65
+ params=FluxParams(
66
+ in_channels=64,
67
+ vec_in_dim=768,
68
+ context_in_dim=4096,
69
+ hidden_size=3072,
70
+ mlp_ratio=4.0,
71
+ num_heads=24,
72
+ depth=19,
73
+ depth_single_blocks=38,
74
+ axes_dim=[16, 56, 56],
75
+ theta=10_000,
76
+ qkv_bias=True,
77
+ guidance_embed=False,
78
+ ),
79
+ ae_path=os.getenv("AE"),
80
+ ae_params=AutoEncoderParams(
81
+ resolution=256,
82
+ in_channels=3,
83
+ ch=128,
84
+ out_ch=3,
85
+ ch_mult=[1, 2, 4, 4],
86
+ num_res_blocks=2,
87
+ z_channels=16,
88
+ scale_factor=0.3611,
89
+ shift_factor=0.1159,
90
+ ),
91
+ ),
92
+ }
93
+
94
+
95
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
96
+ if len(missing) > 0 and len(unexpected) > 0:
97
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
98
+ print("\n" + "-" * 79 + "\n")
99
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
100
+ elif len(missing) > 0:
101
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
102
+ elif len(unexpected) > 0:
103
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
104
+
105
+
106
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
107
+ # Loading Flux
108
+ print("Init model")
109
+
110
+ ckpt_path = configs[name].ckpt_path
111
+ if (
112
+ ckpt_path is None
113
+ and configs[name].repo_id is not None
114
+ and configs[name].repo_flow is not None
115
+ and hf_download
116
+ ):
117
+ #ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
118
+ ckpt_path = "/homedata/HuggingFace/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"
119
+
120
+ with torch.device("meta" if ckpt_path is not None else device):
121
+ model = Flux(configs[name].params).to(torch.bfloat16)
122
+
123
+ if ckpt_path is not None:
124
+ print("Loading checkpoint")
125
+ # load_sft doesn't support torch.device
126
+ sd = load_sft(ckpt_path, device=str(device))
127
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
128
+ print_load_warning(missing, unexpected)
129
+ return model
130
+
131
+
132
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
133
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
134
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
135
+ #return HFEmbedder("/homedata/HuggingFace/black-forest-labs/FLUX.1-dev/text_encoder_2", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
136
+
137
+
138
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
139
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
140
+
141
+
142
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
143
+ ckpt_path = configs[name].ae_path
144
+ if (
145
+ ckpt_path is None
146
+ and configs[name].repo_id is not None
147
+ and configs[name].repo_ae is not None
148
+ and hf_download
149
+ ):
150
+ #ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
151
+ ckpt_path = "/homedata/HuggingFace/black-forest-labs/FLUX.1-dev/ae.safetensors"
152
+
153
+ # Loading the autoencoder
154
+ print("Init AE")
155
+ with torch.device("meta" if ckpt_path is not None else device):
156
+ ae = AutoEncoder(configs[name].ae_params)
157
+
158
+ if ckpt_path is not None:
159
+ sd = load_sft(ckpt_path, device=str(device))
160
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
161
+ print_load_warning(missing, unexpected)
162
+ return ae
163
+
164
+
165
+ class WatermarkEmbedder:
166
+ def __init__(self, watermark):
167
+ self.watermark = watermark
168
+ self.num_bits = len(WATERMARK_BITS)
169
+ self.encoder = WatermarkEncoder()
170
+ self.encoder.set_watermark("bits", self.watermark)
171
+
172
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Adds a predefined watermark to the input image
175
+
176
+ Args:
177
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
178
+
179
+ Returns:
180
+ same as input but watermarked
181
+ """
182
+ image = 0.5 * image + 0.5
183
+ squeeze = len(image.shape) == 4
184
+ if squeeze:
185
+ image = image[None, ...]
186
+ n = image.shape[0]
187
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
188
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
189
+ # watermarking libary expects input as cv2 BGR format
190
+ for k in range(image_np.shape[0]):
191
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
192
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
193
+ image.device
194
+ )
195
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
196
+ if squeeze:
197
+ image = image[0]
198
+ image = 2 * image - 1
199
+ return image
200
+
201
+ def save_velocity_distribution(info, prefix=""):
202
+ velocity_list = info['velocity']
203
+ pkl_file_name = prefix + "_velocity.pkl"
204
+ with open(pkl_file_name, "wb") as f:
205
+ pickle.dump(velocity_list, f)
206
+ print("List saved to " + pkl_file_name)
207
+
208
+ # A fixed 48-bit message that was chosen at random
209
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
210
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
211
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
212
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
src/gradio_examples/000000000011.jpg ADDED

Git LFS Details

  • SHA256: 86e46a552bb5d7c6fa2dff8c3b00b4f0536beb1e2bc5346529e5c0af5013c483
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
src/gradio_examples/221000000002.jpg ADDED

Git LFS Details

  • SHA256: 81e88f4b4c7211b9dd8e29aaa7f31e38e474f148dcc8d47ba772a7b3c59bbbb1
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
src/gradio_utils/gradio_utils.py ADDED
File without changes