asd755 commited on
Commit
70194df
·
verified ·
1 Parent(s): 15e9f7e

Upload 3 files

Browse files
Files changed (3) hide show
  1. pa_src/attn_processor.py +443 -0
  2. pa_src/pipeline.py +1272 -0
  3. pa_src/utils.py +106 -0
pa_src/attn_processor.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from diffusers.models.attention_processor import *
5
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
6
+
7
+ def default_set_attn_proc_func(
8
+ name: str,
9
+ hidden_size: int,
10
+ cross_attention_dim: Optional[int],
11
+ ori_attn_proc: object,
12
+ ) -> object:
13
+ return ori_attn_proc
14
+
15
+ def set_flux_transformer_attn_processor(
16
+ transformer: FluxTransformer2DModel,
17
+ set_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_attn_module_names: Optional[list[str]] = None,
19
+ ) -> None:
20
+ do_set_processor = lambda name, module_names: (
21
+ any([name.startswith(module_name) for module_name in module_names])
22
+ if module_names is not None
23
+ else True
24
+ ) # prefix match
25
+
26
+ attn_procs = {}
27
+ for name, attn_processor in transformer.attn_processors.items():
28
+ dim_head = transformer.config.attention_head_dim
29
+ num_heads = transformer.config.num_attention_heads
30
+ if name.endswith("attn.processor"):
31
+ attn_procs[name] = (
32
+ set_attn_proc_func(name, dim_head, num_heads, attn_processor)
33
+ if do_set_processor(name, set_attn_module_names)
34
+ else attn_processor
35
+ )
36
+
37
+ transformer.set_attn_processor(attn_procs)
38
+
39
+ class PersonalizeAnythingAttnProcessor:
40
+
41
+ def __init__(self, name, mask, device, tau=0.98, concept_process=False, shift_mask = None, img_dims=4096):
42
+ if not hasattr(F, "scaled_dot_product_attention"):
43
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
44
+
45
+ self.name = name
46
+ self.mask = mask.view(img_dims).bool().to(device)
47
+ self.device = device
48
+ self.tau = tau
49
+ self.concept_process = concept_process
50
+ self.img_dims = img_dims
51
+
52
+ if shift_mask is None:
53
+ self.shift_mask = self.mask
54
+ else:
55
+ self.shift_mask = shift_mask.view(img_dims).bool().to(device)
56
+
57
+ def __call__(
58
+ self,
59
+ attn: Attention,
60
+ hidden_states: torch.FloatTensor,
61
+ encoder_hidden_states: torch.FloatTensor = None,
62
+ attention_mask: Optional[torch.FloatTensor] = None,
63
+ image_rotary_emb: Optional[torch.Tensor] = None,
64
+ timestep = None,
65
+ ) -> torch.FloatTensor:
66
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
67
+
68
+ ###################################################################################
69
+ if timestep is not None:
70
+ timestep = timestep
71
+
72
+ concept_process = self.concept_process # token concatenation
73
+ c_q = concept_process and True # if token concatenation is applied to q
74
+ c_kv = concept_process and True # if token concatenation is applied to kv
75
+
76
+ t_flag = timestep > self.tau # token replacement
77
+ r_q = True and t_flag # if token concatenation is applied to q
78
+ r_k = True and t_flag # if token concatenation is applied to k
79
+ r_v = True and t_flag # if token concatenation is applied to v
80
+
81
+ if encoder_hidden_states is not None:
82
+ concept_feature_ = hidden_states[0, self.mask, :]
83
+ else:
84
+ concept_feature_ = hidden_states[0, 512:, :][self.mask, :]
85
+
86
+ if r_k or r_q or r_v:
87
+ r_hidden_states = hidden_states
88
+ if encoder_hidden_states is not None:
89
+ r_hidden_states[1, self.shift_mask, :] = concept_feature_
90
+ else:
91
+ text_hidden_states = hidden_states[1, :512, :]
92
+ image_hidden_states = hidden_states[1, 512:, :]
93
+ image_hidden_states[self.shift_mask, :] = concept_feature_
94
+
95
+ r_hidden_states[1] = torch.cat([text_hidden_states, image_hidden_states], dim=0)
96
+ ###################################################################################
97
+
98
+ key = attn.to_k(hidden_states)
99
+ value = attn.to_v(hidden_states)
100
+ query = attn.to_q(hidden_states)
101
+
102
+ ###################################################################################
103
+ if r_k:
104
+ key = attn.to_k(r_hidden_states)
105
+ if r_q:
106
+ query = attn.to_q(r_hidden_states)
107
+ if r_v:
108
+ value = attn.to_v(r_hidden_states)
109
+
110
+ if concept_process:
111
+ if c_q:
112
+ c_query = attn.to_q(concept_feature_)
113
+ c_query = c_query.repeat(query.shape[0], 1, 1)
114
+ query = torch.cat([query, c_query], dim=1)
115
+ if c_kv:
116
+ c_key = attn.to_k(concept_feature_)
117
+ c_key = c_key.repeat(key.shape[0], 1, 1)
118
+ c_value = attn.to_v(concept_feature_)
119
+ c_value = c_value.repeat(value.shape[0], 1, 1)
120
+ key = torch.cat([key, c_key], dim=1)
121
+ value = torch.cat([value, c_value], dim=1)
122
+ ###################################################################################
123
+
124
+ inner_dim = key.shape[-1]
125
+ head_dim = inner_dim // attn.heads
126
+
127
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
128
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
129
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
130
+
131
+ if attn.norm_q is not None:
132
+ query = attn.norm_q(query)
133
+ if attn.norm_k is not None:
134
+ key = attn.norm_k(key)
135
+
136
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
137
+ if encoder_hidden_states is not None:
138
+ # `context` projections.
139
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
140
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
141
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
142
+
143
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
144
+ batch_size, -1, attn.heads, head_dim
145
+ ).transpose(1, 2)
146
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
147
+ batch_size, -1, attn.heads, head_dim
148
+ ).transpose(1, 2)
149
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
150
+ batch_size, -1, attn.heads, head_dim
151
+ ).transpose(1, 2)
152
+
153
+ if attn.norm_added_q is not None:
154
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
155
+ if attn.norm_added_k is not None:
156
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
157
+
158
+ # attention
159
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
160
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
161
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
162
+
163
+
164
+ if image_rotary_emb is not None:
165
+ from diffusers.models.embeddings import apply_rotary_emb
166
+
167
+ # use original position emb or text emb
168
+ if not c_q:
169
+ query = apply_rotary_emb(query, image_rotary_emb)
170
+ if not c_kv:
171
+ key = apply_rotary_emb(key, image_rotary_emb)
172
+
173
+ ###################################################################################
174
+ # get original position emb
175
+ def get_concept_rotary_emb(ori_rotary_emb, mask):
176
+ enc_emb = ori_rotary_emb[:512, :]
177
+ hid_emb = ori_rotary_emb[512:, :]
178
+ concept_emb = hid_emb[mask, :]
179
+
180
+ image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0)
181
+ return image_rotary_emb
182
+
183
+ if concept_process:
184
+ # 1. use original position emb
185
+ image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_mask)
186
+ image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_mask)
187
+ image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1)
188
+
189
+ # 2. use text emb
190
+ # dims = (self.mask == 1).sum().item()
191
+ # concept_rotary_emb_0 = torch.ones((dims, 128)).to(self.device)
192
+ # concept_rotary_emb_1 = torch.zeros((dims, 128)).to(self.device)
193
+ # image_rotary_emb = (
194
+ # torch.cat([image_rotary_emb[0], concept_rotary_emb_0], dim=0),
195
+ # torch.cat([image_rotary_emb[1], concept_rotary_emb_1], dim=0))
196
+
197
+ if c_q:
198
+ query = apply_rotary_emb(query, image_rotary_emb)
199
+ if c_kv:
200
+ key = apply_rotary_emb(key, image_rotary_emb)
201
+ ###################################################################################
202
+
203
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
204
+
205
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
206
+ hidden_states = hidden_states.to(query.dtype)
207
+
208
+ if encoder_hidden_states is not None:
209
+ encoder_hidden_states, hidden_states = (
210
+ hidden_states[:, : encoder_hidden_states.shape[1]],
211
+ hidden_states[:, encoder_hidden_states.shape[1] :],
212
+ )
213
+
214
+ # linear proj
215
+ hidden_states = attn.to_out[0](hidden_states)
216
+ # dropout
217
+ hidden_states = attn.to_out[1](hidden_states)
218
+
219
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
220
+
221
+ ################################################################
222
+ # restore after token concatenation
223
+ hidden_states = hidden_states[:, :self.img_dims, :]
224
+ ################################################################
225
+
226
+ return hidden_states, encoder_hidden_states
227
+ else:
228
+ ################################################################
229
+ dims = self.img_dims + 512
230
+ hidden_states = hidden_states[:, :dims, :]
231
+ ################################################################
232
+
233
+ return hidden_states
234
+
235
+ class MultiPersonalizeAnythingAttnProcessor:
236
+
237
+ def __init__(self, name, masks, device, tau=0.98, concept_process=False, shift_masks = None, img_dims=4096):
238
+ if not hasattr(F, "scaled_dot_product_attention"):
239
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
240
+
241
+ self.name = name
242
+ self.device = device
243
+ self.tau = tau
244
+ self.concept_process = concept_process
245
+ self.img_dims = img_dims
246
+
247
+ for i in range(len(masks)):
248
+ masks[i] = masks[i].view(img_dims).bool().to(device)
249
+ self.masks = masks
250
+
251
+ if shift_masks is None:
252
+ self.shift_masks = self.masks
253
+ else:
254
+ for i in range(len(shift_masks)):
255
+ shift_masks[i] = shift_masks[i].view(img_dims).bool().to(device)
256
+ self.shift_masks = shift_masks
257
+
258
+ def __call__(
259
+ self,
260
+ attn: Attention,
261
+ hidden_states: torch.FloatTensor,
262
+ encoder_hidden_states: torch.FloatTensor = None,
263
+ attention_mask: Optional[torch.FloatTensor] = None,
264
+ image_rotary_emb: Optional[torch.Tensor] = None,
265
+ timestep = None,
266
+ ) -> torch.FloatTensor:
267
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
268
+
269
+ ###################################################################################
270
+ if timestep is not None:
271
+ timestep = timestep
272
+
273
+ concept_process = self.concept_process # token concatenation
274
+ c_q = concept_process and True # if token concatenation is applied to q
275
+ c_kv = concept_process and True # if token concatenation is applied to kv
276
+
277
+ t_flag = timestep > self.tau # token replacement
278
+ r_q = True and t_flag # if token concatenation is applied to q
279
+ r_k = True and t_flag # if token concatenation is applied to k
280
+ r_v = True and t_flag # if token concatenation is applied to v
281
+
282
+ concept_features = []
283
+ r_hidden_states = hidden_states
284
+ for id, mask in enumerate(self.masks):
285
+ if encoder_hidden_states is not None:
286
+ concept_feature_ = hidden_states[id, mask, :]
287
+ else:
288
+ concept_feature_ = hidden_states[id, 512:, :][mask, :]
289
+
290
+ shift_mask = self.shift_masks[id]
291
+ concept_features.append(concept_feature_)
292
+
293
+ if r_k or r_q or r_v:
294
+ if encoder_hidden_states is not None:
295
+ r_hidden_states[-1, shift_mask, :] = concept_feature_
296
+ else:
297
+ text_hidden_states = r_hidden_states[-1, :512, :]
298
+ image_hidden_states = r_hidden_states[-1, 512:, :]
299
+ image_hidden_states[shift_mask, :] = concept_feature_
300
+ r_hidden_states[-1] = torch.cat([text_hidden_states, image_hidden_states], dim=0)
301
+ ###################################################################################
302
+
303
+ key = attn.to_k(hidden_states)
304
+ value = attn.to_v(hidden_states)
305
+ query = attn.to_q(hidden_states)
306
+
307
+ ###################################################################################
308
+ if r_k:
309
+ key = attn.to_k(r_hidden_states)
310
+ if r_q:
311
+ query = attn.to_q(r_hidden_states)
312
+ if r_v:
313
+ value = attn.to_v(r_hidden_states)
314
+
315
+ if concept_process:
316
+ for concept_feature_ in concept_features:
317
+ if c_q:
318
+ c_query = attn.to_q(concept_feature_)
319
+ c_query = c_query.repeat(query.shape[0], 1, 1)
320
+ query = torch.cat([query, c_query], dim=1)
321
+ if c_kv:
322
+ c_key = attn.to_k(concept_feature_)
323
+ c_key = c_key.repeat(key.shape[0], 1, 1)
324
+
325
+ c_value = attn.to_v(concept_feature_)
326
+ c_value = c_value.repeat(value.shape[0], 1, 1)
327
+
328
+ key = torch.cat([key, c_key], dim=1)
329
+ value = torch.cat([value, c_value], dim=1)
330
+ ###################################################################################
331
+
332
+ inner_dim = key.shape[-1]
333
+ head_dim = inner_dim // attn.heads
334
+
335
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
336
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
337
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
338
+
339
+ if attn.norm_q is not None:
340
+ query = attn.norm_q(query)
341
+ if attn.norm_k is not None:
342
+ key = attn.norm_k(key)
343
+
344
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
345
+ if encoder_hidden_states is not None:
346
+ # `context` projections.
347
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
348
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
349
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
350
+
351
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
352
+ batch_size, -1, attn.heads, head_dim
353
+ ).transpose(1, 2)
354
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
355
+ batch_size, -1, attn.heads, head_dim
356
+ ).transpose(1, 2)
357
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
358
+ batch_size, -1, attn.heads, head_dim
359
+ ).transpose(1, 2)
360
+
361
+ if attn.norm_added_q is not None:
362
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
363
+ if attn.norm_added_k is not None:
364
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
365
+
366
+ # attention
367
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
368
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
369
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
370
+
371
+ if image_rotary_emb is not None:
372
+ from diffusers.models.embeddings import apply_rotary_emb
373
+
374
+ # use original position emb or text emb
375
+ if not c_q:
376
+ query = apply_rotary_emb(query, image_rotary_emb)
377
+ if not c_kv:
378
+ key = apply_rotary_emb(key, image_rotary_emb)
379
+
380
+ ###################################################################################
381
+ def get_concept_rotary_emb(ori_rotary_emb, shift_masks):
382
+ enc_emb = ori_rotary_emb[:512, :]
383
+ hid_emb = ori_rotary_emb[512:, :]
384
+
385
+ concept_embs = []
386
+ for mask in shift_masks:
387
+ concept_embs.append(hid_emb[mask, :])
388
+ concept_emb = torch.cat(concept_embs, dim=0) if len(concept_embs) > 0 else torch.zeros(0, hid_emb.shape[1], device=hid_emb.device)
389
+ image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0)
390
+ return image_rotary_emb
391
+
392
+ if concept_process:
393
+ # 1. use original position emb with plural masks
394
+ image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_masks)
395
+ image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_masks)
396
+ image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1)
397
+
398
+ # 2. use text emb with plural masks
399
+ # total_dims = sum((mask == 1).sum().item() for mask in self.masks)
400
+ # concept_rotary_emb_0 = torch.ones((total_dims, 128)).to(self.device)
401
+ # concept_rotary_emb_1 = torch.zeros((total_dims, 128)).to(self.device)
402
+ # image_rotary_emb = (
403
+ # torch.cat([image_rotary_emb[0], concept_rotary_emb_0], dim=0),
404
+ # torch.cat([image_rotary_emb[1], concept_rotary_emb_1], dim=0)
405
+ # )
406
+
407
+ if c_q:
408
+ query = apply_rotary_emb(query, image_rotary_emb)
409
+ if c_kv:
410
+ key = apply_rotary_emb(key, image_rotary_emb)
411
+ ###################################################################################
412
+
413
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
414
+
415
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
416
+ hidden_states = hidden_states.to(query.dtype)
417
+
418
+ if encoder_hidden_states is not None:
419
+ encoder_hidden_states, hidden_states = (
420
+ hidden_states[:, : encoder_hidden_states.shape[1]],
421
+ hidden_states[:, encoder_hidden_states.shape[1] :],
422
+ )
423
+
424
+ # linear proj
425
+ hidden_states = attn.to_out[0](hidden_states)
426
+ # dropout
427
+ hidden_states = attn.to_out[1](hidden_states)
428
+
429
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
430
+
431
+ ################################################################
432
+ # restore after token concatenation
433
+ hidden_states = hidden_states[:, :self.img_dims, :]
434
+ ################################################################
435
+
436
+ return hidden_states, encoder_hidden_states
437
+ else:
438
+ ################################################################
439
+ dims = self.img_dims + 512
440
+ hidden_states = hidden_states[:, :dims, :]
441
+ ################################################################
442
+
443
+ return hidden_states
pa_src/pipeline.py ADDED
@@ -0,0 +1,1272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7
+
8
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
9
+ from diffusers.loaders import (
10
+ FluxLoraLoaderMixin,
11
+ FromSingleFileMixin,
12
+ TextualInversionLoaderMixin,
13
+ )
14
+ from diffusers.models.autoencoders import AutoencoderKL
15
+ from diffusers.models.transformers import FluxTransformer2DModel
16
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
+ from diffusers.utils import (
20
+ USE_PEFT_BACKEND,
21
+ is_torch_xla_available,
22
+ logging,
23
+ replace_example_docstring,
24
+ scale_lora_layers,
25
+ unscale_lora_layers,
26
+ )
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> import requests
45
+ >>> import PIL
46
+ >>> from io import BytesIO
47
+ >>> from diffusers import DiffusionPipeline
48
+
49
+ >>> pipe = DiffusionPipeline.from_pretrained(
50
+ ... "black-forest-labs/FLUX.1-dev",
51
+ ... torch_dtype=torch.bfloat16,
52
+ ... custom_pipeline="pipeline_flux_rf_inversion")
53
+ >>> pipe.to("cuda")
54
+
55
+ >>> def download_image(url):
56
+ ... response = requests.get(url)
57
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
58
+
59
+
60
+ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
61
+ >>> image = download_image(img_url)
62
+
63
+ >>> inverted_latents, image_latents, latent_image_ids = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5)
64
+
65
+ >>> edited_image = pipe(
66
+ ... prompt="a tomato",
67
+ ... inverted_latents=inverted_latents,
68
+ ... image_latents=image_latents,
69
+ ... latent_image_ids=latent_image_ids,
70
+ ... start_timestep=0,
71
+ ... stop_timestep=.25,
72
+ ... num_inference_steps=28,
73
+ ... eta=0.9,
74
+ ... ).images[0]
75
+ ```
76
+ """
77
+
78
+
79
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
80
+ def calculate_shift(
81
+ image_seq_len,
82
+ base_seq_len: int = 256,
83
+ max_seq_len: int = 4096,
84
+ base_shift: float = 0.5,
85
+ max_shift: float = 1.16,
86
+ ):
87
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
88
+ b = base_shift - m * base_seq_len
89
+ mu = image_seq_len * m + b
90
+ return mu
91
+
92
+
93
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
94
+ def retrieve_timesteps(
95
+ scheduler,
96
+ num_inference_steps: Optional[int] = None,
97
+ device: Optional[Union[str, torch.device]] = None,
98
+ timesteps: Optional[List[int]] = None,
99
+ sigmas: Optional[List[float]] = None,
100
+ **kwargs,
101
+ ):
102
+ r"""
103
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
104
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
105
+
106
+ Args:
107
+ scheduler (`SchedulerMixin`):
108
+ The scheduler to get timesteps from.
109
+ num_inference_steps (`int`):
110
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
111
+ must be `None`.
112
+ device (`str` or `torch.device`, *optional*):
113
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114
+ timesteps (`List[int]`, *optional*):
115
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
116
+ `num_inference_steps` and `sigmas` must be `None`.
117
+ sigmas (`List[float]`, *optional*):
118
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
119
+ `num_inference_steps` and `timesteps` must be `None`.
120
+
121
+ Returns:
122
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
123
+ second element is the number of inference steps.
124
+ """
125
+ if timesteps is not None and sigmas is not None:
126
+ raise ValueError(
127
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
128
+ )
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(
131
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
132
+ )
133
+ if not accepts_timesteps:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" timestep schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ elif sigmas is not None:
142
+ accept_sigmas = "sigmas" in set(
143
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
144
+ )
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ class RFPanoInversionParallelFluxPipeline(
160
+ DiffusionPipeline,
161
+ FluxLoraLoaderMixin,
162
+ FromSingleFileMixin,
163
+ TextualInversionLoaderMixin,
164
+ ):
165
+ r"""
166
+ The Flux pipeline for text-to-image generation.
167
+
168
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
169
+
170
+ Args:
171
+ transformer ([`FluxTransformer2DModel`]):
172
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
173
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
175
+ vae ([`AutoencoderKL`]):
176
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
177
+ text_encoder ([`CLIPTextModel`]):
178
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
179
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
180
+ text_encoder_2 ([`T5EncoderModel`]):
181
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
182
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
183
+ tokenizer (`CLIPTokenizer`):
184
+ Tokenizer of class
185
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
186
+ tokenizer_2 (`T5TokenizerFast`):
187
+ Second Tokenizer of class
188
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
192
+ _optional_components = []
193
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
194
+
195
+ def __init__(
196
+ self,
197
+ scheduler: FlowMatchEulerDiscreteScheduler,
198
+ vae: AutoencoderKL,
199
+ text_encoder: CLIPTextModel,
200
+ tokenizer: CLIPTokenizer,
201
+ text_encoder_2: T5EncoderModel,
202
+ tokenizer_2: T5TokenizerFast,
203
+ transformer: FluxTransformer2DModel,
204
+ ):
205
+ super().__init__()
206
+
207
+ self.register_modules(
208
+ vae=vae,
209
+ text_encoder=text_encoder,
210
+ text_encoder_2=text_encoder_2,
211
+ tokenizer=tokenizer,
212
+ tokenizer_2=tokenizer_2,
213
+ transformer=transformer,
214
+ scheduler=scheduler,
215
+ )
216
+ self.vae_scale_factor = (
217
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
218
+ if hasattr(self, "vae") and self.vae is not None
219
+ else 8
220
+ )
221
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
222
+ self.tokenizer_max_length = (
223
+ self.tokenizer.model_max_length
224
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
225
+ else 77
226
+ )
227
+ self.default_sample_size = 128
228
+
229
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
230
+ def _get_t5_prompt_embeds(
231
+ self,
232
+ prompt: Union[str, List[str]] = None,
233
+ num_images_per_prompt: int = 1,
234
+ max_sequence_length: int = 512,
235
+ device: Optional[torch.device] = None,
236
+ dtype: Optional[torch.dtype] = None,
237
+ ):
238
+ device = device or self._execution_device
239
+ dtype = dtype or self.text_encoder.dtype
240
+
241
+ prompt = [prompt] if isinstance(prompt, str) else prompt
242
+ batch_size = len(prompt)
243
+
244
+ if isinstance(self, TextualInversionLoaderMixin):
245
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
246
+
247
+ text_inputs = self.tokenizer_2(
248
+ prompt,
249
+ padding="max_length",
250
+ max_length=max_sequence_length,
251
+ truncation=True,
252
+ return_length=False,
253
+ return_overflowing_tokens=False,
254
+ return_tensors="pt",
255
+ )
256
+ text_input_ids = text_inputs.input_ids
257
+ untruncated_ids = self.tokenizer_2(
258
+ prompt, padding="longest", return_tensors="pt"
259
+ ).input_ids
260
+
261
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
262
+ text_input_ids, untruncated_ids
263
+ ):
264
+ removed_text = self.tokenizer_2.batch_decode(
265
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
266
+ )
267
+ logger.warning(
268
+ "The following part of your input was truncated because `max_sequence_length` is set to "
269
+ f" {max_sequence_length} tokens: {removed_text}"
270
+ )
271
+
272
+ prompt_embeds = self.text_encoder_2(
273
+ text_input_ids.to(device), output_hidden_states=False
274
+ )[0]
275
+
276
+ dtype = self.text_encoder_2.dtype
277
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
278
+
279
+ _, seq_len, _ = prompt_embeds.shape
280
+
281
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
282
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
283
+ prompt_embeds = prompt_embeds.view(
284
+ batch_size * num_images_per_prompt, seq_len, -1
285
+ )
286
+
287
+ return prompt_embeds
288
+
289
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
290
+ def _get_clip_prompt_embeds(
291
+ self,
292
+ prompt: Union[str, List[str]],
293
+ num_images_per_prompt: int = 1,
294
+ device: Optional[torch.device] = None,
295
+ ):
296
+ device = device or self._execution_device
297
+
298
+ prompt = [prompt] if isinstance(prompt, str) else prompt
299
+ batch_size = len(prompt)
300
+
301
+ if isinstance(self, TextualInversionLoaderMixin):
302
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
303
+
304
+ text_inputs = self.tokenizer(
305
+ prompt,
306
+ padding="max_length",
307
+ max_length=self.tokenizer_max_length,
308
+ truncation=True,
309
+ return_overflowing_tokens=False,
310
+ return_length=False,
311
+ return_tensors="pt",
312
+ )
313
+
314
+ text_input_ids = text_inputs.input_ids
315
+ untruncated_ids = self.tokenizer(
316
+ prompt, padding="longest", return_tensors="pt"
317
+ ).input_ids
318
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
319
+ text_input_ids, untruncated_ids
320
+ ):
321
+ removed_text = self.tokenizer.batch_decode(
322
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
323
+ )
324
+ logger.warning(
325
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
326
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
327
+ )
328
+ prompt_embeds = self.text_encoder(
329
+ text_input_ids.to(device), output_hidden_states=False
330
+ )
331
+
332
+ # Use pooled output of CLIPTextModel
333
+ prompt_embeds = prompt_embeds.pooler_output
334
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
335
+
336
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
337
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
338
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
339
+
340
+ return prompt_embeds
341
+
342
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
343
+ def encode_prompt(
344
+ self,
345
+ prompt: Union[str, List[str]],
346
+ prompt_2: Union[str, List[str]],
347
+ device: Optional[torch.device] = None,
348
+ num_images_per_prompt: int = 1,
349
+ prompt_embeds: Optional[torch.FloatTensor] = None,
350
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
351
+ max_sequence_length: int = 512,
352
+ lora_scale: Optional[float] = None,
353
+ ):
354
+ r"""
355
+
356
+ Args:
357
+ prompt (`str` or `List[str]`, *optional*):
358
+ prompt to be encoded
359
+ prompt_2 (`str` or `List[str]`, *optional*):
360
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
361
+ used in all text-encoders
362
+ device: (`torch.device`):
363
+ torch device
364
+ num_images_per_prompt (`int`):
365
+ number of images that should be generated per prompt
366
+ prompt_embeds (`torch.FloatTensor`, *optional*):
367
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
368
+ provided, text embeddings will be generated from `prompt` input argument.
369
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
370
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
371
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
372
+ lora_scale (`float`, *optional*):
373
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
374
+ """
375
+ device = device or self._execution_device
376
+
377
+ # set lora scale so that monkey patched LoRA
378
+ # function of text encoder can correctly access it
379
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
380
+ self._lora_scale = lora_scale
381
+
382
+ # dynamically adjust the LoRA scale
383
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
384
+ scale_lora_layers(self.text_encoder, lora_scale)
385
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
386
+ scale_lora_layers(self.text_encoder_2, lora_scale)
387
+
388
+ prompt = [prompt] if isinstance(prompt, str) else prompt
389
+
390
+ if prompt_embeds is None:
391
+ prompt_2 = prompt_2 or prompt
392
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
393
+
394
+ # We only use the pooled prompt output from the CLIPTextModel
395
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
396
+ prompt=prompt,
397
+ device=device,
398
+ num_images_per_prompt=num_images_per_prompt,
399
+ )
400
+ prompt_embeds = self._get_t5_prompt_embeds(
401
+ prompt=prompt_2,
402
+ num_images_per_prompt=num_images_per_prompt,
403
+ max_sequence_length=max_sequence_length,
404
+ device=device,
405
+ )
406
+
407
+ if self.text_encoder is not None:
408
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
409
+ # Retrieve the original scale by scaling back the LoRA layers
410
+ unscale_lora_layers(self.text_encoder, lora_scale)
411
+
412
+ if self.text_encoder_2 is not None:
413
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
414
+ # Retrieve the original scale by scaling back the LoRA layers
415
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
416
+
417
+ dtype = (
418
+ self.text_encoder.dtype
419
+ if self.text_encoder is not None
420
+ else self.transformer.dtype
421
+ )
422
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
423
+
424
+ return prompt_embeds, pooled_prompt_embeds, text_ids
425
+
426
+ @torch.no_grad()
427
+ # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
428
+ def encode_image(
429
+ self,
430
+ image,
431
+ dtype=None,
432
+ height=None,
433
+ width=None,
434
+ resize_mode="default",
435
+ crops_coords=None,
436
+ ):
437
+ image = self.image_processor.preprocess(
438
+ image=image,
439
+ height=height,
440
+ width=width,
441
+ resize_mode=resize_mode,
442
+ crops_coords=crops_coords,
443
+ )
444
+ resized = self.image_processor.postprocess(image=image, output_type="pil")
445
+
446
+ if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
447
+ logger.warning(
448
+ "Your input images far exceed the default resolution of the underlying diffusion model. "
449
+ "The output images may contain severe artifacts! "
450
+ "Consider down-sampling the input using the `height` and `width` parameters"
451
+ )
452
+ image = image.to(dtype)
453
+
454
+ x0 = self.vae.encode(image.to(self.device)).latent_dist.sample()
455
+ x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
456
+ x0 = x0.to(dtype)
457
+ return x0, resized
458
+
459
+ def check_inputs(
460
+ self,
461
+ prompt,
462
+ prompt_2,
463
+ inverted_latents,
464
+ image_latents,
465
+ latent_image_ids,
466
+ height,
467
+ width,
468
+ start_timestep,
469
+ stop_timestep,
470
+ prompt_embeds=None,
471
+ pooled_prompt_embeds=None,
472
+ callback_on_step_end_tensor_inputs=None,
473
+ max_sequence_length=None,
474
+ ):
475
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
476
+ raise ValueError(
477
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
478
+ )
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs
482
+ for k in callback_on_step_end_tensor_inputs
483
+ ):
484
+ raise ValueError(
485
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
486
+ )
487
+
488
+ if prompt is not None and prompt_embeds is not None:
489
+ raise ValueError(
490
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
491
+ " only forward one of the two."
492
+ )
493
+ elif prompt_2 is not None and prompt_embeds is not None:
494
+ raise ValueError(
495
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
496
+ " only forward one of the two."
497
+ )
498
+ elif prompt is None and prompt_embeds is None:
499
+ raise ValueError(
500
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
501
+ )
502
+ elif prompt is not None and (
503
+ not isinstance(prompt, str) and not isinstance(prompt, list)
504
+ ):
505
+ raise ValueError(
506
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
507
+ )
508
+ elif prompt_2 is not None and (
509
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
510
+ ):
511
+ raise ValueError(
512
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
513
+ )
514
+
515
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
516
+ raise ValueError(
517
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
518
+ )
519
+
520
+ if max_sequence_length is not None and max_sequence_length > 512:
521
+ raise ValueError(
522
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
523
+ )
524
+
525
+ if inverted_latents is not None and (
526
+ image_latents is None or latent_image_ids is None
527
+ ):
528
+ raise ValueError(
529
+ "If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. "
530
+ )
531
+ # check start_timestep and stop_timestep
532
+ if start_timestep < 0 or start_timestep > stop_timestep:
533
+ raise ValueError(
534
+ f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}"
535
+ )
536
+
537
+ @staticmethod
538
+ def _prepare_latent_image_ids_offset(
539
+ batch_size, height, width, device, dtype, x_offset=0, y_offset=32
540
+ ):
541
+ latent_image_ids = torch.zeros(height, width, 3)
542
+ latent_image_ids[..., 1] = (
543
+ latent_image_ids[..., 1] + torch.arange(height)[:, None]
544
+ )
545
+ latent_image_ids[..., 2] = (
546
+ latent_image_ids[..., 2] + torch.arange(width)[None, :]
547
+ )
548
+
549
+ latent_image_ids[..., 1] += x_offset
550
+ latent_image_ids[..., 2] += y_offset
551
+
552
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
553
+ latent_image_ids.shape
554
+ )
555
+
556
+ latent_image_ids = latent_image_ids.reshape(
557
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
558
+ )
559
+
560
+ return latent_image_ids.to(device=device, dtype=dtype)
561
+
562
+ @staticmethod
563
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
564
+ latent_image_ids = torch.zeros(height, width, 3)
565
+ latent_image_ids[..., 1] = (
566
+ latent_image_ids[..., 1] + torch.arange(height)[:, None]
567
+ )
568
+ latent_image_ids[..., 2] = (
569
+ latent_image_ids[..., 2] + torch.arange(width)[None, :]
570
+ )
571
+
572
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
573
+ latent_image_ids.shape
574
+ )
575
+
576
+ latent_image_ids = latent_image_ids.reshape(
577
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
578
+ )
579
+
580
+ return latent_image_ids.to(device=device, dtype=dtype)
581
+
582
+ @staticmethod
583
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
584
+ latents = latents.view(
585
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
586
+ )
587
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
588
+ latents = latents.reshape(
589
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
590
+ )
591
+
592
+ return latents
593
+
594
+ @staticmethod
595
+ def _unpack_latents(latents, height, width, vae_scale_factor):
596
+ batch_size, num_patches, channels = latents.shape
597
+
598
+ height = height // vae_scale_factor
599
+ width = width // vae_scale_factor
600
+
601
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
602
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
603
+
604
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
605
+
606
+ return latents
607
+
608
+ def enable_vae_slicing(self):
609
+ r"""
610
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
611
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
612
+ """
613
+ self.vae.enable_slicing()
614
+
615
+ def disable_vae_slicing(self):
616
+ r"""
617
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
618
+ computing decoding in one step.
619
+ """
620
+ self.vae.disable_slicing()
621
+
622
+ def enable_vae_tiling(self):
623
+ r"""
624
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
625
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
626
+ processing larger images.
627
+ """
628
+ self.vae.enable_tiling()
629
+
630
+ def disable_vae_tiling(self):
631
+ r"""
632
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
633
+ computing decoding in one step.
634
+ """
635
+ self.vae.disable_tiling()
636
+
637
+ def prepare_latents_inversion(
638
+ self,
639
+ batch_size,
640
+ num_channels_latents,
641
+ height,
642
+ width,
643
+ dtype,
644
+ device,
645
+ image_latents,
646
+ ):
647
+ height = int(height) // self.vae_scale_factor
648
+ width = int(width) // self.vae_scale_factor
649
+ latents = self._pack_latents(
650
+ image_latents, batch_size, num_channels_latents, height, width
651
+ )
652
+
653
+ latent_image_ids = self._prepare_latent_image_ids(
654
+ batch_size, height // 2, width // 2, device, dtype
655
+ )
656
+
657
+ return latents, latent_image_ids
658
+
659
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
660
+ def prepare_latents(
661
+ self,
662
+ batch_size,
663
+ num_channels_latents,
664
+ height,
665
+ width,
666
+ dtype,
667
+ device,
668
+ generator,
669
+ latents=None,
670
+ ):
671
+ # VAE applies 8x compression on images but we must also account for packing which requires
672
+ # latent height and width to be divisible by 2.
673
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
674
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
675
+
676
+ shape = (batch_size, num_channels_latents, height, width)
677
+
678
+ if latents is not None:
679
+ latent_image_ids = self._prepare_latent_image_ids(
680
+ batch_size, height // 2, width // 2, device, dtype
681
+ )
682
+ return latents.to(device=device, dtype=dtype), latent_image_ids
683
+
684
+ if isinstance(generator, list) and len(generator) != batch_size:
685
+ raise ValueError(
686
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
687
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
688
+ )
689
+
690
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
691
+ latents = self._pack_latents(
692
+ latents, batch_size, num_channels_latents, height, width
693
+ )
694
+
695
+ latent_image_ids = self._prepare_latent_image_ids(
696
+ batch_size, height // 2, width // 2, device, dtype
697
+ )
698
+
699
+ return latents, latent_image_ids
700
+
701
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
702
+ def get_timesteps(self, num_inference_steps, strength=1.0):
703
+ # get the original timestep using init_timestep
704
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
705
+
706
+ t_start = int(max(num_inference_steps - init_timestep, 0))
707
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
708
+ sigmas = self.scheduler.sigmas[t_start * self.scheduler.order :]
709
+ if hasattr(self.scheduler, "set_begin_index"):
710
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
711
+
712
+ return timesteps, sigmas, num_inference_steps - t_start
713
+
714
+ @property
715
+ def guidance_scale(self):
716
+ return self._guidance_scale
717
+
718
+ @property
719
+ def joint_attention_kwargs(self):
720
+ return self._joint_attention_kwargs
721
+
722
+ @property
723
+ def num_timesteps(self):
724
+ return self._num_timesteps
725
+
726
+ @property
727
+ def interrupt(self):
728
+ return self._interrupt
729
+
730
+ @torch.no_grad()
731
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
732
+ def __call__(
733
+ self,
734
+ prompt: Union[str, List[str]] = None,
735
+ prompt_2: Optional[Union[str, List[str]]] = None,
736
+ inverted_latents: Optional[torch.FloatTensor] = None,
737
+ image_latents: Optional[torch.FloatTensor] = None,
738
+ latent_image_ids: Optional[torch.FloatTensor] = None,
739
+ height: Optional[int] = None,
740
+ width: Optional[int] = None,
741
+ eta: float = 1.0,
742
+ decay_eta: Optional[bool] = False,
743
+ eta_decay_power: Optional[float] = 1.0,
744
+ strength: float = 1.0,
745
+ start_timestep: float = 0,
746
+ stop_timestep: float = 0.25,
747
+ num_inference_steps: int = 28,
748
+ sigmas: Optional[List[float]] = None,
749
+ timesteps: List[int] = None,
750
+ guidance_scale: float = 3.5,
751
+ num_images_per_prompt: Optional[int] = 1,
752
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
753
+ latents: Optional[torch.FloatTensor] = None,
754
+ prompt_embeds: Optional[torch.FloatTensor] = None,
755
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
756
+ output_type: Optional[str] = "pil",
757
+ return_dict: bool = True,
758
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
759
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
760
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
761
+ max_sequence_length: int = 512,
762
+ ###################################
763
+ mask = None,
764
+ use_timestep = False
765
+ ):
766
+ r"""
767
+ Function invoked when calling the pipeline for generation.
768
+
769
+ Args:
770
+ prompt (`str` or `List[str]`, *optional*):
771
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
772
+ instead.
773
+ prompt_2 (`str` or `List[str]`, *optional*):
774
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
775
+ will be used instead
776
+ inverted_latents (`torch.Tensor`, *optional*):
777
+ The inverted latents from `pipe.invert`.
778
+ image_latents (`torch.Tensor`, *optional*):
779
+ The image latents from `pipe.invert`.
780
+ latent_image_ids (`torch.Tensor`, *optional*):
781
+ The latent image ids from `pipe.invert`.
782
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
783
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
784
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
785
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
786
+ eta (`float`, *optional*, defaults to 1.0):
787
+ The controller guidance, balancing faithfulness & editability:
788
+ higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.
789
+ num_inference_steps (`int`, *optional*, defaults to 50):
790
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
791
+ expense of slower inference.
792
+ timesteps (`List[int]`, *optional*):
793
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
794
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
795
+ passed will be used. Must be in descending order.
796
+ guidance_scale (`float`, *optional*, defaults to 7.0):
797
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
798
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
799
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
800
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
801
+ usually at the expense of lower image quality.
802
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
803
+ The number of images to generate per prompt.
804
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
805
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
806
+ to make generation deterministic.
807
+ latents (`torch.FloatTensor`, *optional*):
808
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
809
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
810
+ tensor will ge generated by sampling using the supplied random `generator`.
811
+ prompt_embeds (`torch.FloatTensor`, *optional*):
812
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
813
+ provided, text embeddings will be generated from `prompt` input argument.
814
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
815
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
816
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
817
+ output_type (`str`, *optional*, defaults to `"pil"`):
818
+ The output format of the generate image. Choose between
819
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
820
+ return_dict (`bool`, *optional*, defaults to `True`):
821
+ Whether to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
822
+ joint_attention_kwargs (`dict`, *optional*):
823
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
824
+ `self.processor` in
825
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
826
+ callback_on_step_end (`Callable`, *optional*):
827
+ A function that calls at the end of each denoising steps during the inference. The function is called
828
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
829
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
830
+ `callback_on_step_end_tensor_inputs`.
831
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
832
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
833
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
834
+ `._callback_tensor_inputs` attribute of your pipeline class.
835
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
836
+
837
+ Examples:
838
+
839
+ Returns:
840
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
841
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
842
+ images.
843
+ """
844
+
845
+ height = height or self.default_sample_size * self.vae_scale_factor
846
+ width = width or self.default_sample_size * self.vae_scale_factor
847
+
848
+ # 1. Check inputs. Raise error if not correct
849
+ self.check_inputs(
850
+ prompt,
851
+ prompt_2,
852
+ inverted_latents,
853
+ image_latents,
854
+ latent_image_ids,
855
+ height,
856
+ width,
857
+ start_timestep,
858
+ stop_timestep,
859
+ prompt_embeds=prompt_embeds,
860
+ pooled_prompt_embeds=pooled_prompt_embeds,
861
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
862
+ max_sequence_length=max_sequence_length,
863
+ )
864
+
865
+ self._guidance_scale = guidance_scale
866
+ self._joint_attention_kwargs = joint_attention_kwargs
867
+ self._interrupt = False
868
+ do_rf_inversion = inverted_latents is not None
869
+
870
+ # 2. Define call parameters
871
+ if prompt is not None and isinstance(prompt, str):
872
+ batch_size = 1
873
+ elif prompt is not None and isinstance(prompt, list):
874
+ batch_size = len(prompt)
875
+ else:
876
+ batch_size = prompt_embeds.shape[0]
877
+
878
+ device = self._execution_device
879
+
880
+ lora_scale = (
881
+ self.joint_attention_kwargs.get("scale", None)
882
+ if self.joint_attention_kwargs is not None
883
+ else None
884
+ )
885
+ (
886
+ prompt_embeds,
887
+ pooled_prompt_embeds,
888
+ text_ids,
889
+ ) = self.encode_prompt(
890
+ prompt=prompt,
891
+ prompt_2=prompt_2,
892
+ prompt_embeds=prompt_embeds,
893
+ pooled_prompt_embeds=pooled_prompt_embeds,
894
+ device=device,
895
+ num_images_per_prompt=num_images_per_prompt,
896
+ max_sequence_length=max_sequence_length,
897
+ lora_scale=lora_scale,
898
+ )
899
+
900
+ # 4. Prepare latent variables
901
+ num_channels_latents = self.transformer.config.in_channels // 4
902
+ if do_rf_inversion:
903
+ latents = inverted_latents
904
+ new_latents, _ = self.prepare_latents(
905
+ 1,
906
+ num_channels_latents,
907
+ height,
908
+ width,
909
+ prompt_embeds.dtype,
910
+ device,
911
+ generator,
912
+ latents=None,
913
+ )
914
+ ###############################################
915
+ n_h = height // 16
916
+ n_w = width // 16
917
+ bsz, _, dim = latents.shape
918
+
919
+ new_latents = new_latents.reshape(bsz, n_h, n_w, dim)
920
+ first_col = new_latents[:, :, 0:1, :]
921
+ last_col = new_latents[:, :, -1:, :]
922
+ new_latents = torch.cat([last_col, new_latents, first_col], dim=2)
923
+ new_latents = new_latents.reshape(bsz, -1, dim)
924
+ ###############################################
925
+
926
+ latents = torch.cat((latents, new_latents), dim=0)
927
+ bsz, _, dim = latents.shape
928
+ else:
929
+ latents, latent_image_ids = self.prepare_latents(
930
+ batch_size * num_images_per_prompt,
931
+ num_channels_latents,
932
+ height,
933
+ width,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 5. Prepare timesteps
941
+ sigmas = (
942
+ np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
943
+ if sigmas is None
944
+ else sigmas
945
+ )
946
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (
947
+ int(width) // self.vae_scale_factor // 2
948
+ )
949
+ mu = calculate_shift(
950
+ image_seq_len,
951
+ self.scheduler.config.base_image_seq_len,
952
+ self.scheduler.config.max_image_seq_len,
953
+ self.scheduler.config.base_shift,
954
+ self.scheduler.config.max_shift,
955
+ )
956
+ # if timesteps is not None:
957
+ # sigmas = None
958
+ timesteps, num_inference_steps = retrieve_timesteps(
959
+ self.scheduler,
960
+ num_inference_steps,
961
+ device,
962
+ timesteps,
963
+ sigmas,
964
+ mu=mu,
965
+ )
966
+ if do_rf_inversion:
967
+ start_timestep = int(start_timestep * num_inference_steps)
968
+ stop_timestep = min(
969
+ int(stop_timestep * num_inference_steps), num_inference_steps
970
+ )
971
+ timesteps, sigmas, num_inference_steps = self.get_timesteps(
972
+ num_inference_steps, strength
973
+ )
974
+ num_warmup_steps = max(
975
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
976
+ )
977
+ self._num_timesteps = len(timesteps)
978
+
979
+ # handle guidance
980
+ if self.transformer.config.guidance_embeds:
981
+ guidance = torch.full(
982
+ [1], guidance_scale, device=device, dtype=torch.float32
983
+ )
984
+ guidance = guidance.expand(latents.shape[0])
985
+ else:
986
+ guidance = None
987
+
988
+ if do_rf_inversion:
989
+ y_0 = image_latents.clone()
990
+
991
+ # 6. Denoising loop / Controlled Reverse ODE, Algorithm 2 from: https://arxiv.org/pdf/2410.10792
992
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
993
+ for i, t in enumerate(timesteps):
994
+ if do_rf_inversion:
995
+ # ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps.
996
+ t_i = 1 - t / 1000
997
+ dt = torch.tensor(1 / (len(timesteps) - 1), device=device)
998
+
999
+ if self.interrupt:
1000
+ continue
1001
+
1002
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1003
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1004
+
1005
+ # 添加额外的 timestep
1006
+ if use_timestep:
1007
+ self._joint_attention_kwargs = (
1008
+ {}
1009
+ if self._joint_attention_kwargs is None
1010
+ else self._joint_attention_kwargs
1011
+ )
1012
+ self._joint_attention_kwargs["timestep"] = timestep[0].item() / 1000
1013
+
1014
+ noise_pred = self.transformer(
1015
+ hidden_states=latents,
1016
+ timestep=timestep / 1000,
1017
+ guidance=guidance,
1018
+ pooled_projections=pooled_prompt_embeds,
1019
+ encoder_hidden_states=prompt_embeds,
1020
+ txt_ids=text_ids,
1021
+ img_ids=latent_image_ids,
1022
+ joint_attention_kwargs=self.joint_attention_kwargs,
1023
+ return_dict=False,
1024
+ )[0]
1025
+
1026
+ latents_dtype = latents.dtype
1027
+
1028
+ # noise_pred.shape: torch.Size([2, 4096, 64])
1029
+ if do_rf_inversion:
1030
+ v_t = -noise_pred[:-1]
1031
+ v_t_cond = (y_0 - latents[:-1]) / (1 - t_i)
1032
+ eta_t = eta if start_timestep <= i < stop_timestep else 0.0
1033
+ if decay_eta:
1034
+ eta_t = (
1035
+ eta_t * (1 - i / num_inference_steps) ** eta_decay_power
1036
+ ) # Decay eta over the loop
1037
+ v_hat_t = v_t + eta_t * (v_t_cond - v_t)
1038
+ # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
1039
+ latents[:-1] = latents[:-1] + v_hat_t * (sigmas[i] - sigmas[i + 1])
1040
+
1041
+ latents[-1] = self.scheduler.step(
1042
+ noise_pred[-1], t, latents[-1], return_dict=False
1043
+ )[0]
1044
+
1045
+ else:
1046
+ # compute the previous noisy sample x_t -> x_t-1
1047
+ latents = self.scheduler.step(
1048
+ noise_pred, t, latents, return_dict=False
1049
+ )[0]
1050
+
1051
+ ############################################
1052
+ # enhance image consistency
1053
+ t_l = 0.5 # smaller values enforce stronger consistency but may cause visual discontinuities
1054
+ if mask is not None and timestep[1].item() / 1000 >= t_l:
1055
+ mask = mask.to(device)
1056
+ latents[1] = latents[1] * (1.0-mask) + latents[0] * mask
1057
+ ############################################
1058
+
1059
+ if latents.dtype != latents_dtype:
1060
+ if torch.backends.mps.is_available():
1061
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1062
+ latents = latents.to(latents_dtype)
1063
+
1064
+ if callback_on_step_end is not None:
1065
+ callback_kwargs = {}
1066
+ for k in callback_on_step_end_tensor_inputs:
1067
+ callback_kwargs[k] = locals()[k]
1068
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1069
+
1070
+ latents = callback_outputs.pop("latents", latents)
1071
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1072
+
1073
+ # call the callback, if provided
1074
+ if i == len(timesteps) - 1 or (
1075
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1076
+ ):
1077
+ progress_bar.update()
1078
+
1079
+ if XLA_AVAILABLE:
1080
+ xm.mark_step()
1081
+
1082
+ ######################################################
1083
+ latents = latents.reshape(bsz, n_h, n_w+2, dim)
1084
+ latents = latents[:, :, 1:-1, :]
1085
+ latents = latents.reshape(bsz, -1, dim)
1086
+ ######################################################
1087
+
1088
+ if output_type == "latent":
1089
+ image = latents
1090
+
1091
+ else:
1092
+ latents = self._unpack_latents(
1093
+ latents, height, width, self.vae_scale_factor
1094
+ )
1095
+ latents = (
1096
+ latents / self.vae.config.scaling_factor
1097
+ ) + self.vae.config.shift_factor
1098
+ image = self.vae.decode(latents, return_dict=False)[0]
1099
+ image = self.image_processor.postprocess(image, output_type=output_type)
1100
+
1101
+ # Offload all models
1102
+ self.maybe_free_model_hooks()
1103
+
1104
+ if not return_dict:
1105
+ return (image,)
1106
+
1107
+ return FluxPipelineOutput(images=image)
1108
+
1109
+ @torch.no_grad()
1110
+ def invert(
1111
+ self,
1112
+ image: PipelineImageInput,
1113
+ source_prompt: str = "",
1114
+ source_guidance_scale=0.0,
1115
+ num_inversion_steps: int = 28,
1116
+ strength: float = 1.0,
1117
+ gamma: float = 0.5,
1118
+ height: Optional[int] = None,
1119
+ width: Optional[int] = None,
1120
+ timesteps: List[int] = None,
1121
+ dtype: Optional[torch.dtype] = None,
1122
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1123
+ generator: Optional[torch.Generator] = None,
1124
+ ):
1125
+ r"""
1126
+ Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792
1127
+ Args:
1128
+ image (`PipelineImageInput`):
1129
+ Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect
1130
+ ratio.
1131
+ source_prompt (`str` or `List[str]`, *optional* defaults to an empty prompt as done in the original paper):
1132
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1133
+ instead.
1134
+ source_guidance_scale (`float`, *optional*, defaults to 0.0):
1135
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1136
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1137
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). For this algorithm, it's better to keep it 0.
1138
+ num_inversion_steps (`int`, *optional*, defaults to 28):
1139
+ The number of discretization steps.
1140
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1141
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1142
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1143
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1144
+ gamma (`float`, *optional*, defaults to 0.5):
1145
+ The controller guidance for the forward ODE, balancing faithfulness & editability:
1146
+ higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.
1147
+ timesteps (`List[int]`, *optional*):
1148
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1149
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1150
+ passed will be used. Must be in descending order.
1151
+ """
1152
+ dtype = dtype or self.text_encoder.dtype
1153
+ batch_size = 1
1154
+ self._joint_attention_kwargs = joint_attention_kwargs
1155
+ num_channels_latents = self.transformer.config.in_channels // 4
1156
+
1157
+ height = height or self.default_sample_size * self.vae_scale_factor
1158
+ width = width or self.default_sample_size * self.vae_scale_factor
1159
+ device = self._execution_device
1160
+
1161
+ # 1. prepare image
1162
+ image_latents, _ = self.encode_image(
1163
+ image, height=height, width=width, dtype=dtype
1164
+ )
1165
+ image_latents, latent_image_ids = self.prepare_latents_inversion(
1166
+ batch_size,
1167
+ num_channels_latents,
1168
+ height,
1169
+ width,
1170
+ dtype,
1171
+ device,
1172
+ image_latents,
1173
+ )
1174
+ #####################################################
1175
+ n_h = height // 16
1176
+ n_w = width // 16
1177
+ bsz, _, dim = image_latents.shape
1178
+
1179
+ image_latents = image_latents.reshape(bsz, n_h, n_w, dim)
1180
+ first_col = image_latents[:, :, 0:1, :]
1181
+ last_col = image_latents[:, :, -1:, :]
1182
+ image_latents = torch.cat([last_col, image_latents, first_col], dim=2)
1183
+ image_latents = image_latents.reshape(bsz, -1, dim)
1184
+
1185
+ latent_image_ids = latent_image_ids.reshape(n_h, n_w, 3)
1186
+ first_col = latent_image_ids[:, 0:1, :]
1187
+ last_col = latent_image_ids[:, -1:, :]
1188
+ latent_image_ids = torch.cat([last_col, latent_image_ids, first_col], dim=1)
1189
+ latent_image_ids = latent_image_ids.reshape(-1, 3)
1190
+ #####################################################
1191
+
1192
+ # 2. prepare timesteps
1193
+ sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)
1194
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (
1195
+ int(width) // self.vae_scale_factor // 2
1196
+ )
1197
+ mu = calculate_shift(
1198
+ image_seq_len,
1199
+ self.scheduler.config.base_image_seq_len,
1200
+ self.scheduler.config.max_image_seq_len,
1201
+ self.scheduler.config.base_shift,
1202
+ self.scheduler.config.max_shift,
1203
+ )
1204
+ timesteps, num_inversion_steps = retrieve_timesteps(
1205
+ self.scheduler,
1206
+ num_inversion_steps,
1207
+ device,
1208
+ timesteps,
1209
+ sigmas,
1210
+ mu=mu,
1211
+ )
1212
+ timesteps, sigmas, num_inversion_steps = self.get_timesteps(
1213
+ num_inversion_steps, strength
1214
+ )
1215
+
1216
+ # 3. prepare text embeddings
1217
+ (
1218
+ prompt_embeds,
1219
+ pooled_prompt_embeds,
1220
+ text_ids,
1221
+ ) = self.encode_prompt(
1222
+ prompt=source_prompt,
1223
+ prompt_2=source_prompt,
1224
+ device=device,
1225
+ )
1226
+ # 4. handle guidance
1227
+ if self.transformer.config.guidance_embeds:
1228
+ guidance = torch.full(
1229
+ [1], source_guidance_scale, device=device, dtype=torch.float32
1230
+ )
1231
+ else:
1232
+ guidance = None
1233
+
1234
+ # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt
1235
+ Y_t = image_latents
1236
+ y_1 = torch.randn(
1237
+ Y_t.shape, device=Y_t.device, dtype=Y_t.dtype, generator=generator
1238
+ )
1239
+ N = len(sigmas)
1240
+
1241
+ # forward ODE loop
1242
+ with self.progress_bar(total=N - 1) as progress_bar:
1243
+ for i in range(N - 1):
1244
+ t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device)
1245
+ timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(
1246
+ batch_size
1247
+ )
1248
+ # invert_timestep_list.append(timestep.item() * 1000)
1249
+ # get the unconditional vector field
1250
+ u_t_i = self.transformer(
1251
+ hidden_states=Y_t,
1252
+ timestep=timestep,
1253
+ guidance=guidance,
1254
+ pooled_projections=pooled_prompt_embeds,
1255
+ encoder_hidden_states=prompt_embeds,
1256
+ txt_ids=text_ids,
1257
+ img_ids=latent_image_ids,
1258
+ joint_attention_kwargs=self.joint_attention_kwargs,
1259
+ return_dict=False,
1260
+ )[0]
1261
+
1262
+ # get the conditional vector field
1263
+ u_t_i_cond = (y_1 - Y_t) / (1 - t_i)
1264
+
1265
+ # controlled vector field
1266
+ # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt
1267
+ u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i)
1268
+ Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])
1269
+ progress_bar.update()
1270
+
1271
+ # return the inverted latents (start point for the denoising loop), encoded image & latent image ids
1272
+ return Y_t, image_latents, latent_image_ids
pa_src/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ from typing import Optional
6
+
7
+
8
+ def shift_tensor(tensor, x):
9
+ shifted_tensor = torch.zeros_like(tensor)
10
+
11
+ if x > 0:
12
+ shifted_tensor[:, x:] = tensor[:, :-x]
13
+ elif x < 0:
14
+ shifted_tensor[:, :x] = tensor[:, -x:]
15
+ else:
16
+ shifted_tensor = tensor # No shift for x == 0
17
+
18
+ return shifted_tensor
19
+
20
+
21
+ def create_mask(input_image_path, w=64, h=64):
22
+ img = (
23
+ Image.open(input_image_path)
24
+ .resize((w, h), Image.Resampling.NEAREST)
25
+ .convert("L")
26
+ )
27
+ img_array = np.array(img)
28
+ mask = np.where(img_array == 255, 1, 0)
29
+ mask_tensor = torch.tensor(mask).int()
30
+
31
+ return mask_tensor
32
+
33
+
34
+ def save_array_as_png(array, path):
35
+ if array.dtype != np.uint8:
36
+ array = (array * 255).clip(0, 255).astype(np.uint8)
37
+ image = Image.fromarray(array, "RGBA")
38
+ image.save(path)
39
+
40
+
41
+ def convert_to_mask_inpainting(image_array, mask_path):
42
+ if image_array.shape[2] != 4:
43
+ raise ValueError("输入数组必须是 RGBA 格式")
44
+ mask = np.ones(image_array.shape[:2], dtype=np.uint8) * 255
45
+ alpha_channel = image_array[:, :, 3]
46
+ mask[alpha_channel != 0] = 0
47
+ mask_image = Image.fromarray(mask, mode="L")
48
+ mask_image.save(mask_path)
49
+
50
+ return mask_image
51
+
52
+
53
+ # mask for Subject Customiztion
54
+ def composite_images(background_path: str, mask_path: str) -> Image.Image:
55
+ background = Image.open(background_path).convert("RGBA")
56
+ mask = Image.open(mask_path).convert("L")
57
+
58
+ if background.size != mask.size:
59
+ mask = mask.resize(background.size)
60
+
61
+ mask_array = np.array(mask) > 128
62
+
63
+ if background.mode == "RGBA":
64
+ white_canvas = Image.new("RGBA", background.size, (255, 255, 255, 255))
65
+ else:
66
+ white_canvas = Image.new("RGB", background.size, (255, 255, 255))
67
+
68
+ composite = Image.composite(background, white_canvas, Image.fromarray(mask_array))
69
+
70
+ return composite.convert("RGB")
71
+
72
+
73
+ def process_mask_array(mask_array: np.ndarray) -> Image.Image:
74
+ alpha = mask_array[..., 3]
75
+ gray_array = np.where(alpha > 0, 0, 255).astype(np.uint8)
76
+ mask_image = Image.fromarray(gray_array, mode="L")
77
+ return mask_image.convert("1")
78
+
79
+
80
+ def process_mask(mask: Image.Image) -> Image.Image:
81
+ if mask.mode != "L":
82
+ mask = mask.convert("L")
83
+ return mask.point(lambda x: 1 if x > 128 else 0, mode="1")
84
+
85
+
86
+ def merge_masks(mask1: Image.Image, mask2: Image.Image) -> Image.Image:
87
+ arr1 = np.array(mask1, dtype=bool)
88
+ arr2 = np.array(mask2, dtype=bool)
89
+ merged = np.logical_and(arr1, arr2)
90
+ return Image.fromarray(merged).convert("1")
91
+
92
+
93
+ def save_merged_mask(
94
+ mask_array: np.ndarray, mask: Optional[Image.Image], output_path: str
95
+ ) -> None:
96
+ mask1 = process_mask_array(mask_array)
97
+
98
+ if mask is not None:
99
+ mask2 = process_mask(mask)
100
+ if mask1.size != mask2.size:
101
+ mask2 = mask2.resize(mask1.size, Image.NEAREST)
102
+ merged = merge_masks(mask1, mask2)
103
+ else:
104
+ merged = mask1
105
+
106
+ merged.save(output_path)