Upload FOFPred pipeline

#6
by kahnchana - opened
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  processor/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  processor/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ __pycache__/transformer_fofpred.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -15,18 +15,20 @@ tags:
15
  ## Usage
16
 
17
  ```python
 
 
18
  import torch
19
- from fofpred.pipelines.fofpred.pipeline_fofpred import FOFPredPipeline
20
- from fofpred.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
21
  from PIL import Image
22
 
23
- pipeline = FOFPredPipeline.from_pretrained(
 
24
  "Salesforce/FOFPred",
25
  torch_dtype=torch.bfloat16,
 
26
  ).to("cuda")
27
 
28
- pipeline.scheduler = FlowMatchEulerDiscreteScheduler()
29
-
30
  results = pipeline(
31
  prompt="Moving the water bottle from right to left.",
32
  input_images=[Image.open("your_image.jpg")],
@@ -40,6 +42,12 @@ results = pipeline(
40
  )
41
 
42
  flow_frames = results.images # [B, F, C, H, W]
 
 
 
 
 
 
43
  ```
44
 
45
  ## Architecture
 
15
  ## Usage
16
 
17
  ```python
18
+ import einops
19
+ import numpy as np
20
  import torch
21
+ from diffusers import DiffusionPipeline
 
22
  from PIL import Image
23
 
24
+ # Load pipeline with trust_remote_code
25
+ pipeline = DiffusionPipeline.from_pretrained(
26
  "Salesforce/FOFPred",
27
  torch_dtype=torch.bfloat16,
28
+ trust_remote_code=True,
29
  ).to("cuda")
30
 
31
+ # Run inference
 
32
  results = pipeline(
33
  prompt="Moving the water bottle from right to left.",
34
  input_images=[Image.open("your_image.jpg")],
 
42
  )
43
 
44
  flow_frames = results.images # [B, F, C, H, W]
45
+
46
+ output_tensor = flow_frames[0] # [F, C, H, W]
47
+ output_np = pipeline.image_processor.pt_to_numpy(output_tensor) # [F, H, W, C]
48
+ reshaped = einops.rearrange(output_np, "f h w c -> h (f w) c")
49
+ img = Image.fromarray((reshaped * 255).astype(np.uint8))
50
+ img.save("output_combined.png")
51
  ```
52
 
53
  ## Architecture
__pycache__/pipeline_fofpred.cpython-311.pyc ADDED
Binary file (88.8 kB). View file
 
__pycache__/scheduler_fofpred.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
__pycache__/transformer_fofpred.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4194813ba36a92b72a9fc5e90a0257d61096743c4e5bb6800f3c6683b3774510
3
+ size 124604
model_index.json CHANGED
@@ -4,7 +4,6 @@
4
  "FOFPredPipeline"
5
  ],
6
  "_diffusers_version": "0.34.0",
7
- "_name_or_path": "/export/home/public_repo/FOFPred/pretrained_models/hf_upload",
8
  "mllm": [
9
  "transformers",
10
  "Qwen2_5_VLForConditionalGeneration"
@@ -14,11 +13,11 @@
14
  "Qwen2_5_VLProcessor"
15
  ],
16
  "scheduler": [
17
- "diffusers",
18
  "FlowMatchEulerDiscreteScheduler"
19
  ],
20
  "transformer": [
21
- "transformer_omnigen2",
22
  "OmniGen2Transformer3DModel"
23
  ],
24
  "vae": [
 
4
  "FOFPredPipeline"
5
  ],
6
  "_diffusers_version": "0.34.0",
 
7
  "mllm": [
8
  "transformers",
9
  "Qwen2_5_VLForConditionalGeneration"
 
13
  "Qwen2_5_VLProcessor"
14
  ],
15
  "scheduler": [
16
+ "scheduler_fofpred",
17
  "FlowMatchEulerDiscreteScheduler"
18
  ],
19
  "transformer": [
20
+ "transformer_fofpred",
21
  "OmniGen2Transformer3DModel"
22
  ],
