dominoer commited on
Commit
50651a6
·
verified ·
1 Parent(s): eca6b2d

Upload FlowEdit_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. FlowEdit_utils.py +684 -0
FlowEdit_utils.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from PIL import Image
4
+ from diffusers import FlowMatchEulerDiscreteScheduler
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
9
+
10
+
11
+ def resize_image_for_flux(
12
+ image: Image.Image,
13
+ max_short_edge: int = 1024,
14
+ ) -> Tuple[Image.Image, bool]:
15
+ """
16
+ Resize image if short edge exceeds max_short_edge.
17
+ Maintains aspect ratio and ensures dimensions are divisible by 16.
18
+
19
+ Args:
20
+ image: PIL Image to resize
21
+ max_short_edge: Maximum size for shorter edge (default: 1024)
22
+
23
+ Returns:
24
+ Tuple of (resized_image, was_resized)
25
+ """
26
+ w, h = image.size
27
+ short_edge = min(w, h)
28
+
29
+ if short_edge <= max_short_edge:
30
+ # Only ensure divisible by 16
31
+ new_w = (w // 16) * 16
32
+ new_h = (h // 16) * 16
33
+ if new_w != w or new_h != h:
34
+ image = image.resize((new_w, new_h), Image.LANCZOS)
35
+ return image, True
36
+ return image, False
37
+
38
+ # Calculate new dimensions maintaining aspect ratio
39
+ scale = max_short_edge / short_edge
40
+ new_w = int(w * scale)
41
+ new_h = int(h * scale)
42
+
43
+ # Ensure divisible by 16
44
+ new_w = (new_w // 16) * 16
45
+ new_h = (new_h // 16) * 16
46
+
47
+ image_resized = image.resize((new_w, new_h), Image.LANCZOS)
48
+ print(f" Resized for FLUX: {w}x{h} -> {new_w}x{new_h}")
49
+
50
+ return image_resized, True
51
+
52
+
53
+ def load_and_resize_image(
54
+ image_path: str,
55
+ max_short_edge: int = 1024,
56
+ ) -> Image.Image:
57
+ """
58
+ Load image and resize if necessary.
59
+
60
+ Args:
61
+ image_path: Path to image file
62
+ max_short_edge: Maximum size for shorter edge
63
+
64
+ Returns:
65
+ PIL Image (resized if needed)
66
+ """
67
+ image = Image.open(image_path).convert("RGB")
68
+ image, _ = resize_image_for_flux(image, max_short_edge)
69
+ return image
70
+
71
+
72
+
73
+ def scale_noise(
74
+ scheduler,
75
+ sample: torch.FloatTensor,
76
+ timestep: Union[float, torch.FloatTensor],
77
+ noise: Optional[torch.FloatTensor] = None,
78
+ ) -> torch.FloatTensor:
79
+ """
80
+ Foward process in flow-matching
81
+
82
+ Args:
83
+ sample (`torch.FloatTensor`):
84
+ The input sample.
85
+ timestep (`int`, *optional*):
86
+ The current timestep in the diffusion chain.
87
+
88
+ Returns:
89
+ `torch.FloatTensor`:
90
+ A scaled input sample.
91
+ """
92
+ # if scheduler.step_index is None:
93
+ scheduler._init_step_index(timestep)
94
+
95
+ sigma = scheduler.sigmas[scheduler.step_index]
96
+ sample = sigma * noise + (1.0 - sigma) * sample
97
+
98
+ return sample
99
+
100
+
101
+ # for flux
102
+ def calculate_shift(
103
+ image_seq_len,
104
+ base_seq_len: int = 256,
105
+ max_seq_len: int = 4096,
106
+ base_shift: float = 0.5,
107
+ max_shift: float = 1.16,
108
+ ):
109
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
110
+ b = base_shift - m * base_seq_len
111
+ mu = image_seq_len * m + b
112
+ return mu
113
+
114
+
115
+
116
+ def calc_v_sd3(pipe, src_tar_latent_model_input, src_tar_prompt_embeds, src_tar_pooled_prompt_embeds, src_guidance_scale, tar_guidance_scale, t):
117
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
118
+ timestep = t.expand(src_tar_latent_model_input.shape[0])
119
+ # joint_attention_kwargs = {}
120
+ # # add timestep to joint_attention_kwargs
121
+ # joint_attention_kwargs["timestep"] = timestep[0]
122
+ # joint_attention_kwargs["timestep_idx"] = i
123
+
124
+
125
+ with torch.no_grad():
126
+ # # predict the noise for the source prompt
127
+ noise_pred_src_tar = pipe.transformer(
128
+ hidden_states=src_tar_latent_model_input,
129
+ timestep=timestep,
130
+ encoder_hidden_states=src_tar_prompt_embeds,
131
+ pooled_projections=src_tar_pooled_prompt_embeds,
132
+ joint_attention_kwargs=None,
133
+ return_dict=False,
134
+ )[0]
135
+
136
+ # perform guidance source
137
+ if pipe.do_classifier_free_guidance:
138
+ src_noise_pred_uncond, src_noise_pred_text, tar_noise_pred_uncond, tar_noise_pred_text = noise_pred_src_tar.chunk(4)
139
+ noise_pred_src = src_noise_pred_uncond + src_guidance_scale * (src_noise_pred_text - src_noise_pred_uncond)
140
+ noise_pred_tar = tar_noise_pred_uncond + tar_guidance_scale * (tar_noise_pred_text - tar_noise_pred_uncond)
141
+
142
+ return noise_pred_src, noise_pred_tar
143
+
144
+
145
+
146
+ def calc_v_zimage(pipe, latents_list, prompt_embeds_list, src_guidance_scale, tar_guidance_scale, t):
147
+ """
148
+ ZImage用の速度場計算
149
+
150
+ Args:
151
+ pipe: ZImagePipeline
152
+ latents_list: List[Tensor] - [src_uncond, src_cond, tar_uncond, tar_cond] の4要素
153
+ prompt_embeds_list: List[Tensor] - 対応するprompt embeddings
154
+ src_guidance_scale: float - ソースプロンプトのCFGスケール
155
+ tar_guidance_scale: float - ターゲットプロンプトのCFGスケール
156
+ t: Tensor - タイムステップ (0-1000)
157
+
158
+ Returns:
159
+ noise_pred_src, noise_pred_tar: CFG適用後の速度場
160
+ """
161
+ # timestepを正規化 (ZImageは (1000-t)/1000 形式)
162
+ timestep = (1000 - t) / 1000
163
+ timestep = timestep.expand(len(latents_list))
164
+
165
+ # latentsをList[Tensor]形式に変換
166
+ # 入力: (C, H, W) -> 出力: (C, 1, H, W) でF(フレーム)次元を追加
167
+ # transformerのdtypeに合わせる
168
+ transformer_dtype = pipe.transformer.dtype
169
+ latent_model_input_list = [lat.unsqueeze(1).to(transformer_dtype) for lat in latents_list]
170
+
171
+ with torch.no_grad():
172
+ # transformer forward
173
+ noise_pred_list = pipe.transformer(
174
+ latent_model_input_list,
175
+ timestep,
176
+ prompt_embeds_list,
177
+ return_dict=False,
178
+ )[0]
179
+
180
+ # squeeze(1)でF次元を戻し、符号反転(ZImageの仕様)
181
+ # 出力: (C, 1, H, W) -> (C, H, W)
182
+ noise_pred_list = [-pred.squeeze(1) for pred in noise_pred_list]
183
+
184
+ # CFG適用: [src_uncond, src_cond, tar_uncond, tar_cond]
185
+ src_noise_pred_uncond = noise_pred_list[0]
186
+ src_noise_pred_cond = noise_pred_list[1]
187
+ tar_noise_pred_uncond = noise_pred_list[2]
188
+ tar_noise_pred_cond = noise_pred_list[3]
189
+
190
+ noise_pred_src = src_noise_pred_uncond + src_guidance_scale * (src_noise_pred_cond - src_noise_pred_uncond)
191
+ noise_pred_tar = tar_noise_pred_uncond + tar_guidance_scale * (tar_noise_pred_cond - tar_noise_pred_uncond)
192
+
193
+ return noise_pred_src, noise_pred_tar
194
+
195
+
196
+ def calc_v_flux(pipe, latents, prompt_embeds, pooled_prompt_embeds, guidance, text_ids, latent_image_ids, t):
197
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
198
+ timestep = t.expand(latents.shape[0])
199
+ # joint_attention_kwargs = {}
200
+ # # add timestep to joint_attention_kwargs
201
+ # joint_attention_kwargs["timestep"] = timestep[0]
202
+ # joint_attention_kwargs["timestep_idx"] = i
203
+
204
+
205
+ with torch.no_grad():
206
+ # # predict the noise for the source prompt
207
+ noise_pred = pipe.transformer(
208
+ hidden_states=latents,
209
+ timestep=timestep / 1000,
210
+ guidance=guidance,
211
+ encoder_hidden_states=prompt_embeds,
212
+ txt_ids=text_ids,
213
+ img_ids=latent_image_ids,
214
+ pooled_projections=pooled_prompt_embeds,
215
+ joint_attention_kwargs=None,
216
+ return_dict=False,
217
+ )[0]
218
+
219
+ return noise_pred
220
+
221
+
222
+
223
+ @torch.no_grad()
224
+ def FlowEditSD3(pipe,
225
+ scheduler,
226
+ x_src,
227
+ src_prompt,
228
+ tar_prompt,
229
+ negative_prompt,
230
+ T_steps: int = 50,
231
+ n_avg: int = 1,
232
+ src_guidance_scale: float = 3.5,
233
+ tar_guidance_scale: float = 13.5,
234
+ n_min: int = 0,
235
+ n_max: int = 15,):
236
+
237
+ device = x_src.device
238
+
239
+ timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, device, timesteps=None)
240
+
241
+ num_warmup_steps = max(len(timesteps) - T_steps * scheduler.order, 0)
242
+ pipe._num_timesteps = len(timesteps)
243
+ pipe._guidance_scale = src_guidance_scale
244
+
245
+ # src prompts
246
+ (
247
+ src_prompt_embeds,
248
+ src_negative_prompt_embeds,
249
+ src_pooled_prompt_embeds,
250
+ src_negative_pooled_prompt_embeds,
251
+ ) = pipe.encode_prompt(
252
+ prompt=src_prompt,
253
+ prompt_2=None,
254
+ prompt_3=None,
255
+ negative_prompt=negative_prompt,
256
+ do_classifier_free_guidance=pipe.do_classifier_free_guidance,
257
+ device=device,
258
+ )
259
+
260
+ # tar prompts
261
+ pipe._guidance_scale = tar_guidance_scale
262
+ (
263
+ tar_prompt_embeds,
264
+ tar_negative_prompt_embeds,
265
+ tar_pooled_prompt_embeds,
266
+ tar_negative_pooled_prompt_embeds,
267
+ ) = pipe.encode_prompt(
268
+ prompt=tar_prompt,
269
+ prompt_2=None,
270
+ prompt_3=None,
271
+ negative_prompt=negative_prompt,
272
+ do_classifier_free_guidance=pipe.do_classifier_free_guidance,
273
+ device=device,
274
+ )
275
+
276
+ # CFG prep
277
+ src_tar_prompt_embeds = torch.cat([src_negative_prompt_embeds, src_prompt_embeds, tar_negative_prompt_embeds, tar_prompt_embeds], dim=0)
278
+ src_tar_pooled_prompt_embeds = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds, tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0)
279
+
280
+ # initialize our ODE Zt_edit_1=x_src
281
+ zt_edit = x_src.clone()
282
+
283
+ for i, t in tqdm(enumerate(timesteps)):
284
+
285
+ if T_steps - i > n_max:
286
+ continue
287
+
288
+ t_i = t/1000
289
+ if i+1 < len(timesteps):
290
+ t_im1 = (timesteps[i+1])/1000
291
+ else:
292
+ t_im1 = torch.zeros_like(t_i).to(t_i.device)
293
+
294
+ if T_steps - i > n_min:
295
+
296
+ # Calculate the average of the V predictions
297
+ V_delta_avg = torch.zeros_like(x_src)
298
+ for k in range(n_avg):
299
+
300
+ fwd_noise = torch.randn_like(x_src).to(x_src.device)
301
+
302
+ zt_src = (1-t_i)*x_src + (t_i)*fwd_noise
303
+
304
+ zt_tar = zt_edit + zt_src - x_src
305
+
306
+ src_tar_latent_model_input = torch.cat([zt_src, zt_src, zt_tar, zt_tar]) if pipe.do_classifier_free_guidance else (zt_src, zt_tar)
307
+
308
+ Vt_src, Vt_tar = calc_v_sd3(pipe, src_tar_latent_model_input,src_tar_prompt_embeds, src_tar_pooled_prompt_embeds, src_guidance_scale, tar_guidance_scale, t)
309
+
310
+ V_delta_avg += (1/n_avg) * (Vt_tar - Vt_src) # - (hfg-1)*( x_src))
311
+
312
+ # propagate direct ODE
313
+ zt_edit = zt_edit.to(torch.float32)
314
+
315
+ zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
316
+
317
+ zt_edit = zt_edit.to(V_delta_avg.dtype)
318
+
319
+ else: # i >= T_steps-n_min # regular sampling for last n_min steps
320
+
321
+ if i == T_steps-n_min:
322
+ # initialize SDEDIT-style generation phase
323
+ fwd_noise = torch.randn_like(x_src).to(x_src.device)
324
+ xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise)
325
+ xt_tar = zt_edit + xt_src - x_src
326
+
327
+ src_tar_latent_model_input = torch.cat([xt_tar, xt_tar, xt_tar, xt_tar]) if pipe.do_classifier_free_guidance else (xt_src, xt_tar)
328
+
329
+ _, Vt_tar = calc_v_sd3(pipe, src_tar_latent_model_input,src_tar_prompt_embeds, src_tar_pooled_prompt_embeds, src_guidance_scale, tar_guidance_scale, t)
330
+
331
+ xt_tar = xt_tar.to(torch.float32)
332
+
333
+ prev_sample = xt_tar + (t_im1 - t_i) * (Vt_tar)
334
+
335
+ prev_sample = prev_sample.to(noise_pred_tar.dtype)
336
+
337
+ xt_tar = prev_sample
338
+
339
+ return zt_edit if n_min == 0 else xt_tar
340
+
341
+
342
+
343
+ @torch.no_grad()
344
+ def FlowEditFLUX(pipe,
345
+ scheduler,
346
+ x_src,
347
+ src_prompt,
348
+ tar_prompt,
349
+ negative_prompt,
350
+ T_steps: int = 28,
351
+ n_avg: int = 1,
352
+ src_guidance_scale: float = 1.5,
353
+ tar_guidance_scale: float = 5.5,
354
+ n_min: int = 0,
355
+ n_max: int = 24,):
356
+
357
+ device = x_src.device
358
+ # Note: orig_height/width should match the actual image dimensions for correct latent_image_ids
359
+ # x_src is VAE-encoded latent (H/8, W/8), so multiply by vae_scale_factor to get original size
360
+ orig_height = x_src.shape[2] * pipe.vae_scale_factor
361
+ orig_width = x_src.shape[3] * pipe.vae_scale_factor
362
+ num_channels_latents = pipe.transformer.config.in_channels // 4
363
+
364
+ pipe.check_inputs(
365
+ prompt=src_prompt,
366
+ prompt_2=None,
367
+ height=orig_height,
368
+ width=orig_width,
369
+ callback_on_step_end_tensor_inputs=None,
370
+ max_sequence_length=512,
371
+ )
372
+
373
+ x_src, latent_src_image_ids = pipe.prepare_latents(batch_size= x_src.shape[0], num_channels_latents=num_channels_latents, height=orig_height, width=orig_width, dtype=x_src.dtype, device=x_src.device, generator=None,latents=x_src)
374
+ x_src_packed = pipe._pack_latents(x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3])
375
+ latent_tar_image_ids = latent_src_image_ids
376
+
377
+ # 5. Prepare timesteps
378
+ sigmas = np.linspace(1.0, 1 / T_steps, T_steps)
379
+ image_seq_len = x_src_packed.shape[1]
380
+ mu = calculate_shift(
381
+ image_seq_len,
382
+ scheduler.config.base_image_seq_len,
383
+ scheduler.config.max_image_seq_len,
384
+ scheduler.config.base_shift,
385
+ scheduler.config.max_shift,
386
+ )
387
+ timesteps, T_steps = retrieve_timesteps(
388
+ scheduler,
389
+ T_steps,
390
+ device,
391
+ timesteps=None,
392
+ sigmas=sigmas,
393
+ mu=mu,
394
+ )
395
+
396
+ num_warmup_steps = max(len(timesteps) - T_steps * pipe.scheduler.order, 0)
397
+ pipe._num_timesteps = len(timesteps)
398
+
399
+
400
+ # src prompts
401
+ (
402
+ src_prompt_embeds,
403
+ src_pooled_prompt_embeds,
404
+ src_text_ids,
405
+
406
+ ) = pipe.encode_prompt(
407
+ prompt=src_prompt,
408
+ prompt_2=None,
409
+ device=device,
410
+ )
411
+
412
+ # tar prompts
413
+ pipe._guidance_scale = tar_guidance_scale
414
+ (
415
+ tar_prompt_embeds,
416
+ tar_pooled_prompt_embeds,
417
+ tar_text_ids,
418
+ ) = pipe.encode_prompt(
419
+ prompt=tar_prompt,
420
+ prompt_2=None,
421
+ device=device,
422
+ )
423
+
424
+ # handle guidance
425
+ if pipe.transformer.config.guidance_embeds:
426
+ src_guidance = torch.tensor([src_guidance_scale], device=device)
427
+ src_guidance = src_guidance.expand(x_src_packed.shape[0])
428
+ tar_guidance = torch.tensor([tar_guidance_scale], device=device)
429
+ tar_guidance = tar_guidance.expand(x_src_packed.shape[0])
430
+ else:
431
+ src_guidance = None
432
+ tar_guidance = None
433
+
434
+ # initialize our ODE Zt_edit_1=x_src
435
+ zt_edit = x_src_packed.clone()
436
+
437
+ for i, t in tqdm(enumerate(timesteps)):
438
+
439
+ if T_steps - i > n_max:
440
+ continue
441
+
442
+ scheduler._init_step_index(t)
443
+ t_i = scheduler.sigmas[scheduler.step_index]
444
+ if i < len(timesteps):
445
+ t_im1 = scheduler.sigmas[scheduler.step_index + 1]
446
+ else:
447
+ t_im1 = t_i
448
+
449
+ if T_steps - i > n_min:
450
+
451
+ # Calculate the average of the V predictions
452
+ V_delta_avg = torch.zeros_like(x_src_packed)
453
+
454
+ for k in range(n_avg):
455
+
456
+
457
+ fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device)
458
+
459
+ zt_src = (1-t_i)*x_src_packed + (t_i)*fwd_noise
460
+
461
+ zt_tar = zt_edit + zt_src - x_src_packed
462
+
463
+ # Merge in the future to avoid double computation
464
+ Vt_src = calc_v_flux(pipe,
465
+ latents=zt_src,
466
+ prompt_embeds=src_prompt_embeds,
467
+ pooled_prompt_embeds=src_pooled_prompt_embeds,
468
+ guidance=src_guidance,
469
+ text_ids=src_text_ids,
470
+ latent_image_ids=latent_src_image_ids,
471
+ t=t)
472
+
473
+ Vt_tar = calc_v_flux(pipe,
474
+ latents=zt_tar,
475
+ prompt_embeds=tar_prompt_embeds,
476
+ pooled_prompt_embeds=tar_pooled_prompt_embeds,
477
+ guidance=tar_guidance,
478
+ text_ids=tar_text_ids,
479
+ latent_image_ids=latent_tar_image_ids,
480
+ t=t)
481
+
482
+ V_delta_avg += (1/n_avg) * (Vt_tar - Vt_src) # - (hfg-1)*( x_src))
483
+
484
+ # propagate direct ODE
485
+ zt_edit = zt_edit.to(torch.float32)
486
+
487
+ zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
488
+
489
+ zt_edit = zt_edit.to(V_delta_avg.dtype)
490
+
491
+ else: # i >= T_steps-n_min # regular sampling last n_min steps
492
+
493
+ if i == T_steps-n_min:
494
+ # initialize SDEDIT-style generation phase
495
+ fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device)
496
+ xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise)
497
+ xt_tar = zt_edit + xt_src - x_src_packed
498
+
499
+ Vt_tar = calc_v_flux(pipe,
500
+ latents=xt_tar,
501
+ prompt_embeds=tar_prompt_embeds,
502
+ pooled_prompt_embeds=tar_pooled_prompt_embeds,
503
+ guidance=tar_guidance,
504
+ text_ids=tar_text_ids,
505
+ latent_image_ids=latent_tar_image_ids,
506
+ t=t)
507
+
508
+
509
+ xt_tar = xt_tar.to(torch.float32)
510
+
511
+ prev_sample = xt_tar + (t_im1 - t_i) * (Vt_tar)
512
+
513
+ prev_sample = prev_sample.to(Vt_tar.dtype)
514
+ xt_tar = prev_sample
515
+ out = zt_edit if n_min == 0 else xt_tar
516
+ unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor)
517
+ return unpacked_out
518
+
519
+
520
+ @torch.no_grad()
521
+ def FlowEditZImage(pipe,
522
+ scheduler,
523
+ x_src,
524
+ src_prompt,
525
+ tar_prompt,
526
+ negative_prompt,
527
+ T_steps: int = 28,
528
+ n_avg: int = 1,
529
+ src_guidance_scale: float = 1.5,
530
+ tar_guidance_scale: float = 5.5,
531
+ n_min: int = 0,
532
+ n_max: int = 24,
533
+ max_sequence_length: int = 512,):
534
+ """
535
+ ZImage用のFlowEdit実装
536
+
537
+ Args:
538
+ pipe: ZImagePipeline
539
+ scheduler: FlowMatchEulerDiscreteScheduler
540
+ x_src: Tensor - ソース画像のlatent (B, C, H, W)
541
+ src_prompt: str - ソースプロンプト
542
+ tar_prompt: str - ターゲットプロンプト
543
+ negative_prompt: str - ネガティブプロンプト
544
+ T_steps: int - 総ステップ数
545
+ n_avg: int - 速度場の平均化回数
546
+ src_guidance_scale: float - ソースCFGスケール
547
+ tar_guidance_scale: float - ターゲットCFGスケール
548
+ n_min: int - 通常サンプリングに切り替える最終ステップ数
549
+ n_max: int - Flow編集を適用する最大ステップ数
550
+ max_sequence_length: int - プロンプトの最大シーケンス長
551
+
552
+ Returns:
553
+ Tensor - 編集後のlatent
554
+ """
555
+ device = x_src.device
556
+
557
+ # timestep準備(ZImageはcalculate_shiftを使用)
558
+ height = x_src.shape[2] * pipe.vae_scale_factor * 2
559
+ width = x_src.shape[3] * pipe.vae_scale_factor * 2
560
+ image_seq_len = (x_src.shape[2] // 2) * (x_src.shape[3] // 2)
561
+
562
+ mu = calculate_shift(
563
+ image_seq_len,
564
+ scheduler.config.get("base_image_seq_len", 256),
565
+ scheduler.config.get("max_image_seq_len", 4096),
566
+ scheduler.config.get("base_shift", 0.5),
567
+ scheduler.config.get("max_shift", 1.15),
568
+ )
569
+ scheduler.sigma_min = 0.0
570
+ timesteps, T_steps = retrieve_timesteps(
571
+ scheduler,
572
+ T_steps,
573
+ device,
574
+ sigmas=None,
575
+ mu=mu,
576
+ )
577
+
578
+ # プロンプトエンコード
579
+ # ソースプロンプト
580
+ src_prompt_embeds, src_negative_prompt_embeds = pipe.encode_prompt(
581
+ prompt=src_prompt,
582
+ device=device,
583
+ do_classifier_free_guidance=True,
584
+ negative_prompt=negative_prompt,
585
+ max_sequence_length=max_sequence_length,
586
+ )
587
+
588
+ # ターゲットプロンプト
589
+ tar_prompt_embeds, tar_negative_prompt_embeds = pipe.encode_prompt(
590
+ prompt=tar_prompt,
591
+ device=device,
592
+ do_classifier_free_guidance=True,
593
+ negative_prompt=negative_prompt,
594
+ max_sequence_length=max_sequence_length,
595
+ )
596
+
597
+ # prompt_embeds_list: [src_uncond, src_cond, tar_uncond, tar_cond]
598
+ # ZImageのencode_promptはList[Tensor]を返すので、要素を取り出す
599
+ src_neg_emb = src_negative_prompt_embeds[0] if isinstance(src_negative_prompt_embeds, list) else src_negative_prompt_embeds
600
+ src_pos_emb = src_prompt_embeds[0] if isinstance(src_prompt_embeds, list) else src_prompt_embeds
601
+ tar_neg_emb = tar_negative_prompt_embeds[0] if isinstance(tar_negative_prompt_embeds, list) else tar_negative_prompt_embeds
602
+ tar_pos_emb = tar_prompt_embeds[0] if isinstance(tar_prompt_embeds, list) else tar_prompt_embeds
603
+
604
+ prompt_embeds_list = [src_neg_emb, src_pos_emb, tar_neg_emb, tar_pos_emb]
605
+
606
+ # initialize ODE: zt_edit = x_src
607
+ zt_edit = x_src.clone()
608
+
609
+ for i, t in tqdm(enumerate(timesteps)):
610
+
611
+ if T_steps - i > n_max:
612
+ continue
613
+
614
+ # タイムステップの計算
615
+ scheduler._init_step_index(t)
616
+ t_i = scheduler.sigmas[scheduler.step_index]
617
+ if scheduler.step_index + 1 < len(scheduler.sigmas):
618
+ t_im1 = scheduler.sigmas[scheduler.step_index + 1]
619
+ else:
620
+ t_im1 = torch.zeros_like(t_i)
621
+
622
+ if T_steps - i > n_min:
623
+ # Flow-based editing phase
624
+
625
+ V_delta_avg = torch.zeros_like(x_src)
626
+
627
+ for k in range(n_avg):
628
+ # ランダムノイズ
629
+ fwd_noise = torch.randn_like(x_src).to(device)
630
+
631
+ # 順方向プロセス: ソース軌道
632
+ zt_src = (1 - t_i) * x_src + t_i * fwd_noise
633
+
634
+ # ターゲット軌道(オフセット維持)
635
+ zt_tar = zt_edit + zt_src - x_src
636
+
637
+ # latents_list: [src_uncond, src_cond, tar_uncond, tar_cond]
638
+ latents_list = [zt_src.squeeze(0), zt_src.squeeze(0), zt_tar.squeeze(0), zt_tar.squeeze(0)]
639
+
640
+ # 速度場計算
641
+ Vt_src, Vt_tar = calc_v_zimage(
642
+ pipe,
643
+ latents_list,
644
+ prompt_embeds_list,
645
+ src_guidance_scale,
646
+ tar_guidance_scale,
647
+ t
648
+ )
649
+
650
+ # 速度場の差分を蓄積
651
+ V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src).unsqueeze(0)
652
+
653
+ # ODE更新
654
+ zt_edit = zt_edit.to(torch.float32)
655
+ zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
656
+ zt_edit = zt_edit.to(V_delta_avg.dtype)
657
+
658
+ else: # 通常サンプリング(最後のn_minステップ)
659
+
660
+ if i == T_steps - n_min:
661
+ # SDEDIT-style generation phaseの初期化
662
+ fwd_noise = torch.randn_like(x_src).to(device)
663
+ xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise)
664
+ xt_tar = zt_edit + xt_src - x_src
665
+
666
+ # ターゲットのみで速度場計算
667
+ latents_list = [xt_tar.squeeze(0), xt_tar.squeeze(0), xt_tar.squeeze(0), xt_tar.squeeze(0)]
668
+
669
+ _, Vt_tar = calc_v_zimage(
670
+ pipe,
671
+ latents_list,
672
+ prompt_embeds_list,
673
+ src_guidance_scale,
674
+ tar_guidance_scale,
675
+ t
676
+ )
677
+
678
+ # ODE更新
679
+ xt_tar = xt_tar.to(torch.float32)
680
+ prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar.unsqueeze(0)
681
+ prev_sample = prev_sample.to(Vt_tar.dtype)
682
+ xt_tar = prev_sample
683
+
684
+ return zt_edit if n_min == 0 else xt_tar