Add pipeline code for trust_remote_code

#3
by kahnchana - opened
Files changed (1) hide show
  1. pipeline_fofpred.py +894 -0
pipeline_fofpred.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FOFPred Diffusion Pipeline.
3
+
4
+ Modified from OmniGen2 Diffusion Pipeline (By OmniGen2 Team and The HuggingFace Team).
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import inspect
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL.Image
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from diffusers.models.autoencoders import AutoencoderKL
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
30
+ from diffusers.utils import (
31
+ BaseOutput,
32
+ is_torch_xla_available,
33
+ logging,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from transformers import Qwen2_5_VLForConditionalGeneration
37
+
38
+ from fofpred.pipelines.image_processor import OmniGen2ImageProcessor
39
+ from fofpred.utils.teacache_util import TeaCacheParams
40
+
41
+ from ...models.transformers import OmniGen2Transformer3DModel
42
+ from ...models.transformers.repo import OmniGen2RotaryPosEmbed
43
+ from ..lora_pipeline import OmniGen2LoraLoaderMixin
44
+
45
+ if is_torch_xla_available():
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
50
+ from ...cache_functions import cache_init
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+
55
+ @dataclass
56
+ class FMPipelineOutput(BaseOutput):
57
+ """
58
+ Output class for OmniGen2 pipeline.
59
+
60
+ Args:
61
+ images (Union[List[PIL.Image.Image], np.ndarray]):
62
+ List of denoised PIL images of length `batch_size` or numpy array of shape
63
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
64
+ """
65
+
66
+ images: Union[List[PIL.Image.Image], np.ndarray]
67
+
68
+
69
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
70
+ def retrieve_timesteps(
71
+ scheduler,
72
+ num_inference_steps: Optional[int] = None,
73
+ device: Optional[Union[str, torch.device]] = None,
74
+ timesteps: Optional[List[int]] = None,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
79
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
80
+
81
+ Args:
82
+ scheduler (`SchedulerMixin`):
83
+ The scheduler to get timesteps from.
84
+ num_inference_steps (`int`):
85
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
86
+ must be `None`.
87
+ device (`str` or `torch.device`, *optional*):
88
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
89
+ timesteps (`List[int]`, *optional*):
90
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
91
+ `num_inference_steps` and `sigmas` must be `None`.
92
+ sigmas (`List[float]`, *optional*):
93
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
94
+ `num_inference_steps` and `timesteps` must be `None`.
95
+
96
+ Returns:
97
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
98
+ second element is the number of inference steps.
99
+ """
100
+ if timesteps is not None:
101
+ accepts_timesteps = "timesteps" in set(
102
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
103
+ )
104
+ if not accepts_timesteps:
105
+ raise ValueError(
106
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
107
+ f" timestep schedules. Please check whether you are using the correct scheduler."
108
+ )
109
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ num_inference_steps = len(timesteps)
112
+ else:
113
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
114
+ timesteps = scheduler.timesteps
115
+ return timesteps, num_inference_steps
116
+
117
+
118
+ class FOFPredPipeline(DiffusionPipeline, OmniGen2LoraLoaderMixin):
119
+ """
120
+ Pipeline for text-to-image generation using OmniGen2.
121
+
122
+ This pipeline implements a text-to-image generation model that uses:
123
+ - Qwen2.5-VL for text encoding
124
+ - A custom transformer architecture for image generation
125
+ - VAE for image encoding/decoding
126
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
127
+
128
+ Args:
129
+ transformer (OmniGen2Transformer3DModel): The transformer model for image generation.
130
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
131
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
132
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
133
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
134
+ """
135
+
136
+ model_cpu_offload_seq = "mllm->transformer->vae"
137
+
138
+ def __init__(
139
+ self,
140
+ transformer: OmniGen2Transformer3DModel,
141
+ vae: AutoencoderKL,
142
+ scheduler: FlowMatchEulerDiscreteScheduler,
143
+ mllm: Qwen2_5_VLForConditionalGeneration,
144
+ processor,
145
+ ) -> None:
146
+ """
147
+ Initialize the OmniGen2 pipeline.
148
+
149
+ Args:
150
+ transformer: The transformer model for image generation.
151
+ vae: The VAE model for image encoding/decoding.
152
+ scheduler: The scheduler for noise scheduling.
153
+ text_encoder: The text encoder model.
154
+ tokenizer: The tokenizer for text processing.
155
+ """
156
+ super().__init__()
157
+
158
+ self.register_modules(
159
+ transformer=transformer,
160
+ vae=vae,
161
+ scheduler=scheduler,
162
+ mllm=mllm,
163
+ processor=processor,
164
+ )
165
+ self.vae_scale_factor = (
166
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
167
+ if hasattr(self, "vae") and self.vae is not None
168
+ else 8
169
+ )
170
+ self.image_processor = OmniGen2ImageProcessor(
171
+ vae_scale_factor=self.vae_scale_factor * 2, do_resize=True
172
+ )
173
+ self.default_sample_size = 128
174
+
175
+ def prepare_latents(
176
+ self,
177
+ batch_size: int,
178
+ num_channels_latents: int,
179
+ height: int,
180
+ width: int,
181
+ dtype: torch.dtype,
182
+ device: torch.device,
183
+ generator: Optional[torch.Generator],
184
+ latents: Optional[torch.FloatTensor] = None,
185
+ frame_count: int = 1,
186
+ ) -> torch.FloatTensor:
187
+ """
188
+ Prepare the initial latents for the diffusion process.
189
+
190
+ Args:
191
+ batch_size: The number of images to generate.
192
+ num_channels_latents: The number of channels in the latent space.
193
+ height: The height of the generated image.
194
+ width: The width of the generated image.
195
+ dtype: The data type of the latents.
196
+ device: The device to place the latents on.
197
+ generator: The random number generator to use.
198
+ latents: Optional pre-computed latents to use instead of random initialization.
199
+ frame_count: The number of frames to output.
200
+
201
+ Returns:
202
+ torch.FloatTensor: The prepared latents tensor.
203
+ """
204
+ height = int(height) // self.vae_scale_factor
205
+ width = int(width) // self.vae_scale_factor
206
+
207
+ if frame_count > 1:
208
+ shape = (batch_size, frame_count, num_channels_latents, height, width)
209
+ else:
210
+ shape = (batch_size, num_channels_latents, height, width)
211
+
212
+ if latents is None:
213
+ latents = randn_tensor(
214
+ shape, generator=generator, device=device, dtype=dtype
215
+ )
216
+ else:
217
+ latents = latents.to(device)
218
+ return latents
219
+
220
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
221
+ """
222
+ Encode an image into the VAE latent space.
223
+
224
+ Args:
225
+ img: The input image tensor to encode.
226
+
227
+ Returns:
228
+ torch.FloatTensor: The encoded latent representation.
229
+ """
230
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
231
+ if self.vae.config.shift_factor is not None:
232
+ z0 = z0 - self.vae.config.shift_factor
233
+ if self.vae.config.scaling_factor is not None:
234
+ z0 = z0 * self.vae.config.scaling_factor
235
+ z0 = z0.to(dtype=self.vae.dtype)
236
+ return z0
237
+
238
+ def prepare_image(
239
+ self,
240
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
241
+ batch_size: int,
242
+ num_images_per_prompt: int,
243
+ max_pixels: int,
244
+ max_side_length: int,
245
+ device: torch.device,
246
+ dtype: torch.dtype,
247
+ ) -> List[Optional[torch.FloatTensor]]:
248
+ """
249
+ Prepare input images for processing by encoding them into the VAE latent space.
250
+
251
+ Args:
252
+ images: Single image or list of images to process.
253
+ batch_size: The number of images to generate per prompt.
254
+ num_images_per_prompt: The number of images to generate for each prompt.
255
+ device: The device to place the encoded latents on.
256
+ dtype: The data type of the encoded latents.
257
+
258
+ Returns:
259
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
260
+ """
261
+ if batch_size == 1:
262
+ images = [images]
263
+ latents = []
264
+ for i, img in enumerate(images):
265
+ if img is not None and len(img) > 0:
266
+ ref_latents = []
267
+ for j, img_j in enumerate(img):
268
+ img_j = self.image_processor.preprocess(
269
+ img_j, max_pixels=max_pixels, max_side_length=max_side_length
270
+ )
271
+ ref_latents.append(
272
+ self.encode_vae(img_j.to(device=device)).squeeze(0)
273
+ )
274
+ else:
275
+ ref_latents = None
276
+ for _ in range(num_images_per_prompt):
277
+ latents.append(ref_latents)
278
+
279
+ return latents
280
+
281
+ def _get_qwen2_prompt_embeds(
282
+ self,
283
+ prompt: Union[str, List[str]],
284
+ device: Optional[torch.device] = None,
285
+ max_sequence_length: int = 256,
286
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
287
+ """
288
+ Get prompt embeddings from the Qwen2 text encoder.
289
+
290
+ Args:
291
+ prompt: The prompt or list of prompts to encode.
292
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
293
+ max_sequence_length: Maximum sequence length for tokenization.
294
+
295
+ Returns:
296
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
297
+ - The prompt embeddings tensor
298
+ - The attention mask tensor
299
+
300
+ Raises:
301
+ Warning: If the input text is truncated due to sequence length limitations.
302
+ """
303
+ device = device or self._execution_device
304
+ prompt = [prompt] if isinstance(prompt, str) else prompt
305
+ # text_inputs = self.processor.tokenizer(
306
+ # prompt,
307
+ # padding="max_length",
308
+ # max_length=max_sequence_length,
309
+ # truncation=True,
310
+ # return_tensors="pt",
311
+ # )
312
+ text_inputs = self.processor.tokenizer(
313
+ prompt,
314
+ padding="longest",
315
+ max_length=max_sequence_length,
316
+ truncation=True,
317
+ return_tensors="pt",
318
+ )
319
+
320
+ text_input_ids = text_inputs.input_ids.to(device)
321
+ untruncated_ids = self.processor.tokenizer(
322
+ prompt, padding="longest", return_tensors="pt"
323
+ ).input_ids.to(device)
324
+
325
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
326
+ text_input_ids, untruncated_ids
327
+ ):
328
+ removed_text = self.processor.tokenizer.batch_decode(
329
+ untruncated_ids[:, max_sequence_length - 1 : -1]
330
+ )
331
+ logger.warning(
332
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
333
+ f" {max_sequence_length} tokens: {removed_text}"
334
+ )
335
+
336
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
337
+ prompt_embeds = self.mllm(
338
+ text_input_ids,
339
+ attention_mask=prompt_attention_mask,
340
+ output_hidden_states=True,
341
+ ).hidden_states[-1]
342
+
343
+ if self.mllm is not None:
344
+ dtype = self.mllm.dtype
345
+ elif self.transformer is not None:
346
+ dtype = self.transformer.dtype
347
+ else:
348
+ dtype = None
349
+
350
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
351
+
352
+ return prompt_embeds, prompt_attention_mask
353
+
354
+ def _apply_chat_template(self, prompt: str):
355
+ prompt = [
356
+ {
357
+ "role": "system",
358
+ "content": "You are a helpful assistant that generates high-quality images based on user instructions.",
359
+ },
360
+ {"role": "user", "content": prompt},
361
+ ]
362
+ prompt = self.processor.tokenizer.apply_chat_template(
363
+ prompt, tokenize=False, add_generation_prompt=False
364
+ )
365
+ return prompt
366
+
367
+ def encode_prompt(
368
+ self,
369
+ prompt: Union[str, List[str]],
370
+ do_classifier_free_guidance: bool = True,
371
+ negative_prompt: Optional[Union[str, List[str]]] = None,
372
+ num_images_per_prompt: int = 1,
373
+ device: Optional[torch.device] = None,
374
+ prompt_embeds: Optional[torch.Tensor] = None,
375
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
376
+ prompt_attention_mask: Optional[torch.Tensor] = None,
377
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
378
+ max_sequence_length: int = 256,
379
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
380
+ r"""
381
+ Encodes the prompt into text encoder hidden states.
382
+
383
+ Args:
384
+ prompt (`str` or `List[str]`, *optional*):
385
+ prompt to be encoded
386
+ negative_prompt (`str` or `List[str]`, *optional*):
387
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
388
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
389
+ Lumina-T2I, this should be "".
390
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
391
+ whether to use classifier free guidance or not
392
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
393
+ number of images that should be generated per prompt
394
+ device: (`torch.device`, *optional*):
395
+ torch device to place the resulting embeddings on
396
+ prompt_embeds (`torch.Tensor`, *optional*):
397
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
398
+ provided, text embeddings will be generated from `prompt` input argument.
399
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
400
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
401
+ max_sequence_length (`int`, defaults to `256`):
402
+ Maximum sequence length to use for the prompt.
403
+ """
404
+ device = device or self._execution_device
405
+
406
+ prompt = [prompt] if isinstance(prompt, str) else prompt
407
+ prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
408
+
409
+ if prompt is not None:
410
+ batch_size = len(prompt)
411
+ else:
412
+ batch_size = prompt_embeds.shape[0]
413
+ if prompt_embeds is None:
414
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
415
+ prompt=prompt, device=device, max_sequence_length=max_sequence_length
416
+ )
417
+
418
+ batch_size, seq_len, _ = prompt_embeds.shape
419
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
420
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
421
+ prompt_embeds = prompt_embeds.view(
422
+ batch_size * num_images_per_prompt, seq_len, -1
423
+ )
424
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
425
+ prompt_attention_mask = prompt_attention_mask.view(
426
+ batch_size * num_images_per_prompt, -1
427
+ )
428
+
429
+ # Get negative embeddings for classifier free guidance
430
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
431
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
432
+
433
+ # Normalize str to list
434
+ negative_prompt = (
435
+ batch_size * [negative_prompt]
436
+ if isinstance(negative_prompt, str)
437
+ else negative_prompt
438
+ )
439
+ negative_prompt = [
440
+ self._apply_chat_template(_negative_prompt)
441
+ for _negative_prompt in negative_prompt
442
+ ]
443
+
444
+ if prompt is not None and type(prompt) is not type(negative_prompt):
445
+ raise TypeError(
446
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
447
+ f" {type(prompt)}."
448
+ )
449
+ elif isinstance(negative_prompt, str):
450
+ negative_prompt = [negative_prompt]
451
+ elif batch_size != len(negative_prompt):
452
+ raise ValueError(
453
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
454
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
455
+ " the batch size of `prompt`."
456
+ )
457
+ negative_prompt_embeds, negative_prompt_attention_mask = (
458
+ self._get_qwen2_prompt_embeds(
459
+ prompt=negative_prompt,
460
+ device=device,
461
+ max_sequence_length=max_sequence_length,
462
+ )
463
+ )
464
+
465
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
466
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
467
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
468
+ 1, num_images_per_prompt, 1
469
+ )
470
+ negative_prompt_embeds = negative_prompt_embeds.view(
471
+ batch_size * num_images_per_prompt, seq_len, -1
472
+ )
473
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
474
+ num_images_per_prompt, 1
475
+ )
476
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
477
+ batch_size * num_images_per_prompt, -1
478
+ )
479
+
480
+ return (
481
+ prompt_embeds,
482
+ prompt_attention_mask,
483
+ negative_prompt_embeds,
484
+ negative_prompt_attention_mask,
485
+ )
486
+
487
+ @property
488
+ def num_timesteps(self):
489
+ return self._num_timesteps
490
+
491
+ @property
492
+ def text_guidance_scale(self):
493
+ return self._text_guidance_scale
494
+
495
+ @property
496
+ def image_guidance_scale(self):
497
+ return self._image_guidance_scale
498
+
499
+ @property
500
+ def cfg_range(self):
501
+ return self._cfg_range
502
+
503
+ @torch.no_grad()
504
+ def __call__(
505
+ self,
506
+ prompt: Optional[Union[str, List[str]]] = None,
507
+ negative_prompt: Optional[Union[str, List[str]]] = None,
508
+ prompt_embeds: Optional[torch.FloatTensor] = None,
509
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
510
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
511
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
512
+ max_sequence_length: Optional[int] = None,
513
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
514
+ input_images: Optional[List[PIL.Image.Image]] = None,
515
+ num_images_per_prompt: int = 1,
516
+ height: Optional[int] = None,
517
+ width: Optional[int] = None,
518
+ max_pixels: int = 1024 * 1024,
519
+ max_input_image_side_length: int = 1024,
520
+ align_res: bool = True,
521
+ num_inference_steps: int = 28,
522
+ text_guidance_scale: float = 4.0,
523
+ image_guidance_scale: float = 1.0,
524
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
525
+ attention_kwargs: Optional[Dict[str, Any]] = None,
526
+ timesteps: List[int] = None,
527
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
528
+ latents: Optional[torch.FloatTensor] = None,
529
+ frame_count: int = 1,
530
+ output_type: Optional[str] = "pil",
531
+ return_dict: bool = True,
532
+ verbose: bool = False,
533
+ step_func=None,
534
+ get_latents_text_embeds=False,
535
+ ):
536
+ height = height or self.default_sample_size * self.vae_scale_factor
537
+ width = width or self.default_sample_size * self.vae_scale_factor
538
+
539
+ self._text_guidance_scale = text_guidance_scale
540
+ self._image_guidance_scale = image_guidance_scale
541
+ self._cfg_range = cfg_range
542
+ self._attention_kwargs = attention_kwargs
543
+
544
+ # 2. Define call parameters
545
+ if prompt is not None and isinstance(prompt, str):
546
+ batch_size = 1
547
+ elif prompt is not None and isinstance(prompt, list):
548
+ batch_size = len(prompt)
549
+ else:
550
+ batch_size = prompt_embeds.shape[0]
551
+
552
+ device = self._execution_device
553
+
554
+ # 3. Encode input prompt
555
+ (
556
+ prompt_embeds,
557
+ prompt_attention_mask,
558
+ negative_prompt_embeds,
559
+ negative_prompt_attention_mask,
560
+ ) = self.encode_prompt(
561
+ prompt,
562
+ self.text_guidance_scale > 1.0,
563
+ negative_prompt=negative_prompt,
564
+ num_images_per_prompt=num_images_per_prompt,
565
+ device=device,
566
+ prompt_embeds=prompt_embeds,
567
+ negative_prompt_embeds=negative_prompt_embeds,
568
+ prompt_attention_mask=prompt_attention_mask,
569
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
570
+ max_sequence_length=max_sequence_length,
571
+ )
572
+
573
+ dtype = self.vae.dtype
574
+ # 3. Prepare control image
575
+ ref_latents = self.prepare_image(
576
+ images=input_images,
577
+ batch_size=batch_size,
578
+ num_images_per_prompt=num_images_per_prompt,
579
+ max_pixels=max_pixels,
580
+ max_side_length=max_input_image_side_length,
581
+ device=device,
582
+ dtype=dtype,
583
+ )
584
+
585
+ if input_images is None:
586
+ input_images = []
587
+
588
+ if len(input_images) == 1 and align_res:
589
+ width, height = (
590
+ ref_latents[0][0].shape[-1] * self.vae_scale_factor,
591
+ ref_latents[0][0].shape[-2] * self.vae_scale_factor,
592
+ )
593
+ ori_width, ori_height = width, height
594
+ else:
595
+ ori_width, ori_height = width, height
596
+
597
+ cur_pixels = height * width
598
+ ratio = (max_pixels / cur_pixels) ** 0.5
599
+ ratio = min(ratio, 1.0)
600
+
601
+ height, width = (
602
+ int(height * ratio) // 16 * 16,
603
+ int(width * ratio) // 16 * 16,
604
+ )
605
+
606
+ if len(input_images) == 0:
607
+ self._image_guidance_scale = 1
608
+
609
+ # 4. Prepare latents.
610
+ latent_channels = self.transformer.config.in_channels
611
+ latents = self.prepare_latents(
612
+ batch_size * num_images_per_prompt,
613
+ latent_channels,
614
+ height,
615
+ width,
616
+ prompt_embeds.dtype,
617
+ device,
618
+ generator,
619
+ latents,
620
+ frame_count,
621
+ )
622
+
623
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
624
+ self.transformer.config.axes_dim_rope,
625
+ self.transformer.config.axes_lens,
626
+ theta=10000,
627
+ )
628
+
629
+ image = self.processing(
630
+ latents=latents,
631
+ ref_latents=ref_latents,
632
+ prompt_embeds=prompt_embeds,
633
+ freqs_cis=freqs_cis,
634
+ negative_prompt_embeds=negative_prompt_embeds,
635
+ prompt_attention_mask=prompt_attention_mask,
636
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
637
+ num_inference_steps=num_inference_steps,
638
+ timesteps=timesteps,
639
+ device=device,
640
+ dtype=dtype,
641
+ verbose=verbose,
642
+ step_func=step_func,
643
+ get_latents_text_embeds=get_latents_text_embeds,
644
+ )
645
+
646
+ if get_latents_text_embeds:
647
+ return image, prompt_embeds
648
+
649
+ if len(image.shape) == 4:
650
+ image = F.interpolate(image, size=(ori_height, ori_width), mode="bilinear")
651
+ image = self.image_processor.postprocess(image, output_type=output_type)
652
+ else:
653
+ image = [
654
+ F.interpolate(
655
+ image[:, i], size=(ori_height, ori_width), mode="bilinear"
656
+ )
657
+ for i in range(image.shape[1])
658
+ ]
659
+ image = [
660
+ self.image_processor.postprocess(x, output_type=output_type)
661
+ for x in image
662
+ ]
663
+ image = torch.stack(image, dim=1)
664
+
665
+ # Offload all models
666
+ self.maybe_free_model_hooks()
667
+
668
+ if not return_dict:
669
+ return image
670
+ else:
671
+ return FMPipelineOutput(images=image)
672
+
673
+ def processing(
674
+ self,
675
+ latents,
676
+ ref_latents,
677
+ prompt_embeds,
678
+ freqs_cis,
679
+ negative_prompt_embeds,
680
+ prompt_attention_mask,
681
+ negative_prompt_attention_mask,
682
+ num_inference_steps,
683
+ timesteps,
684
+ device,
685
+ dtype,
686
+ verbose,
687
+ step_func=None,
688
+ get_latents_text_embeds=False,
689
+ ):
690
+ batch_size = latents.shape[0]
691
+
692
+ timesteps, num_inference_steps = retrieve_timesteps(
693
+ self.scheduler,
694
+ num_inference_steps,
695
+ device,
696
+ timesteps,
697
+ num_tokens=latents.shape[-2] * latents.shape[-1],
698
+ )
699
+ num_warmup_steps = max(
700
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
701
+ )
702
+ self._num_timesteps = len(timesteps)
703
+
704
+ enable_taylorseer = getattr(self, "enable_taylorseer", False)
705
+ if enable_taylorseer:
706
+ model_pred_cache_dic, model_pred_current = cache_init(
707
+ self, num_inference_steps
708
+ )
709
+ model_pred_ref_cache_dic, model_pred_ref_current = cache_init(
710
+ self, num_inference_steps
711
+ )
712
+ model_pred_uncond_cache_dic, model_pred_uncond_current = cache_init(
713
+ self, num_inference_steps
714
+ )
715
+ self.transformer.enable_taylorseer = True
716
+ elif self.transformer.enable_teacache:
717
+ # Use different TeaCacheParams for different conditions
718
+ teacache_params = TeaCacheParams()
719
+ teacache_params_uncond = TeaCacheParams()
720
+ teacache_params_ref = TeaCacheParams()
721
+
722
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
723
+ for i, t in enumerate(timesteps):
724
+ if enable_taylorseer:
725
+ self.transformer.cache_dic = model_pred_cache_dic
726
+ self.transformer.current = model_pred_current
727
+ elif self.transformer.enable_teacache:
728
+ teacache_params.is_first_or_last_step = (
729
+ i == 0 or i == len(timesteps) - 1
730
+ )
731
+ self.transformer.teacache_params = teacache_params
732
+
733
+ model_pred = self.predict(
734
+ t=t,
735
+ latents=latents,
736
+ prompt_embeds=prompt_embeds,
737
+ freqs_cis=freqs_cis,
738
+ prompt_attention_mask=prompt_attention_mask,
739
+ ref_image_hidden_states=ref_latents,
740
+ )
741
+ text_guidance_scale = (
742
+ self.text_guidance_scale
743
+ if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1]
744
+ else 1.0
745
+ )
746
+ image_guidance_scale = (
747
+ self.image_guidance_scale
748
+ if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1]
749
+ else 1.0
750
+ )
751
+
752
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
753
+ if enable_taylorseer:
754
+ self.transformer.cache_dic = model_pred_ref_cache_dic
755
+ self.transformer.current = model_pred_ref_current
756
+ elif self.transformer.enable_teacache:
757
+ teacache_params_ref.is_first_or_last_step = (
758
+ i == 0 or i == len(timesteps) - 1
759
+ )
760
+ self.transformer.teacache_params = teacache_params_ref
761
+
762
+ model_pred_ref = self.predict(
763
+ t=t,
764
+ latents=latents,
765
+ prompt_embeds=negative_prompt_embeds,
766
+ freqs_cis=freqs_cis,
767
+ prompt_attention_mask=negative_prompt_attention_mask,
768
+ ref_image_hidden_states=ref_latents,
769
+ )
770
+
771
+ if enable_taylorseer:
772
+ self.transformer.cache_dic = model_pred_uncond_cache_dic
773
+ self.transformer.current = model_pred_uncond_current
774
+ elif self.transformer.enable_teacache:
775
+ teacache_params_uncond.is_first_or_last_step = (
776
+ i == 0 or i == len(timesteps) - 1
777
+ )
778
+ self.transformer.teacache_params = teacache_params_uncond
779
+
780
+ model_pred_uncond = self.predict(
781
+ t=t,
782
+ latents=latents,
783
+ prompt_embeds=negative_prompt_embeds,
784
+ freqs_cis=freqs_cis,
785
+ prompt_attention_mask=negative_prompt_attention_mask,
786
+ ref_image_hidden_states=None,
787
+ )
788
+
789
+ model_pred = (
790
+ model_pred_uncond
791
+ + image_guidance_scale * (model_pred_ref - model_pred_uncond)
792
+ + text_guidance_scale * (model_pred - model_pred_ref)
793
+ )
794
+ elif text_guidance_scale > 1.0:
795
+ if enable_taylorseer:
796
+ self.transformer.cache_dic = model_pred_uncond_cache_dic
797
+ self.transformer.current = model_pred_uncond_current
798
+ elif self.transformer.enable_teacache:
799
+ teacache_params_uncond.is_first_or_last_step = (
800
+ i == 0 or i == len(timesteps) - 1
801
+ )
802
+ self.transformer.teacache_params = teacache_params_uncond
803
+
804
+ model_pred_uncond = self.predict(
805
+ t=t,
806
+ latents=latents,
807
+ prompt_embeds=negative_prompt_embeds,
808
+ freqs_cis=freqs_cis,
809
+ prompt_attention_mask=negative_prompt_attention_mask,
810
+ ref_image_hidden_states=None,
811
+ )
812
+ model_pred = model_pred_uncond + text_guidance_scale * (
813
+ model_pred - model_pred_uncond
814
+ )
815
+
816
+ latents = self.scheduler.step(
817
+ model_pred, t, latents, return_dict=False
818
+ )[0]
819
+
820
+ latents = latents.to(dtype=dtype)
821
+
822
+ if i == len(timesteps) - 1 or (
823
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
824
+ ):
825
+ progress_bar.update()
826
+
827
+ if step_func is not None:
828
+ step_func(i, self._num_timesteps)
829
+
830
+ if enable_taylorseer:
831
+ del (
832
+ model_pred_cache_dic,
833
+ model_pred_ref_cache_dic,
834
+ model_pred_uncond_cache_dic,
835
+ )
836
+ del model_pred_current, model_pred_ref_current, model_pred_uncond_current
837
+
838
+ latents = latents.to(dtype=dtype)
839
+ if get_latents_text_embeds:
840
+ return latents
841
+
842
+ if self.vae.config.scaling_factor is not None:
843
+ latents = latents / self.vae.config.scaling_factor
844
+ if self.vae.config.shift_factor is not None:
845
+ latents = latents + self.vae.config.shift_factor
846
+ if len(latents.shape) == 4:
847
+ image = self.vae.decode(latents, return_dict=False)[0]
848
+ else:
849
+ image = [
850
+ self.vae.decode(latents[:, i], return_dict=False)[0]
851
+ for i in range(latents.shape[1])
852
+ ]
853
+ image = torch.stack(image, dim=1)
854
+
855
+ return image
856
+
857
+ def predict(
858
+ self,
859
+ t,
860
+ latents,
861
+ prompt_embeds,
862
+ freqs_cis,
863
+ prompt_attention_mask,
864
+ ref_image_hidden_states,
865
+ ):
866
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
867
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
868
+
869
+ if len(latents.shape) == 4:
870
+ batch_size, num_channels_latents, height, width = latents.shape
871
+ is_temporal = False
872
+ else:
873
+ batch_size, num_frames, num_channels_latents, height, width = latents.shape
874
+ latents = [_latents for _latents in latents]
875
+ is_temporal = True
876
+
877
+ optional_kwargs = {}
878
+ if "ref_image_hidden_states" in set(
879
+ inspect.signature(self.transformer.forward).parameters.keys()
880
+ ):
881
+ optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
882
+
883
+ model_pred = self.transformer(
884
+ latents,
885
+ timestep,
886
+ prompt_embeds,
887
+ freqs_cis,
888
+ prompt_attention_mask,
889
+ **optional_kwargs,
890
+ )
891
+
892
+ if is_temporal:
893
+ model_pred = torch.stack(model_pred)
894
+ return model_pred