23
  "vae": [
pipeline_fofpred.py CHANGED
@@ -17,39 +17,1003 @@ limitations under the License.
17
  """
18
 
19
  import inspect
 
 
20
  from dataclasses import dataclass
21
- from typing import Any, Dict, List, Optional, Tuple, Union
22
 
23
  import numpy as np
24
  import PIL.Image
25
  import torch
 
26
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from diffusers.models.autoencoders import AutoencoderKL
 
28
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
30
  from diffusers.utils import (
 
31
  BaseOutput,
 
 
 
32
  is_torch_xla_available,
 
 
33
  logging,
34
  )
35
  from diffusers.utils.torch_utils import randn_tensor
 
 
36
  from transformers import Qwen2_5_VLForConditionalGeneration
37
 
38
- from fofpred.pipelines.image_processor import OmniGen2ImageProcessor
39
- from fofpred.utils.teacache_util import TeaCacheParams
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- from ...models.transformers import OmniGen2Transformer3DModel
42
- from ...models.transformers.repo import OmniGen2RotaryPosEmbed
43
- from ..lora_pipeline import OmniGen2LoraLoaderMixin
44
 
45
  if is_torch_xla_available():
46
  XLA_AVAILABLE = True
47
  else:
48
  XLA_AVAILABLE = False
49
 
50
- from ...cache_functions import cache_init
51
 
52
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  @dataclass
 
17
  """
18
 
19
  import inspect
20
+ import os
21
+ import warnings
22
  from dataclasses import dataclass
23
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
24
 
25
  import numpy as np
26
  import PIL.Image
27
  import torch
28
+ import torch.nn as nn
29
  import torch.nn.functional as F
30
+ from diffusers.configuration_utils import register_to_config
31
+ from diffusers.image_processor import (
32
+ PipelineImageInput,
33
+ VaeImageProcessor,
34
+ is_valid_image_imagelist,
35
+ )
36
+ from diffusers.loaders.lora_base import ( # noqa
37
+ LoraBaseMixin,
38
+ _fetch_state_dict,
39
+ )
40
+ from diffusers.loaders.lora_conversion_utils import (
41
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
42
+ )
43
  from diffusers.models.autoencoders import AutoencoderKL
44
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
45
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
 
46
  from diffusers.utils import (
47
+ USE_PEFT_BACKEND,
48
  BaseOutput,
49
+ is_peft_available,
50
+ is_peft_version,
51
+ is_torch_version,
52
  is_torch_xla_available,
53
+ is_transformers_available,
54
+ is_transformers_version,
55
  logging,
56
  )
57
  from diffusers.utils.torch_utils import randn_tensor
58
+ from einops import repeat
59
+ from huggingface_hub.utils import validate_hf_hub_args
60
  from transformers import Qwen2_5_VLForConditionalGeneration
61
 
62
+ from .scheduler_fofpred import FlowMatchEulerDiscreteScheduler
63
+ from .transformer_fofpred import OmniGen2Transformer3DModel
64
+
65
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
+
67
+
68
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
69
+ if is_torch_version(">=", "1.9.0"):
70
+ if (
71
+ is_peft_available()
72
+ and is_peft_version(">=", "0.13.1")
73
+ and is_transformers_available()
74
+ and is_transformers_version(">", "4.45.2")
75
+ ):
76
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
77
 
 
 
 
78
 
79
  if is_torch_xla_available():
80
  XLA_AVAILABLE = True
81
  else:
82
  XLA_AVAILABLE = False
83
 
 
84
 
85
+ TRANSFORMER_NAME = "transformer"
86
+
87
+
88
+ class OmniGen2ImageProcessor(VaeImageProcessor):
89
+ """
90
+ Image processor for PixArt image resize and crop.
91
+
92
+ Args:
93
+ do_resize (`bool`, *optional*, defaults to `True`):
94
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
95
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
96
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
97
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
98
+ resample (`str`, *optional*, defaults to `lanczos`):
99
+ Resampling filter to use when resizing the image.
100
+ do_normalize (`bool`, *optional*, defaults to `True`):
101
+ Whether to normalize the image to [-1,1].
102
+ do_binarize (`bool`, *optional*, defaults to `False`):
103
+ Whether to binarize the image to 0/1.
104
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
105
+ Whether to convert the images to RGB format.
106
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
107
+ Whether to convert the images to grayscale format.
108
+ """
109
+
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ do_resize: bool = True,
114
+ vae_scale_factor: int = 16,
115
+ resample: str = "lanczos",
116
+ max_pixels: Optional[int] = None,
117
+ max_side_length: Optional[int] = None,
118
+ do_normalize: bool = True,
119
+ do_binarize: bool = False,
120
+ do_convert_grayscale: bool = False,
121
+ ):
122
+ super().__init__(
123
+ do_resize=do_resize,
124
+ vae_scale_factor=vae_scale_factor,
125
+ resample=resample,
126
+ do_normalize=do_normalize,
127
+ do_binarize=do_binarize,
128
+ do_convert_grayscale=do_convert_grayscale,
129
+ )
130
+
131
+ self.max_pixels = max_pixels
132
+ self.max_side_length = max_side_length
133
+
134
+ def get_new_height_width(
135
+ self,
136
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
137
+ height: Optional[int] = None,
138
+ width: Optional[int] = None,
139
+ max_pixels: Optional[int] = None,
140
+ max_side_length: Optional[int] = None,
141
+ ) -> Tuple[int, int]:
142
+ r"""
143
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
144
+
145
+ Args:
146
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
147
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
148
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
149
+ tensor, it should have shape `[batch, channels, height, width]`.
150
+ height (`Optional[int]`, *optional*, defaults to `None`):
151
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
152
+ width (`Optional[int]`, *optional*, defaults to `None`):
153
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
154
+
155
+ Returns:
156
+ `Tuple[int, int]`:
157
+ A tuple containing the height and width, both resized to the nearest integer multiple of
158
+ `vae_scale_factor`.
159
+ """
160
+
161
+ if height is None:
162
+ if isinstance(image, PIL.Image.Image):
163
+ height = image.height
164
+ elif isinstance(image, torch.Tensor):
165
+ height = image.shape[2]
166
+ else:
167
+ height = image.shape[1]
168
+
169
+ if width is None:
170
+ if isinstance(image, PIL.Image.Image):
171
+ width = image.width
172
+ elif isinstance(image, torch.Tensor):
173
+ width = image.shape[3]
174
+ else:
175
+ width = image.shape[2]
176
+
177
+ if max_side_length is None:
178
+ max_side_length = self.max_side_length
179
+
180
+ if max_pixels is None:
181
+ max_pixels = self.max_pixels
182
+
183
+ ratio = 1.0
184
+ if max_side_length is not None:
185
+ if height > width:
186
+ max_side_length_ratio = max_side_length / height
187
+ else:
188
+ max_side_length_ratio = max_side_length / width
189
+
190
+ cur_pixels = height * width
191
+ max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
192
+ ratio = min(
193
+ max_pixels_ratio, max_side_length_ratio, 1.0
194
+ ) # do not upscale input image
195
+
196
+ new_height, new_width = (
197
+ int(height * ratio)
198
+ // self.config.vae_scale_factor
199
+ * self.config.vae_scale_factor,
200
+ int(width * ratio)
201
+ // self.config.vae_scale_factor
202
+ * self.config.vae_scale_factor,
203
+ )
204
+ return new_height, new_width
205
+
206
+ def preprocess(
207
+ self,
208
+ image: PipelineImageInput,
209
+ height: Optional[int] = None,
210
+ width: Optional[int] = None,
211
+ max_pixels: Optional[int] = None,
212
+ max_side_length: Optional[int] = None,
213
+ resize_mode: str = "default", # "default", "fill", "crop"
214
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
215
+ ) -> torch.Tensor:
216
+ """
217
+ Preprocess the image input.
218
+
219
+ Args:
220
+ image (`PipelineImageInput`):
221
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
222
+ supported formats.
223
+ height (`int`, *optional*):
224
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
225
+ height.
226
+ width (`int`, *optional*):
227
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
228
+ resize_mode (`str`, *optional*, defaults to `default`):
229
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
230
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
231
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
232
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
233
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
234
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
235
+ supported for PIL image input.
236
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
237
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
238
+
239
+ Returns:
240
+ `torch.Tensor`:
241
+ The preprocessed image.
242
+ """
243
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
244
+
245
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
246
+ if (
247
+ self.config.do_convert_grayscale
248
+ and isinstance(image, (torch.Tensor, np.ndarray))
249
+ and image.ndim == 3
250
+ ):
251
+ if isinstance(image, torch.Tensor):
252
+ # if image is a pytorch tensor could have 2 possible shapes:
253
+ # 1. batch x height x width: we should insert the channel dimension at position 1
254
+ # 2. channel x height x width: we should insert batch dimension at position 0,
255
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
256
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
257
+ image = image.unsqueeze(1)
258
+ else:
259
+ # if it is a numpy array, it could have 2 possible shapes:
260
+ # 1. batch x height x width: insert channel dimension on last position
261
+ # 2. height x width x channel: insert batch dimension on first position
262
+ if image.shape[-1] == 1:
263
+ image = np.expand_dims(image, axis=0)
264
+ else:
265
+ image = np.expand_dims(image, axis=-1)
266
+
267
+ if (
268
+ isinstance(image, list)
269
+ and isinstance(image[0], np.ndarray)
270
+ and image[0].ndim == 4
271
+ ):
272
+ warnings.warn(
273
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
274
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
275
+ FutureWarning,
276
+ )
277
+ image = np.concatenate(image, axis=0)
278
+ if (
279
+ isinstance(image, list)
280
+ and isinstance(image[0], torch.Tensor)
281
+ and image[0].ndim == 4
282
+ ):
283
+ warnings.warn(
284
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
285
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
286
+ FutureWarning,
287
+ )
288
+ image = torch.cat(image, axis=0)
289
+
290
+ if not is_valid_image_imagelist(image):
291
+ raise ValueError(
292
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
293
+ )
294
+ if not isinstance(image, list):
295
+ image = [image]
296
+
297
+ if isinstance(image[0], PIL.Image.Image):
298
+ if crops_coords is not None:
299
+ image = [i.crop(crops_coords) for i in image]
300
+ if self.config.do_resize:
301
+ height, width = self.get_new_height_width(
302
+ image[0], height, width, max_pixels, max_side_length
303
+ )
304
+ image = [
305
+ self.resize(i, height, width, resize_mode=resize_mode)
306
+ for i in image
307
+ ]
308
+ if self.config.do_convert_rgb:
309
+ image = [self.convert_to_rgb(i) for i in image]
310
+ elif self.config.do_convert_grayscale:
311
+ image = [self.convert_to_grayscale(i) for i in image]
312
+ image = self.pil_to_numpy(image) # to np
313
+ image = self.numpy_to_pt(image) # to pt
314
+
315
+ elif isinstance(image[0], np.ndarray):
316
+ image = (
317
+ np.concatenate(image, axis=0)
318
+ if image[0].ndim == 4
319
+ else np.stack(image, axis=0)
320
+ )
321
+
322
+ image = self.numpy_to_pt(image)
323
+
324
+ height, width = self.get_new_height_width(
325
+ image, height, width, max_pixels, max_side_length
326
+ )
327
+ if self.config.do_resize:
328
+ image = self.resize(image, height, width)
329
+
330
+ elif isinstance(image[0], torch.Tensor):
331
+ image = (
332
+ torch.cat(image, axis=0)
333
+ if image[0].ndim == 4
334
+ else torch.stack(image, axis=0)
335
+ )
336
+
337
+ if self.config.do_convert_grayscale and image.ndim == 3:
338
+ image = image.unsqueeze(1)
339
+
340
+ channel = image.shape[1]
341
+ # don't need any preprocess if the image is latents
342
+ if channel == self.config.vae_latent_channels:
343
+ return image
344
+
345
+ height, width = self.get_new_height_width(
346
+ image, height, width, max_pixels, max_side_length
347
+ )
348
+ if self.config.do_resize:
349
+ image = self.resize(image, height, width)
350
+
351
+ # expected range [0,1], normalize to [-1,1]
352
+ do_normalize = self.config.do_normalize
353
+ if do_normalize and image.min() < 0:
354
+ warnings.warn(
355
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
356
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
357
+ FutureWarning,
358
+ )
359
+ do_normalize = False
360
+ if do_normalize:
361
+ image = self.normalize(image)
362
+
363
+ if self.config.do_binarize:
364
+ image = self.binarize(image)
365
+
366
+ return image
367
+
368
+
369
+ @dataclass
370
+ class TeaCacheParams:
371
+ """
372
+ TeaCache parameters for `OmniGen2Transformer3DModel`
373
+ See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding
374
+
375
+ Args:
376
+ previous_residual (Optional[torch.Tensor]):
377
+ The tensor difference between the output and the input of the transformer layers from the previous timestep.
378
+ previous_modulated_inp (Optional[torch.Tensor]):
379
+ The modulated input from the previous timestep used to indicate the change of the transformer layer's output.
380
+ accumulated_rel_l1_distance (float):
381
+ The accumulated relative L1 distance.
382
+ is_first_or_last_step (bool):
383
+ Whether the current timestep is the first or last step.
384
+ """
385
+
386
+ previous_residual: Optional[torch.Tensor] = None
387
+ previous_modulated_inp: Optional[torch.Tensor] = None
388
+ accumulated_rel_l1_distance: float = 0
389
+ is_first_or_last_step: bool = False
390
+
391
+
392
+ class OmniGen2RotaryPosEmbed(nn.Module):
393
+ def __init__(
394
+ self,
395
+ theta: int,
396
+ axes_dim: Tuple[int, int, int],
397
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
398
+ patch_size: int = 2,
399
+ ):
400
+ super().__init__()
401
+ self.theta = theta
402
+ self.axes_dim = axes_dim
403
+ self.axes_lens = axes_lens
404
+ self.patch_size = patch_size
405
+
406
+ @staticmethod
407
+ def get_freqs_cis(
408
+ axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int
409
+ ) -> List[torch.Tensor]:
410
+ freqs_cis = []
411
+ freqs_dtype = (
412
+ torch.float32 if torch.backends.mps.is_available() else torch.float64
413
+ )
414
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
415
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
416
+ freqs_cis.append(emb)
417
+ return freqs_cis
418
+
419
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
420
+ device = ids.device
421
+ if ids.device.type == "mps":
422
+ ids = ids.to("cpu")
423
+
424
+ result = []
425
+ for i in range(len(self.axes_dim)):
426
+ freqs = freqs_cis[i].to(ids.device)
427
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
428
+ result.append(
429
+ torch.gather(
430
+ freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index
431
+ )
432
+ )
433
+ return torch.cat(result, dim=-1).to(device)
434
+
435
+ def forward(
436
+ self,
437
+ freqs_cis,
438
+ attention_mask,
439
+ l_effective_ref_img_len,
440
+ l_effective_img_len,
441
+ ref_img_sizes,
442
+ img_sizes,
443
+ device,
444
+ ):
445
+ batch_size = len(attention_mask)
446
+ p = self.patch_size
447
+
448
+ encoder_seq_len = attention_mask.shape[1]
449
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
450
+
451
+ if isinstance(l_effective_img_len[0], list): # Check for t-dim case
452
+ seq_lengths = [
453
+ cap_len + sum(ref_img_len) + sum(img_len)
454
+ for cap_len, ref_img_len, img_len in zip(
455
+ l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len
456
+ )
457
+ ]
458
+ else: # Original case
459
+ seq_lengths = [
460
+ cap_len + sum(ref_img_len) + img_len
461
+ for cap_len, ref_img_len, img_len in zip(
462
+ l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len
463
+ )
464
+ ]
465
+
466
+ max_seq_len = max(seq_lengths)
467
+ max_ref_img_len = max(
468
+ [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]
469
+ )
470
+ if isinstance(l_effective_img_len[0], list):
471
+ max_img_len = max([sum(ln) for ln in l_effective_img_len])
472
+ else:
473
+ max_img_len = max(l_effective_img_len)
474
+
475
+ # Create position IDs
476
+ position_ids = torch.zeros(
477
+ batch_size, max_seq_len, 3, dtype=torch.int32, device=device
478
+ )
479
+
480
+ for i, (cap_seq_len, seq_len) in enumerate(
481
+ zip(l_effective_cap_len, seq_lengths)
482
+ ):
483
+ # add text position ids
484
+ position_ids[i, :cap_seq_len] = repeat(
485
+ torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3"
486
+ )
487
+
488
+ pe_shift = cap_seq_len
489
+ pe_shift_len = cap_seq_len
490
+
491
+ if ref_img_sizes[i] is not None:
492
+ for ref_img_size, ref_img_len in zip(
493
+ ref_img_sizes[i], l_effective_ref_img_len[i]
494
+ ):
495
+ H, W = ref_img_size
496
+ ref_H_tokens, ref_W_tokens = H // p, W // p
497
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
498
+ # add image position ids
499
+
500
+ row_ids = repeat(
501
+ torch.arange(ref_H_tokens, dtype=torch.int32, device=device),
502
+ "h -> h w",
503
+ w=ref_W_tokens,
504
+ ).flatten()
505
+ col_ids = repeat(
506
+ torch.arange(ref_W_tokens, dtype=torch.int32, device=device),
507
+ "w -> h w",
508
+ h=ref_H_tokens,
509
+ ).flatten()
510
+ position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = (
511
+ pe_shift
512
+ )
513
+ position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = (
514
+ row_ids
515
+ )
516
+ position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = (
517
+ col_ids
518
+ )
519
+
520
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
521
+ pe_shift_len += ref_img_len
522
+
523
+ if isinstance(l_effective_img_len[i], list): # New case
524
+ for img_size, img_len in zip(img_sizes[i], l_effective_img_len[i]):
525
+ H, W = img_size
526
+ H_tokens, W_tokens = H // p, W // p
527
+ assert H_tokens * W_tokens == img_len
528
+
529
+ row_ids = repeat(
530
+ torch.arange(H_tokens, dtype=torch.int32, device=device),
531
+ "h -> h w",
532
+ w=W_tokens,
533
+ ).flatten()
534
+ col_ids = repeat(
535
+ torch.arange(W_tokens, dtype=torch.int32, device=device),
536
+ "w -> h w",
537
+ h=H_tokens,
538
+ ).flatten()
539
+
540
+ end_idx = pe_shift_len + img_len
541
+
542
+ position_ids[i, pe_shift_len:end_idx, 0] = pe_shift
543
+ position_ids[i, pe_shift_len:end_idx, 1] = row_ids
544
+ position_ids[i, pe_shift_len:end_idx, 2] = col_ids
545
+
546
+ pe_shift += max(H_tokens, W_tokens)
547
+ pe_shift_len = end_idx
548
+ else: # Original case
549
+ H, W = img_sizes[i]
550
+ H_tokens, W_tokens = H // p, W // p
551
+ assert H_tokens * W_tokens == l_effective_img_len[i]
552
+
553
+ row_ids = repeat(
554
+ torch.arange(H_tokens, dtype=torch.int32, device=device),
555
+ "h -> h w",
556
+ w=W_tokens,
557
+ ).flatten()
558
+ col_ids = repeat(
559
+ torch.arange(W_tokens, dtype=torch.int32, device=device),
560
+ "w -> h w",
561
+ h=H_tokens,
562
+ ).flatten()
563
+
564
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
565
+ position_ids[i, pe_shift_len:seq_len, 0] = pe_shift
566
+ position_ids[i, pe_shift_len:seq_len, 1] = row_ids
567
+ position_ids[i, pe_shift_len:seq_len, 2] = col_ids
568
+
569
+ # Get combined rotary embeddings
570
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
571
+
572
+ # create separate rotary embeddings for captions and images
573
+ cap_freqs_cis = torch.zeros(
574
+ batch_size,
575
+ encoder_seq_len,
576
+ freqs_cis.shape[-1],
577
+ device=device,
578
+ dtype=freqs_cis.dtype,
579
+ )
580
+ ref_img_freqs_cis = torch.zeros(
581
+ batch_size,
582
+ max_ref_img_len,
583
+ freqs_cis.shape[-1],
584
+ device=device,
585
+ dtype=freqs_cis.dtype,
586
+ )
587
+ img_freqs_cis = torch.zeros(
588
+ batch_size,
589
+ max_img_len,
590
+ freqs_cis.shape[-1],
591
+ device=device,
592
+ dtype=freqs_cis.dtype,
593
+ )
594
+
595
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(
596
+ zip(
597
+ l_effective_cap_len,
598
+ l_effective_ref_img_len,
599
+ l_effective_img_len,
600
+ seq_lengths,
601
+ )
602
+ ):
603
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
604
+ ref_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[
605
+ i, cap_seq_len : cap_seq_len + sum(ref_img_len)
606
+ ]
607
+ if isinstance(img_len, list):
608
+ img_len = sum(img_len)
609
+ img_freqs_cis[i, :img_len] = freqs_cis[
610
+ i,
611
+ cap_seq_len + sum(ref_img_len) : cap_seq_len
612
+ + sum(ref_img_len)
613
+ + img_len,
614
+ ]
615
+
616
+ return (
617
+ cap_freqs_cis,
618
+ ref_img_freqs_cis,
619
+ img_freqs_cis,
620
+ freqs_cis,
621
+ l_effective_cap_len,
622
+ seq_lengths,
623
+ )
624
+
625
+
626
+ class OmniGen2LoraLoaderMixin(LoraBaseMixin):
627
+ r"""
628
+ Load LoRA layers into [`OmniGen2Transformer3DModel`]. Specific to [`FOFPredPipeline`].
629
+ """
630
+
631
+ _lora_loadable_modules = ["transformer"]
632
+ transformer_name = TRANSFORMER_NAME
633
+
634
+ @classmethod
635
+ @validate_hf_hub_args
636
+ def lora_state_dict(
637
+ cls,
638
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
639
+ **kwargs,
640
+ ):
641
+ r"""
642
+ Return state dict for lora weights and the network alphas.
643
+
644
+ <Tip warning={true}>
645
+
646
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
647
+
648
+ This function is experimental and might change in the future.
649
+
650
+ </Tip>
651
+
652
+ Parameters:
653
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
654
+ Can be either:
655
+
656
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
657
+ the Hub.
658
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
659
+ with [`ModelMixin.save_pretrained`].
660
+ - A [torch state
661
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
662
+
663
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
664
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
665
+ is not used.
666
+ force_download (`bool`, *optional*, defaults to `False`):
667
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
668
+ cached versions if they exist.
669
+
670
+ proxies (`Dict[str, str]`, *optional*):
671
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
672
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
673
+ local_files_only (`bool`, *optional*, defaults to `False`):
674
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
675
+ won't be downloaded from the Hub.
676
+ token (`str` or *bool*, *optional*):
677
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
678
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
679
+ revision (`str`, *optional*, defaults to `"main"`):
680
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
681
+ allowed by Git.
682
+ subfolder (`str`, *optional*, defaults to `""`):
683
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
684
+
685
+ """
686
+ # Load the main state dict first which has the LoRA layers for either of
687
+ # transformer and text encoder or both.
688
+ cache_dir = kwargs.pop("cache_dir", None)
689
+ force_download = kwargs.pop("force_download", False)
690
+ proxies = kwargs.pop("proxies", None)
691
+ local_files_only = kwargs.pop("local_files_only", None)
692
+ token = kwargs.pop("token", None)
693
+ revision = kwargs.pop("revision", None)
694
+ subfolder = kwargs.pop("subfolder", None)
695
+ weight_name = kwargs.pop("weight_name", None)
696
+ use_safetensors = kwargs.pop("use_safetensors", None)
697
+
698
+ allow_pickle = False
699
+ if use_safetensors is None:
700
+ use_safetensors = True
701
+ allow_pickle = True
702
+
703
+ user_agent = {
704
+ "file_type": "attn_procs_weights",
705
+ "framework": "pytorch",
706
+ }
707
+
708
+ state_dict = _fetch_state_dict(
709
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
710
+ weight_name=weight_name,
711
+ use_safetensors=use_safetensors,
712
+ local_files_only=local_files_only,
713
+ cache_dir=cache_dir,
714
+ force_download=force_download,
715
+ proxies=proxies,
716
+ token=token,
717
+ revision=revision,
718
+ subfolder=subfolder,
719
+ user_agent=user_agent,
720
+ allow_pickle=allow_pickle,
721
+ )
722
+
723
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
724
+ if is_dora_scale_present:
725
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
726
+ logger.warning(warn_msg)
727
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
728
+
729
+ # conversion.
730
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
731
+ if non_diffusers:
732
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
733
+
734
+ return state_dict
735
+
736
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
737
+ def load_lora_weights(
738
+ self,
739
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
740
+ adapter_name=None,
741
+ **kwargs,
742
+ ):
743
+ """
744
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
745
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
746
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
747
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
748
+ dict is loaded into `self.transformer`.
749
+
750
+ Parameters:
751
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
752
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
753
+ adapter_name (`str`, *optional*):
754
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
755
+ `default_{i}` where i is the total number of adapters being loaded.
756
+ low_cpu_mem_usage (`bool`, *optional*):
757
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
758
+ weights.
759
+ kwargs (`dict`, *optional*):
760
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
761
+ """
762
+ if not USE_PEFT_BACKEND:
763
+ raise ValueError("PEFT backend is required for this method.")
764
+
765
+ low_cpu_mem_usage = kwargs.pop(
766
+ "low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA
767
+ )
768
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
769
+ raise ValueError(
770
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
771
+ )
772
+
773
+ # if a dict is passed, copy it instead of modifying it inplace
774
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
775
+ pretrained_model_name_or_path_or_dict = (
776
+ pretrained_model_name_or_path_or_dict.copy()
777
+ )
778
+
779
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
780
+ state_dict = self.lora_state_dict(
781
+ pretrained_model_name_or_path_or_dict, **kwargs
782
+ )
783
+
784
+ is_correct_format = all("lora" in key for key in state_dict.keys())
785
+ if not is_correct_format:
786
+ raise ValueError("Invalid LoRA checkpoint.")
787
+
788
+ self.load_lora_into_transformer(
789
+ state_dict,
790
+ transformer=getattr(self, self.transformer_name)
791
+ if not hasattr(self, "transformer")
792
+ else self.transformer,
793
+ adapter_name=adapter_name,
794
+ _pipeline=self,
795
+ low_cpu_mem_usage=low_cpu_mem_usage,
796
+ )
797
+
798
+ @classmethod
799
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
800
+ def load_lora_into_transformer(
801
+ cls,
802
+ state_dict,
803
+ transformer,
804
+ adapter_name=None,
805
+ _pipeline=None,
806
+ low_cpu_mem_usage=False,
807
+ hotswap: bool = False,
808
+ ):
809
+ """
810
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
811
+
812
+ Parameters:
813
+ state_dict (`dict`):
814
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
815
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
816
+ encoder lora layers.
817
+ transformer (`Lumina2Transformer2DModel`):
818
+ The Transformer model to load the LoRA layers into.
819
+ adapter_name (`str`, *optional*):
820
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
821
+ `default_{i}` where i is the total number of adapters being loaded.
822
+ low_cpu_mem_usage (`bool`, *optional*):
823
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
824
+ weights.
825
+ hotswap : (`bool`, *optional*)
826
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
827
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
828
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
829
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
830
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
831
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
832
+
833
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
834
+ to call an additional method before loading the adapter:
835
+
836
+ ```py
837
+ pipeline = ... # load diffusers pipeline
838
+ max_rank = ... # the highest rank among all LoRAs that you want to load
839
+ # call *before* compiling and loading the LoRA adapter
840
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
841
+ pipeline.load_lora_weights(file_name)
842
+ # optionally compile the model now
843
+ ```
844
+
845
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
846
+ limitations to this technique, which are documented here:
847
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
848
+ """
849
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
850
+ raise ValueError(
851
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
852
+ )
853
+
854
+ # Load the layers corresponding to transformer.
855
+ logger.info(f"Loading {cls.transformer_name}.")
856
+ transformer.load_lora_adapter(
857
+ state_dict,
858
+ network_alphas=None,
859
+ adapter_name=adapter_name,
860
+ _pipeline=_pipeline,
861
+ low_cpu_mem_usage=low_cpu_mem_usage,
862
+ hotswap=hotswap,
863
+ )
864
+
865
+ @classmethod
866
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
867
+ def save_lora_weights(
868
+ cls,
869
+ save_directory: Union[str, os.PathLike],
870
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
871
+ is_main_process: bool = True,
872
+ weight_name: str = None,
873
+ save_function: Callable = None,
874
+ safe_serialization: bool = True,
875
+ ):
876
+ r"""
877
+ Save the LoRA parameters corresponding to the UNet and text encoder.
878
+
879
+ Arguments:
880
+ save_directory (`str` or `os.PathLike`):
881
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
882
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
883
+ State dict of the LoRA layers corresponding to the `transformer`.
884
+ is_main_process (`bool`, *optional*, defaults to `True`):
885
+ Whether the process calling this is the main process or not. Useful during distributed training and you
886
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
887
+ process to avoid race conditions.
888
+ save_function (`Callable`):
889
+ The function to use to save the state dictionary. Useful during distributed training when you need to
890
+ replace `torch.save` with another method. Can be configured with the environment variable
891
+ `DIFFUSERS_SAVE_MODE`.
892
+ safe_serialization (`bool`, *optional*, defaults to `True`):
893
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
894
+ """
895
+ state_dict = {}
896
+
897
+ if not transformer_lora_layers:
898
+ raise ValueError("You must pass `transformer_lora_layers`.")
899
+
900
+ if transformer_lora_layers:
901
+ state_dict.update(
902
+ cls.pack_weights(transformer_lora_layers, cls.transformer_name)
903
+ )
904
+
905
+ # Save the model
906
+ cls.write_lora_layers(
907
+ state_dict=state_dict,
908
+ save_directory=save_directory,
909
+ is_main_process=is_main_process,
910
+ weight_name=weight_name,
911
+ save_function=save_function,
912
+ safe_serialization=safe_serialization,
913
+ )
914
+
915
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
916
+ def fuse_lora(
917
+ self,
918
+ components: List[str] = ["transformer"],
919
+ lora_scale: float = 1.0,
920
+ safe_fusing: bool = False,
921
+ adapter_names: Optional[List[str]] = None,
922
+ **kwargs,
923
+ ):
924
+ r"""
925
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
926
+
927
+ <Tip warning={true}>
928
+
929
+ This is an experimental API.
930
+
931
+ </Tip>
932
+
933
+ Args:
934
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
935
+ lora_scale (`float`, defaults to 1.0):
936
+ Controls how much to influence the outputs with the LoRA parameters.
937
+ safe_fusing (`bool`, defaults to `False`):
938
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
939
+ adapter_names (`List[str]`, *optional*):
940
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
941
+
942
+ Example:
943
+
944
+ ```py
945
+ from diffusers import DiffusionPipeline
946
+ import torch
947
+
948
+ pipeline = DiffusionPipeline.from_pretrained(
949
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
950
+ ).to("cuda")
951
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
952
+ pipeline.fuse_lora(lora_scale=0.7)
953
+ ```
954
+ """
955
+ super().fuse_lora(
956
+ components=components,
957
+ lora_scale=lora_scale,
958
+ safe_fusing=safe_fusing,
959
+ adapter_names=adapter_names,
960
+ **kwargs,
961
+ )
962
+
963
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
964
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
965
+ r"""
966
+ Reverses the effect of
967
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
968
+
969
+ <Tip warning={true}>
970
+
971
+ This is an experimental API.
972
+
973
+ </Tip>
974
+
975
+ Args:
976
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
977
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
978
+ """
979
+ super().unfuse_lora(components=components, **kwargs)
980
+
981
+
982
+ def cache_init(self, num_steps: int):
983
+ """
984
+ Initialization for cache.
985
+ """
986
+ cache_dic = {}
987
+ cache = {}
988
+ cache_index = {}
989
+ cache[-1] = {}
990
+ cache_index[-1] = {}
991
+ cache_index["layer_index"] = {}
992
+ cache[-1]["layers_stream"] = {}
993
+ cache_dic["cache_counter"] = 0
994
+
995
+ for j in range(len(self.transformer.layers)):
996
+ cache[-1]["layers_stream"][j] = {}
997
+ cache_index[-1][j] = {}
998
+
999
+ cache_dic["Delta-DiT"] = False
1000
+ cache_dic["cache_type"] = "random"
1001
+ cache_dic["cache_index"] = cache_index
1002
+ cache_dic["cache"] = cache
1003
+ cache_dic["fresh_ratio_schedule"] = "ToCa"
1004
+ cache_dic["fresh_ratio"] = 0.0
1005
+ cache_dic["fresh_threshold"] = 3
1006
+ cache_dic["soft_fresh_weight"] = 0.0
1007
+ cache_dic["taylor_cache"] = True
1008
+ cache_dic["max_order"] = 4
1009
+ cache_dic["first_enhance"] = 5
1010
+
1011
+ current = {}
1012
+ current["activated_steps"] = [0]
1013
+ current["step"] = 0
1014
+ current["num_steps"] = num_steps
1015
+
1016
+ return cache_dic, current
1017
 
1018
 
1019
  @dataclass
scheduler/scheduler_config.json CHANGED
@@ -1,18 +1,5 @@
1
  {
2
  "_class_name": "FlowMatchEulerDiscreteScheduler",
3
  "_diffusers_version": "0.34.0",
4
- "base_image_seq_len": 256,
5
- "base_shift": 0.5,
6
- "invert_sigmas": false,
7
- "max_image_seq_len": 4096,
8
- "max_shift": 1.15,
9
- "num_train_timesteps": 1000,
10
- "shift": 1.0,
11
- "shift_terminal": null,
12
- "stochastic_sampling": false,
13
- "time_shift_type": "exponential",
14
- "use_beta_sigmas": false,
15
- "use_dynamic_shifting": false,
16
- "use_exponential_sigmas": false,
17
- "use_karras_sigmas": false
18
  }
 
1
  {
2
  "_class_name": "FlowMatchEulerDiscreteScheduler",
3
  "_diffusers_version": "0.34.0",
4
+ "num_train_timesteps": 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  }
scheduler_fofpred.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders.lora_base import ( # noqa
8
+ LoraBaseMixin,
9
+ _fetch_state_dict,
10
+ )
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ @dataclass
16
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
17
+ """
18
+ Output class for the scheduler's `step` function output.
19
+
20
+ Args:
21
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
22
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
23
+ denoising loop.
24
+ """
25
+
26
+ prev_sample: torch.FloatTensor
27
+
28
+
29
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
30
+ """
31
+ Euler scheduler.
32
+
33
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
34
+ methods the library implements for all schedulers such as loading and saving.
35
+
36
+ Args:
37
+ num_train_timesteps (`int`, defaults to 1000):
38
+ The number of diffusion steps to train the model.
39
+ timestep_spacing (`str`, defaults to `"linspace"`):
40
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
41
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
42
+ shift (`float`, defaults to 1.0):
43
+ The shift value for the timestep schedule.
44
+ """
45
+
46
+ _compatibles = []
47
+ order = 1
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = True
52
+ ):
53
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[
54
+ :-1
55
+ ]
56
+
57
+ self.timesteps = timesteps
58
+
59
+ self._step_index = None
60
+ self._begin_index = None
61
+
62
+ @property
63
+ def step_index(self):
64
+ """
65
+ The index counter for current timestep. It will increase 1 after each scheduler step.
66
+ """
67
+ return self._step_index
68
+
69
+ @property
70
+ def begin_index(self):
71
+ """
72
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
73
+ """
74
+ return self._begin_index
75
+
76
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
77
+ def set_begin_index(self, begin_index: int = 0):
78
+ """
79
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
80
+
81
+ Args:
82
+ begin_index (`int`):
83
+ The begin index for the scheduler.
84
+ """
85
+ self._begin_index = begin_index
86
+
87
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
88
+ if schedule_timesteps is None:
89
+ schedule_timesteps = self._timesteps
90
+
91
+ indices = (schedule_timesteps == timestep).nonzero()
92
+
93
+ # The sigma index that is taken for the **very** first `step`
94
+ # is always the second index (or the last index if there is only 1)
95
+ # This way we can ensure we don't accidentally skip a sigma in
96
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
97
+ pos = 1 if len(indices) > 1 else 0
98
+
99
+ return indices[pos].item()
100
+
101
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
102
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
103
+
104
+ def set_timesteps(
105
+ self,
106
+ num_inference_steps: int = None,
107
+ device: Union[str, torch.device] = None,
108
+ timesteps: Optional[List[float]] = None,
109
+ num_tokens: Optional[int] = None,
110
+ ):
111
+ """
112
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
113
+
114
+ Args:
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model.
117
+ device (`str` or `torch.device`, *optional*):
118
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
119
+ """
120
+
121
+ if timesteps is None:
122
+ self.num_inference_steps = num_inference_steps
123
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[
124
+ :-1
125
+ ]
126
+ if self.config.dynamic_time_shift and num_tokens is not None:
127
+ m = (
128
+ np.sqrt(num_tokens) / 40
129
+ ) # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
130
+ timesteps = timesteps / (m - m * timesteps + timesteps)
131
+
132
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
133
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
134
+
135
+ self.timesteps = timesteps
136
+ self._timesteps = _timesteps
137
+ self._step_index = None
138
+ self._begin_index = None
139
+
140
+ def _init_step_index(self, timestep):
141
+ if self.begin_index is None:
142
+ if isinstance(timestep, torch.Tensor):
143
+ timestep = timestep.to(self.timesteps.device)
144
+ self._step_index = self.index_for_timestep(timestep)
145
+ else:
146
+ self._step_index = self._begin_index
147
+
148
+ def step(
149
+ self,
150
+ model_output: torch.FloatTensor,
151
+ timestep: Union[float, torch.FloatTensor],
152
+ sample: torch.FloatTensor,
153
+ generator: Optional[torch.Generator] = None,
154
+ return_dict: bool = True,
155
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
156
+ """
157
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
158
+ process from the learned model outputs (most often the predicted noise).
159
+
160
+ Args:
161
+ model_output (`torch.FloatTensor`):
162
+ The direct output from learned diffusion model.
163
+ timestep (`float`):
164
+ The current discrete timestep in the diffusion chain.
165
+ sample (`torch.FloatTensor`):
166
+ A current instance of a sample created by the diffusion process.
167
+ s_churn (`float`):
168
+ s_tmin (`float`):
169
+ s_tmax (`float`):
170
+ s_noise (`float`, defaults to 1.0):
171
+ Scaling factor for noise added to the sample.
172
+ generator (`torch.Generator`, *optional*):
173
+ A random number generator.
174
+ return_dict (`bool`):
175
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
176
+ tuple.
177
+
178
+ Returns:
179
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
180
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
181
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
182
+ """
183
+
184
+ if (
185
+ isinstance(timestep, int)
186
+ or isinstance(timestep, torch.IntTensor)
187
+ or isinstance(timestep, torch.LongTensor)
188
+ ):
189
+ raise ValueError(
190
+ (
191
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
192
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
193
+ " one of the `scheduler.timesteps` as a timestep."
194
+ ),
195
+ )
196
+
197
+ if self.step_index is None:
198
+ self._init_step_index(timestep)
199
+ # Upcast to avoid precision issues when computing prev_sample
200
+ sample = sample.to(torch.float32)
201
+ t = self._timesteps[self.step_index]
202
+ t_next = self._timesteps[self.step_index + 1]
203
+
204
+ prev_sample = sample + (t_next - t) * model_output
205
+
206
+ # Cast sample back to model compatible dtype
207
+ prev_sample = prev_sample.to(model_output.dtype)
208
+
209
+ # upon completion increase step index by one
210
+ self._step_index += 1
211
+
212
+ if not return_dict:
213
+ return (prev_sample,)
214
+
215
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
216
+
217
+ def __len__(self):
218
+ return self.config.num_train_timesteps
transformer/config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "OmniGen2Transformer3DModel",
3
  "_diffusers_version": "0.34.0",
4
- "_name_or_path": "/export/home/public_repo/FOFPred/pretrained_models/hf_upload/transformer",
5
  "axes_dim_rope": [
6
  40,
7
  40,
 
1
  {
2
  "_class_name": "OmniGen2Transformer3DModel",
3
  "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "pretrained_models/ft_023/transformer",
5
  "axes_dim_rope": [
6
  40,
7
  40,
transformer_fofpred.py ADDED
The diff for this file is too large to render. See raw diff
 
vae/config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.34.0",
4
- "_name_or_path": "/export/home/public_repo/FOFPred/pretrained_models/hf_upload/vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
 
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "/export/home/.cache/huggingface/hub/models--OmniGen2--OmniGen2/snapshots/df5dca8a981d74e6c3af214c145f5c735fe72367/vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,