jbilcke-hf commited on
Commit
0a515a6
·
verified ·
1 Parent(s): 057bd2e

Delete hyvideo

Browse files
hyvideo/__init__.py DELETED
File without changes
hyvideo/config.py DELETED
@@ -1,398 +0,0 @@
1
- import argparse
2
- from .constants import *
3
- import re
4
- from .modules.models import HUNYUAN_VIDEO_CONFIG
5
-
6
-
7
- def parse_args(namespace=None):
8
- parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
9
-
10
- parser = add_network_args(parser)
11
- parser = add_extra_models_args(parser)
12
- parser = add_denoise_schedule_args(parser)
13
- parser = add_inference_args(parser)
14
- parser = add_parallel_args(parser)
15
-
16
- args = parser.parse_args(namespace=namespace)
17
- args = sanity_check_args(args)
18
-
19
- return args
20
-
21
-
22
- def add_network_args(parser: argparse.ArgumentParser):
23
- group = parser.add_argument_group(title="HunyuanVideo network args")
24
-
25
- # Main model
26
- group.add_argument(
27
- "--model",
28
- type=str,
29
- choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
30
- default="HYVideo-T/2-cfgdistill",
31
- )
32
- group.add_argument(
33
- "--latent-channels",
34
- type=str,
35
- default=16,
36
- help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
37
- "it still needs to match the latent channels of the VAE model.",
38
- )
39
- group.add_argument(
40
- "--precision",
41
- type=str,
42
- default="bf16",
43
- choices=PRECISIONS,
44
- help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
45
- )
46
-
47
- # RoPE
48
- group.add_argument(
49
- "--rope-theta", type=int, default=256, help="Theta used in RoPE."
50
- )
51
- return parser
52
-
53
-
54
- def add_extra_models_args(parser: argparse.ArgumentParser):
55
- group = parser.add_argument_group(
56
- title="Extra models args, including vae, text encoders and tokenizers)"
57
- )
58
-
59
- # - VAE
60
- group.add_argument(
61
- "--vae",
62
- type=str,
63
- default="884-16c-hy",
64
- choices=list(VAE_PATH),
65
- help="Name of the VAE model.",
66
- )
67
- group.add_argument(
68
- "--vae-precision",
69
- type=str,
70
- default="fp16",
71
- choices=PRECISIONS,
72
- help="Precision mode for the VAE model.",
73
- )
74
- group.add_argument(
75
- "--vae-tiling",
76
- action="store_true",
77
- help="Enable tiling for the VAE model to save GPU memory.",
78
- )
79
- group.set_defaults(vae_tiling=True)
80
-
81
- group.add_argument(
82
- "--text-encoder",
83
- type=str,
84
- default="llm",
85
- choices=list(TEXT_ENCODER_PATH),
86
- help="Name of the text encoder model.",
87
- )
88
- group.add_argument(
89
- "--text-encoder-precision",
90
- type=str,
91
- default="fp16",
92
- choices=PRECISIONS,
93
- help="Precision mode for the text encoder model.",
94
- )
95
- group.add_argument(
96
- "--text-states-dim",
97
- type=int,
98
- default=4096,
99
- help="Dimension of the text encoder hidden states.",
100
- )
101
- group.add_argument(
102
- "--text-len", type=int, default=256, help="Maximum length of the text input."
103
- )
104
- group.add_argument(
105
- "--tokenizer",
106
- type=str,
107
- default="llm",
108
- choices=list(TOKENIZER_PATH),
109
- help="Name of the tokenizer model.",
110
- )
111
- group.add_argument(
112
- "--prompt-template",
113
- type=str,
114
- default="dit-llm-encode",
115
- choices=PROMPT_TEMPLATE,
116
- help="Image prompt template for the decoder-only text encoder model.",
117
- )
118
- group.add_argument(
119
- "--prompt-template-video",
120
- type=str,
121
- default="dit-llm-encode-video",
122
- choices=PROMPT_TEMPLATE,
123
- help="Video prompt template for the decoder-only text encoder model.",
124
- )
125
- group.add_argument(
126
- "--hidden-state-skip-layer",
127
- type=int,
128
- default=2,
129
- help="Skip layer for hidden states.",
130
- )
131
- group.add_argument(
132
- "--apply-final-norm",
133
- action="store_true",
134
- help="Apply final normalization to the used text encoder hidden states.",
135
- )
136
-
137
- # - CLIP
138
- group.add_argument(
139
- "--text-encoder-2",
140
- type=str,
141
- default="clipL",
142
- choices=list(TEXT_ENCODER_PATH),
143
- help="Name of the second text encoder model.",
144
- )
145
- group.add_argument(
146
- "--text-encoder-precision-2",
147
- type=str,
148
- default="fp16",
149
- choices=PRECISIONS,
150
- help="Precision mode for the second text encoder model.",
151
- )
152
- group.add_argument(
153
- "--text-states-dim-2",
154
- type=int,
155
- default=768,
156
- help="Dimension of the second text encoder hidden states.",
157
- )
158
- group.add_argument(
159
- "--tokenizer-2",
160
- type=str,
161
- default="clipL",
162
- choices=list(TOKENIZER_PATH),
163
- help="Name of the second tokenizer model.",
164
- )
165
- group.add_argument(
166
- "--text-len-2",
167
- type=int,
168
- default=77,
169
- help="Maximum length of the second text input.",
170
- )
171
-
172
- return parser
173
-
174
-
175
- def add_denoise_schedule_args(parser: argparse.ArgumentParser):
176
- group = parser.add_argument_group(title="Denoise schedule args")
177
-
178
- group.add_argument(
179
- "--denoise-type",
180
- type=str,
181
- default="flow",
182
- help="Denoise type for noised inputs.",
183
- )
184
-
185
- # Flow Matching
186
- group.add_argument(
187
- "--flow-shift",
188
- type=float,
189
- default=7.0,
190
- help="Shift factor for flow matching schedulers.",
191
- )
192
- group.add_argument(
193
- "--flow-reverse",
194
- action="store_true",
195
- help="If reverse, learning/sampling from t=1 -> t=0.",
196
- )
197
- group.add_argument(
198
- "--flow-solver",
199
- type=str,
200
- default="euler",
201
- help="Solver for flow matching.",
202
- )
203
- group.add_argument(
204
- "--use-linear-quadratic-schedule",
205
- action="store_true",
206
- help="Use linear quadratic schedule for flow matching."
207
- "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
208
- )
209
- group.add_argument(
210
- "--linear-schedule-end",
211
- type=int,
212
- default=25,
213
- help="End step for linear quadratic schedule for flow matching.",
214
- )
215
-
216
- return parser
217
-
218
-
219
- def add_inference_args(parser: argparse.ArgumentParser):
220
- group = parser.add_argument_group(title="Inference args")
221
-
222
- # ======================== Model loads ========================
223
- group.add_argument(
224
- "--model-base",
225
- type=str,
226
- default=".",
227
- help="Root path of all the models, including t2v models and extra models.",
228
- )
229
- group.add_argument(
230
- "--dit-weight",
231
- type=str,
232
- default="./hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
233
- help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
234
- "1. If it is a file, load the model directly."
235
- "2. If it is a directory, search the model in the directory. Support two types of models: "
236
- "1) named `pytorch_model_*.pt`"
237
- "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
238
- )
239
- group.add_argument(
240
- "--model-resolution",
241
- type=str,
242
- default="540p",
243
- choices=["540p", "720p"],
244
- help="Root path of all the models, including t2v models and extra models.",
245
- )
246
- group.add_argument(
247
- "--load-key",
248
- type=str,
249
- default="module",
250
- help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
251
- )
252
- group.add_argument(
253
- "--use-cpu-offload",
254
- action="store_true",
255
- help="Use CPU offload for the model load.",
256
- )
257
-
258
- # ======================== Inference general setting ========================
259
- group.add_argument(
260
- "--batch-size",
261
- type=int,
262
- default=1,
263
- help="Batch size for inference and evaluation.",
264
- )
265
- group.add_argument(
266
- "--infer-steps",
267
- type=int,
268
- default=50,
269
- help="Number of denoising steps for inference.",
270
- )
271
- group.add_argument(
272
- "--disable-autocast",
273
- action="store_true",
274
- help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
275
- )
276
- group.add_argument(
277
- "--save-path",
278
- type=str,
279
- default="./results",
280
- help="Path to save the generated samples.",
281
- )
282
- group.add_argument(
283
- "--save-path-suffix",
284
- type=str,
285
- default="",
286
- help="Suffix for the directory of saved samples.",
287
- )
288
- group.add_argument(
289
- "--name-suffix",
290
- type=str,
291
- default="",
292
- help="Suffix for the names of saved samples.",
293
- )
294
- group.add_argument(
295
- "--num-videos",
296
- type=int,
297
- default=1,
298
- help="Number of videos to generate for each prompt.",
299
- )
300
- # ---sample size---
301
- group.add_argument(
302
- "--video-size",
303
- type=int,
304
- nargs="+",
305
- default=(720, 1280),
306
- help="Video size for training. If a single value is provided, it will be used for both height "
307
- "and width. If two values are provided, they will be used for height and width "
308
- "respectively.",
309
- )
310
- group.add_argument(
311
- "--video-length",
312
- type=int,
313
- default=129,
314
- help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
315
- )
316
- # --- prompt ---
317
- group.add_argument(
318
- "--prompt",
319
- type=str,
320
- default=None,
321
- help="Prompt for sampling during evaluation.",
322
- )
323
- group.add_argument(
324
- "--seed-type",
325
- type=str,
326
- default="auto",
327
- choices=["file", "random", "fixed", "auto"],
328
- help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
329
- "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
330
- "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
331
- "fixed `seed` value.",
332
- )
333
- group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
334
-
335
- # Classifier-Free Guidance
336
- group.add_argument(
337
- "--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
338
- )
339
- group.add_argument(
340
- "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
341
- )
342
- group.add_argument(
343
- "--embedded-cfg-scale",
344
- type=float,
345
- default=6.0,
346
- help="Embeded classifier free guidance scale.",
347
- )
348
-
349
- group.add_argument(
350
- "--use-fp8",
351
- action="store_true",
352
- help="Enable use fp8 for inference acceleration."
353
- )
354
-
355
- group.add_argument(
356
- "--reproduce",
357
- action="store_true",
358
- help="Enable reproducibility by setting random seeds and deterministic algorithms.",
359
- )
360
-
361
- return parser
362
-
363
-
364
- def add_parallel_args(parser: argparse.ArgumentParser):
365
- group = parser.add_argument_group(title="Parallel args")
366
-
367
- # ======================== Model loads ========================
368
- group.add_argument(
369
- "--ulysses-degree",
370
- type=int,
371
- default=1,
372
- help="Ulysses degree.",
373
- )
374
- group.add_argument(
375
- "--ring-degree",
376
- type=int,
377
- default=1,
378
- help="Ulysses degree.",
379
- )
380
-
381
- return parser
382
-
383
-
384
- def sanity_check_args(args):
385
- # VAE channels
386
- vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
387
- if not re.match(vae_pattern, args.vae):
388
- raise ValueError(
389
- f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
390
- )
391
- vae_channels = int(args.vae.split("-")[1][:-1])
392
- if args.latent_channels is None:
393
- args.latent_channels = vae_channels
394
- if vae_channels != args.latent_channels:
395
- raise ValueError(
396
- f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
397
- )
398
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/constants.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import torch
3
-
4
- __all__ = [
5
- "C_SCALE",
6
- "PROMPT_TEMPLATE",
7
- "MODEL_BASE",
8
- "PRECISIONS",
9
- "NORMALIZATION_TYPE",
10
- "ACTIVATION_TYPE",
11
- "VAE_PATH",
12
- "TEXT_ENCODER_PATH",
13
- "TOKENIZER_PATH",
14
- "TEXT_PROJECTION",
15
- "DATA_TYPE",
16
- "NEGATIVE_PROMPT",
17
- ]
18
-
19
- PRECISION_TO_TYPE = {
20
- 'fp32': torch.float32,
21
- 'fp16': torch.float16,
22
- 'bf16': torch.bfloat16,
23
- }
24
-
25
- # =================== Constant Values =====================
26
- # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
27
- # overflow error when tensorboard logging values.
28
- C_SCALE = 1_000_000_000_000_000
29
-
30
- # When using decoder-only models, we must provide a prompt template to instruct the text encoder
31
- # on how to generate the text.
32
- # --------------------------------------------------------------------
33
- PROMPT_TEMPLATE_ENCODE = (
34
- "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
35
- "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
36
- "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
37
- )
38
- PROMPT_TEMPLATE_ENCODE_VIDEO = (
39
- "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
40
- "1. The main content and theme of the video."
41
- "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
42
- "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
43
- "4. background environment, light, style and atmosphere."
44
- "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
45
- "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
46
- )
47
-
48
- NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
49
-
50
- PROMPT_TEMPLATE = {
51
- "dit-llm-encode": {
52
- "template": PROMPT_TEMPLATE_ENCODE,
53
- "crop_start": 36,
54
- },
55
- "dit-llm-encode-video": {
56
- "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
57
- "crop_start": 95,
58
- },
59
- }
60
-
61
- # ======================= Model ======================
62
- PRECISIONS = {"fp32", "fp16", "bf16"}
63
- NORMALIZATION_TYPE = {"layer", "rms"}
64
- ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
65
-
66
- # =================== Model Path =====================
67
- MODEL_BASE = os.getenv("MODEL_BASE", ".")
68
-
69
- # =================== Data =======================
70
- DATA_TYPE = {"image", "video", "image_video"}
71
-
72
- # 3D VAE
73
- VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
74
-
75
- # Text Encoder
76
- TEXT_ENCODER_PATH = {
77
- "clipL": f"{MODEL_BASE}/text_encoder_2",
78
- "llm": f"{MODEL_BASE}/text_encoder",
79
- }
80
-
81
- # Tokenizer
82
- TOKENIZER_PATH = {
83
- "clipL": f"{MODEL_BASE}/text_encoder_2",
84
- "llm": f"{MODEL_BASE}/text_encoder",
85
- }
86
-
87
- TEXT_PROJECTION = {
88
- "linear", # Default, an nn.Linear() layer
89
- "single_refiner", # Single TokenRefiner. Refer to LI-DiT
90
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/diffusion/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .pipelines import HunyuanVideoPipeline
2
- from .schedulers import FlowMatchDiscreteScheduler
 
 
 
hyvideo/diffusion/pipelines/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .pipeline_hunyuan_video import HunyuanVideoPipeline
 
 
hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py DELETED
@@ -1,1100 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Modified from diffusers==0.29.2
17
- #
18
- # ==============================================================================
19
- import inspect
20
- from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
- import torch
22
- import torch.distributed as dist
23
- import numpy as np
24
- from dataclasses import dataclass
25
- from packaging import version
26
-
27
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
- from diffusers.configuration_utils import FrozenDict
29
- from diffusers.image_processor import VaeImageProcessor
30
- from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
- from diffusers.models import AutoencoderKL
32
- from diffusers.models.lora import adjust_lora_scale_text_encoder
33
- from diffusers.schedulers import KarrasDiffusionSchedulers
34
- from diffusers.utils import (
35
- USE_PEFT_BACKEND,
36
- deprecate,
37
- logging,
38
- replace_example_docstring,
39
- scale_lora_layers,
40
- unscale_lora_layers,
41
- )
42
- from diffusers.utils.torch_utils import randn_tensor
43
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
- from diffusers.utils import BaseOutput
45
-
46
- from ...constants import PRECISION_TO_TYPE
47
- from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
- from ...text_encoder import TextEncoder
49
- from ...modules import HYVideoDiffusionTransformer
50
-
51
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
-
53
- EXAMPLE_DOC_STRING = """"""
54
-
55
-
56
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
- """
58
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
- """
61
- std_text = noise_pred_text.std(
62
- dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
- )
64
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
- # rescale the results from guidance (fixes overexposure)
66
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
- noise_cfg = (
69
- guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
- )
71
- return noise_cfg
72
-
73
-
74
- def retrieve_timesteps(
75
- scheduler,
76
- num_inference_steps: Optional[int] = None,
77
- device: Optional[Union[str, torch.device]] = None,
78
- timesteps: Optional[List[int]] = None,
79
- sigmas: Optional[List[float]] = None,
80
- **kwargs,
81
- ):
82
- """
83
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
-
86
- Args:
87
- scheduler (`SchedulerMixin`):
88
- The scheduler to get timesteps from.
89
- num_inference_steps (`int`):
90
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
- must be `None`.
92
- device (`str` or `torch.device`, *optional*):
93
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
- timesteps (`List[int]`, *optional*):
95
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
- `num_inference_steps` and `sigmas` must be `None`.
97
- sigmas (`List[float]`, *optional*):
98
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
- `num_inference_steps` and `timesteps` must be `None`.
100
-
101
- Returns:
102
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
- second element is the number of inference steps.
104
- """
105
- if timesteps is not None and sigmas is not None:
106
- raise ValueError(
107
- "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
- )
109
- if timesteps is not None:
110
- accepts_timesteps = "timesteps" in set(
111
- inspect.signature(scheduler.set_timesteps).parameters.keys()
112
- )
113
- if not accepts_timesteps:
114
- raise ValueError(
115
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
- f" timestep schedules. Please check whether you are using the correct scheduler."
117
- )
118
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
- timesteps = scheduler.timesteps
120
- num_inference_steps = len(timesteps)
121
- elif sigmas is not None:
122
- accept_sigmas = "sigmas" in set(
123
- inspect.signature(scheduler.set_timesteps).parameters.keys()
124
- )
125
- if not accept_sigmas:
126
- raise ValueError(
127
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
- f" sigmas schedules. Please check whether you are using the correct scheduler."
129
- )
130
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
- timesteps = scheduler.timesteps
132
- num_inference_steps = len(timesteps)
133
- else:
134
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
- timesteps = scheduler.timesteps
136
- return timesteps, num_inference_steps
137
-
138
-
139
- @dataclass
140
- class HunyuanVideoPipelineOutput(BaseOutput):
141
- videos: Union[torch.Tensor, np.ndarray]
142
-
143
-
144
- class HunyuanVideoPipeline(DiffusionPipeline):
145
- r"""
146
- Pipeline for text-to-video generation using HunyuanVideo.
147
-
148
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
-
151
- Args:
152
- vae ([`AutoencoderKL`]):
153
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
- text_encoder ([`TextEncoder`]):
155
- Frozen text-encoder.
156
- text_encoder_2 ([`TextEncoder`]):
157
- Frozen text-encoder_2.
158
- transformer ([`HYVideoDiffusionTransformer`]):
159
- A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
- scheduler ([`SchedulerMixin`]):
161
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
- """
163
-
164
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
- _optional_components = ["text_encoder_2"]
166
- _exclude_from_cpu_offload = ["transformer"]
167
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
-
169
- def __init__(
170
- self,
171
- vae: AutoencoderKL,
172
- text_encoder: TextEncoder,
173
- transformer: HYVideoDiffusionTransformer,
174
- scheduler: KarrasDiffusionSchedulers,
175
- text_encoder_2: Optional[TextEncoder] = None,
176
- progress_bar_config: Dict[str, Any] = None,
177
- args=None,
178
- ):
179
- super().__init__()
180
-
181
- # ==========================================================================================
182
- if progress_bar_config is None:
183
- progress_bar_config = {}
184
- if not hasattr(self, "_progress_bar_config"):
185
- self._progress_bar_config = {}
186
- self._progress_bar_config.update(progress_bar_config)
187
-
188
- self.args = args
189
- # ==========================================================================================
190
-
191
- if (
192
- hasattr(scheduler.config, "steps_offset")
193
- and scheduler.config.steps_offset != 1
194
- ):
195
- deprecation_message = (
196
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
- " file"
202
- )
203
- deprecate(
204
- "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
- )
206
- new_config = dict(scheduler.config)
207
- new_config["steps_offset"] = 1
208
- scheduler._internal_dict = FrozenDict(new_config)
209
-
210
- if (
211
- hasattr(scheduler.config, "clip_sample")
212
- and scheduler.config.clip_sample is True
213
- ):
214
- deprecation_message = (
215
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
- )
221
- deprecate(
222
- "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
- )
224
- new_config = dict(scheduler.config)
225
- new_config["clip_sample"] = False
226
- scheduler._internal_dict = FrozenDict(new_config)
227
-
228
- self.register_modules(
229
- vae=vae,
230
- text_encoder=text_encoder,
231
- transformer=transformer,
232
- scheduler=scheduler,
233
- text_encoder_2=text_encoder_2,
234
- )
235
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
-
238
- def encode_prompt(
239
- self,
240
- prompt,
241
- device,
242
- num_videos_per_prompt,
243
- do_classifier_free_guidance,
244
- negative_prompt=None,
245
- prompt_embeds: Optional[torch.Tensor] = None,
246
- attention_mask: Optional[torch.Tensor] = None,
247
- negative_prompt_embeds: Optional[torch.Tensor] = None,
248
- negative_attention_mask: Optional[torch.Tensor] = None,
249
- lora_scale: Optional[float] = None,
250
- clip_skip: Optional[int] = None,
251
- text_encoder: Optional[TextEncoder] = None,
252
- data_type: Optional[str] = "image",
253
- ):
254
- r"""
255
- Encodes the prompt into text encoder hidden states.
256
-
257
- Args:
258
- prompt (`str` or `List[str]`, *optional*):
259
- prompt to be encoded
260
- device: (`torch.device`):
261
- torch device
262
- num_videos_per_prompt (`int`):
263
- number of videos that should be generated per prompt
264
- do_classifier_free_guidance (`bool`):
265
- whether to use classifier free guidance or not
266
- negative_prompt (`str` or `List[str]`, *optional*):
267
- The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
- less than `1`).
270
- prompt_embeds (`torch.Tensor`, *optional*):
271
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
- provided, text embeddings will be generated from `prompt` input argument.
273
- attention_mask (`torch.Tensor`, *optional*):
274
- negative_prompt_embeds (`torch.Tensor`, *optional*):
275
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
- argument.
278
- negative_attention_mask (`torch.Tensor`, *optional*):
279
- lora_scale (`float`, *optional*):
280
- A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
- clip_skip (`int`, *optional*):
282
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
- the output of the pre-final layer will be used for computing the prompt embeddings.
284
- text_encoder (TextEncoder, *optional*):
285
- data_type (`str`, *optional*):
286
- """
287
- if text_encoder is None:
288
- text_encoder = self.text_encoder
289
-
290
- # set lora scale so that monkey patched LoRA
291
- # function of text encoder can correctly access it
292
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
- self._lora_scale = lora_scale
294
-
295
- # dynamically adjust the LoRA scale
296
- if not USE_PEFT_BACKEND:
297
- adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
- else:
299
- scale_lora_layers(text_encoder.model, lora_scale)
300
-
301
- if prompt is not None and isinstance(prompt, str):
302
- batch_size = 1
303
- elif prompt is not None and isinstance(prompt, list):
304
- batch_size = len(prompt)
305
- else:
306
- batch_size = prompt_embeds.shape[0]
307
-
308
- if prompt_embeds is None:
309
- # textual inversion: process multi-vector tokens if necessary
310
- if isinstance(self, TextualInversionLoaderMixin):
311
- prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
-
313
- text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
-
315
- if clip_skip is None:
316
- prompt_outputs = text_encoder.encode(
317
- text_inputs, data_type=data_type, device=device
318
- )
319
- prompt_embeds = prompt_outputs.hidden_state
320
- else:
321
- prompt_outputs = text_encoder.encode(
322
- text_inputs,
323
- output_hidden_states=True,
324
- data_type=data_type,
325
- device=device,
326
- )
327
- # Access the `hidden_states` first, that contains a tuple of
328
- # all the hidden states from the encoder layers. Then index into
329
- # the tuple to access the hidden states from the desired layer.
330
- prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
- # We also need to apply the final LayerNorm here to not mess with the
332
- # representations. The `last_hidden_states` that we typically use for
333
- # obtaining the final prompt representations passes through the LayerNorm
334
- # layer.
335
- prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
- prompt_embeds
337
- )
338
-
339
- attention_mask = prompt_outputs.attention_mask
340
- if attention_mask is not None:
341
- attention_mask = attention_mask.to(device)
342
- bs_embed, seq_len = attention_mask.shape
343
- attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
- attention_mask = attention_mask.view(
345
- bs_embed * num_videos_per_prompt, seq_len
346
- )
347
-
348
- if text_encoder is not None:
349
- prompt_embeds_dtype = text_encoder.dtype
350
- elif self.transformer is not None:
351
- prompt_embeds_dtype = self.transformer.dtype
352
- else:
353
- prompt_embeds_dtype = prompt_embeds.dtype
354
-
355
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
-
357
- if prompt_embeds.ndim == 2:
358
- bs_embed, _ = prompt_embeds.shape
359
- # duplicate text embeddings for each generation per prompt, using mps friendly method
360
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
- prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
- else:
363
- bs_embed, seq_len, _ = prompt_embeds.shape
364
- # duplicate text embeddings for each generation per prompt, using mps friendly method
365
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
- prompt_embeds = prompt_embeds.view(
367
- bs_embed * num_videos_per_prompt, seq_len, -1
368
- )
369
-
370
- # get unconditional embeddings for classifier free guidance
371
- if do_classifier_free_guidance and negative_prompt_embeds is None:
372
- uncond_tokens: List[str]
373
- if negative_prompt is None:
374
- uncond_tokens = [""] * batch_size
375
- elif prompt is not None and type(prompt) is not type(negative_prompt):
376
- raise TypeError(
377
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
- f" {type(prompt)}."
379
- )
380
- elif isinstance(negative_prompt, str):
381
- uncond_tokens = [negative_prompt]
382
- elif batch_size != len(negative_prompt):
383
- raise ValueError(
384
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
- " the batch size of `prompt`."
387
- )
388
- else:
389
- uncond_tokens = negative_prompt
390
-
391
- # textual inversion: process multi-vector tokens if necessary
392
- if isinstance(self, TextualInversionLoaderMixin):
393
- uncond_tokens = self.maybe_convert_prompt(
394
- uncond_tokens, text_encoder.tokenizer
395
- )
396
-
397
- # max_length = prompt_embeds.shape[1]
398
- uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
-
400
- negative_prompt_outputs = text_encoder.encode(
401
- uncond_input, data_type=data_type, device=device
402
- )
403
- negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
-
405
- negative_attention_mask = negative_prompt_outputs.attention_mask
406
- if negative_attention_mask is not None:
407
- negative_attention_mask = negative_attention_mask.to(device)
408
- _, seq_len = negative_attention_mask.shape
409
- negative_attention_mask = negative_attention_mask.repeat(
410
- 1, num_videos_per_prompt
411
- )
412
- negative_attention_mask = negative_attention_mask.view(
413
- batch_size * num_videos_per_prompt, seq_len
414
- )
415
-
416
- if do_classifier_free_guidance:
417
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
- seq_len = negative_prompt_embeds.shape[1]
419
-
420
- negative_prompt_embeds = negative_prompt_embeds.to(
421
- dtype=prompt_embeds_dtype, device=device
422
- )
423
-
424
- if negative_prompt_embeds.ndim == 2:
425
- negative_prompt_embeds = negative_prompt_embeds.repeat(
426
- 1, num_videos_per_prompt
427
- )
428
- negative_prompt_embeds = negative_prompt_embeds.view(
429
- batch_size * num_videos_per_prompt, -1
430
- )
431
- else:
432
- negative_prompt_embeds = negative_prompt_embeds.repeat(
433
- 1, num_videos_per_prompt, 1
434
- )
435
- negative_prompt_embeds = negative_prompt_embeds.view(
436
- batch_size * num_videos_per_prompt, seq_len, -1
437
- )
438
-
439
- if text_encoder is not None:
440
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
- # Retrieve the original scale by scaling back the LoRA layers
442
- unscale_lora_layers(text_encoder.model, lora_scale)
443
-
444
- return (
445
- prompt_embeds,
446
- negative_prompt_embeds,
447
- attention_mask,
448
- negative_attention_mask,
449
- )
450
-
451
- def decode_latents(self, latents, enable_tiling=True):
452
- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
- deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
-
455
- latents = 1 / self.vae.config.scaling_factor * latents
456
- if enable_tiling:
457
- self.vae.enable_tiling()
458
- image = self.vae.decode(latents, return_dict=False)[0]
459
- else:
460
- image = self.vae.decode(latents, return_dict=False)[0]
461
- image = (image / 2 + 0.5).clamp(0, 1)
462
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
- if image.ndim == 4:
464
- image = image.cpu().permute(0, 2, 3, 1).float()
465
- else:
466
- image = image.cpu().float()
467
- return image
468
-
469
- def prepare_extra_func_kwargs(self, func, kwargs):
470
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
- # and should be between [0, 1]
474
- extra_step_kwargs = {}
475
-
476
- for k, v in kwargs.items():
477
- accepts = k in set(inspect.signature(func).parameters.keys())
478
- if accepts:
479
- extra_step_kwargs[k] = v
480
- return extra_step_kwargs
481
-
482
- def check_inputs(
483
- self,
484
- prompt,
485
- height,
486
- width,
487
- video_length,
488
- callback_steps,
489
- negative_prompt=None,
490
- prompt_embeds=None,
491
- negative_prompt_embeds=None,
492
- callback_on_step_end_tensor_inputs=None,
493
- vae_ver="88-4c-sd",
494
- ):
495
- if height % 8 != 0 or width % 8 != 0:
496
- raise ValueError(
497
- f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
- )
499
-
500
- if video_length is not None:
501
- if "884" in vae_ver:
502
- if video_length != 1 and (video_length - 1) % 4 != 0:
503
- raise ValueError(
504
- f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
- )
506
- elif "888" in vae_ver:
507
- if video_length != 1 and (video_length - 1) % 8 != 0:
508
- raise ValueError(
509
- f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
- )
511
-
512
- if callback_steps is not None and (
513
- not isinstance(callback_steps, int) or callback_steps <= 0
514
- ):
515
- raise ValueError(
516
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
- f" {type(callback_steps)}."
518
- )
519
- if callback_on_step_end_tensor_inputs is not None and not all(
520
- k in self._callback_tensor_inputs
521
- for k in callback_on_step_end_tensor_inputs
522
- ):
523
- raise ValueError(
524
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
- )
526
-
527
- if prompt is not None and prompt_embeds is not None:
528
- raise ValueError(
529
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
- " only forward one of the two."
531
- )
532
- elif prompt is None and prompt_embeds is None:
533
- raise ValueError(
534
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
- )
536
- elif prompt is not None and (
537
- not isinstance(prompt, str) and not isinstance(prompt, list)
538
- ):
539
- raise ValueError(
540
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
- )
542
-
543
- if negative_prompt is not None and negative_prompt_embeds is not None:
544
- raise ValueError(
545
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
- )
548
-
549
- if prompt_embeds is not None and negative_prompt_embeds is not None:
550
- if prompt_embeds.shape != negative_prompt_embeds.shape:
551
- raise ValueError(
552
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
- f" {negative_prompt_embeds.shape}."
555
- )
556
-
557
-
558
- def prepare_latents(
559
- self,
560
- batch_size,
561
- num_channels_latents,
562
- height,
563
- width,
564
- video_length,
565
- dtype,
566
- device,
567
- generator,
568
- latents=None,
569
- ):
570
- shape = (
571
- batch_size,
572
- num_channels_latents,
573
- video_length,
574
- int(height) // self.vae_scale_factor,
575
- int(width) // self.vae_scale_factor,
576
- )
577
- if isinstance(generator, list) and len(generator) != batch_size:
578
- raise ValueError(
579
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
- )
582
-
583
- if latents is None:
584
- latents = randn_tensor(
585
- shape, generator=generator, device=device, dtype=dtype
586
- )
587
- else:
588
- latents = latents.to(device)
589
-
590
- # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
- if hasattr(self.scheduler, "init_noise_sigma"):
592
- # scale the initial noise by the standard deviation required by the scheduler
593
- latents = latents * self.scheduler.init_noise_sigma
594
- return latents
595
-
596
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
- def get_guidance_scale_embedding(
598
- self,
599
- w: torch.Tensor,
600
- embedding_dim: int = 512,
601
- dtype: torch.dtype = torch.float32,
602
- ) -> torch.Tensor:
603
- """
604
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
-
606
- Args:
607
- w (`torch.Tensor`):
608
- Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
- embedding_dim (`int`, *optional*, defaults to 512):
610
- Dimension of the embeddings to generate.
611
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
- Data type of the generated embeddings.
613
-
614
- Returns:
615
- `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
- """
617
- assert len(w.shape) == 1
618
- w = w * 1000.0
619
-
620
- half_dim = embedding_dim // 2
621
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
- emb = w.to(dtype)[:, None] * emb[None, :]
624
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
- if embedding_dim % 2 == 1: # zero pad
626
- emb = torch.nn.functional.pad(emb, (0, 1))
627
- assert emb.shape == (w.shape[0], embedding_dim)
628
- return emb
629
-
630
- @property
631
- def guidance_scale(self):
632
- return self._guidance_scale
633
-
634
- @property
635
- def guidance_rescale(self):
636
- return self._guidance_rescale
637
-
638
- @property
639
- def clip_skip(self):
640
- return self._clip_skip
641
-
642
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
- # corresponds to doing no classifier free guidance.
645
- @property
646
- def do_classifier_free_guidance(self):
647
- # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
- return self._guidance_scale > 1
649
-
650
- @property
651
- def cross_attention_kwargs(self):
652
- return self._cross_attention_kwargs
653
-
654
- @property
655
- def num_timesteps(self):
656
- return self._num_timesteps
657
-
658
- @property
659
- def interrupt(self):
660
- return self._interrupt
661
-
662
- @torch.no_grad()
663
- @replace_example_docstring(EXAMPLE_DOC_STRING)
664
- def __call__(
665
- self,
666
- prompt: Union[str, List[str]],
667
- height: int,
668
- width: int,
669
- video_length: int,
670
- data_type: str = "video",
671
- num_inference_steps: int = 50,
672
- timesteps: List[int] = None,
673
- sigmas: List[float] = None,
674
- guidance_scale: float = 7.5,
675
- negative_prompt: Optional[Union[str, List[str]]] = None,
676
- num_videos_per_prompt: Optional[int] = 1,
677
- eta: float = 0.0,
678
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
- latents: Optional[torch.Tensor] = None,
680
- prompt_embeds: Optional[torch.Tensor] = None,
681
- attention_mask: Optional[torch.Tensor] = None,
682
- negative_prompt_embeds: Optional[torch.Tensor] = None,
683
- negative_attention_mask: Optional[torch.Tensor] = None,
684
- output_type: Optional[str] = "pil",
685
- return_dict: bool = True,
686
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
- guidance_rescale: float = 0.0,
688
- clip_skip: Optional[int] = None,
689
- callback_on_step_end: Optional[
690
- Union[
691
- Callable[[int, int, Dict], None],
692
- PipelineCallback,
693
- MultiPipelineCallbacks,
694
- ]
695
- ] = None,
696
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
- freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
- vae_ver: str = "88-4c-sd",
699
- enable_tiling: bool = False,
700
- n_tokens: Optional[int] = None,
701
- embedded_guidance_scale: Optional[float] = None,
702
- **kwargs,
703
- ):
704
- r"""
705
- The call function to the pipeline for generation.
706
-
707
- Args:
708
- prompt (`str` or `List[str]`):
709
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
- height (`int`):
711
- The height in pixels of the generated image.
712
- width (`int`):
713
- The width in pixels of the generated image.
714
- video_length (`int`):
715
- The number of frames in the generated video.
716
- num_inference_steps (`int`, *optional*, defaults to 50):
717
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
- expense of slower inference.
719
- timesteps (`List[int]`, *optional*):
720
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
- passed will be used. Must be in descending order.
723
- sigmas (`List[float]`, *optional*):
724
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
- will be used.
727
- guidance_scale (`float`, *optional*, defaults to 7.5):
728
- A higher guidance scale value encourages the model to generate images closely linked to the text
729
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
- negative_prompt (`str` or `List[str]`, *optional*):
731
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
- The number of images to generate per prompt.
735
- eta (`float`, *optional*, defaults to 0.0):
736
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
- generation deterministic.
741
- latents (`torch.Tensor`, *optional*):
742
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
- tensor is generated by sampling using the supplied random `generator`.
745
- prompt_embeds (`torch.Tensor`, *optional*):
746
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
- provided, text embeddings are generated from the `prompt` input argument.
748
- negative_prompt_embeds (`torch.Tensor`, *optional*):
749
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
-
752
- output_type (`str`, *optional*, defaults to `"pil"`):
753
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
- return_dict (`bool`, *optional*, defaults to `True`):
755
- Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
- plain tuple.
757
- cross_attention_kwargs (`dict`, *optional*):
758
- A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
- guidance_rescale (`float`, *optional*, defaults to 0.0):
761
- Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
- using zero terminal SNR.
764
- clip_skip (`int`, *optional*):
765
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
- the output of the pre-final layer will be used for computing the prompt embeddings.
767
- callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
- A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
- each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
- list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
- callback_on_step_end_tensor_inputs (`List`, *optional*):
773
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
- `._callback_tensor_inputs` attribute of your pipeline class.
776
-
777
- Examples:
778
-
779
- Returns:
780
- [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
- If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
- second element is a list of `bool`s indicating whether the corresponding generated image contains
784
- "not-safe-for-work" (nsfw) content.
785
- """
786
- callback = kwargs.pop("callback", None)
787
- callback_steps = kwargs.pop("callback_steps", None)
788
-
789
- if callback is not None:
790
- deprecate(
791
- "callback",
792
- "1.0.0",
793
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
- )
795
- if callback_steps is not None:
796
- deprecate(
797
- "callback_steps",
798
- "1.0.0",
799
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
- )
801
-
802
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
-
805
- # 0. Default height and width to unet
806
- # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
- # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
- # to deal with lora scaling and other possible forward hooks
809
-
810
- # 1. Check inputs. Raise error if not correct
811
- self.check_inputs(
812
- prompt,
813
- height,
814
- width,
815
- video_length,
816
- callback_steps,
817
- negative_prompt,
818
- prompt_embeds,
819
- negative_prompt_embeds,
820
- callback_on_step_end_tensor_inputs,
821
- vae_ver=vae_ver,
822
- )
823
-
824
- self._guidance_scale = guidance_scale
825
- self._guidance_rescale = guidance_rescale
826
- self._clip_skip = clip_skip
827
- self._cross_attention_kwargs = cross_attention_kwargs
828
- self._interrupt = False
829
-
830
- # 2. Define call parameters
831
- if prompt is not None and isinstance(prompt, str):
832
- batch_size = 1
833
- elif prompt is not None and isinstance(prompt, list):
834
- batch_size = len(prompt)
835
- else:
836
- batch_size = prompt_embeds.shape[0]
837
-
838
- device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
-
840
- # 3. Encode input prompt
841
- lora_scale = (
842
- self.cross_attention_kwargs.get("scale", None)
843
- if self.cross_attention_kwargs is not None
844
- else None
845
- )
846
-
847
- (
848
- prompt_embeds,
849
- negative_prompt_embeds,
850
- prompt_mask,
851
- negative_prompt_mask,
852
- ) = self.encode_prompt(
853
- prompt,
854
- device,
855
- num_videos_per_prompt,
856
- self.do_classifier_free_guidance,
857
- negative_prompt,
858
- prompt_embeds=prompt_embeds,
859
- attention_mask=attention_mask,
860
- negative_prompt_embeds=negative_prompt_embeds,
861
- negative_attention_mask=negative_attention_mask,
862
- lora_scale=lora_scale,
863
- clip_skip=self.clip_skip,
864
- data_type=data_type,
865
- )
866
- if self.text_encoder_2 is not None:
867
- (
868
- prompt_embeds_2,
869
- negative_prompt_embeds_2,
870
- prompt_mask_2,
871
- negative_prompt_mask_2,
872
- ) = self.encode_prompt(
873
- prompt,
874
- device,
875
- num_videos_per_prompt,
876
- self.do_classifier_free_guidance,
877
- negative_prompt,
878
- prompt_embeds=None,
879
- attention_mask=None,
880
- negative_prompt_embeds=None,
881
- negative_attention_mask=None,
882
- lora_scale=lora_scale,
883
- clip_skip=self.clip_skip,
884
- text_encoder=self.text_encoder_2,
885
- data_type=data_type,
886
- )
887
- else:
888
- prompt_embeds_2 = None
889
- negative_prompt_embeds_2 = None
890
- prompt_mask_2 = None
891
- negative_prompt_mask_2 = None
892
-
893
- # For classifier free guidance, we need to do two forward passes.
894
- # Here we concatenate the unconditional and text embeddings into a single batch
895
- # to avoid doing two forward passes
896
- if self.do_classifier_free_guidance:
897
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
- if prompt_mask is not None:
899
- prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
- if prompt_embeds_2 is not None:
901
- prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
- if prompt_mask_2 is not None:
903
- prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
-
905
-
906
- # 4. Prepare timesteps
907
- extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
- self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
- )
910
- timesteps, num_inference_steps = retrieve_timesteps(
911
- self.scheduler,
912
- num_inference_steps,
913
- device,
914
- timesteps,
915
- sigmas,
916
- **extra_set_timesteps_kwargs,
917
- )
918
-
919
- if "884" in vae_ver:
920
- video_length = (video_length - 1) // 4 + 1
921
- elif "888" in vae_ver:
922
- video_length = (video_length - 1) // 8 + 1
923
- else:
924
- video_length = video_length
925
-
926
- # 5. Prepare latent variables
927
- num_channels_latents = self.transformer.config.in_channels
928
- latents = self.prepare_latents(
929
- batch_size * num_videos_per_prompt,
930
- num_channels_latents,
931
- height,
932
- width,
933
- video_length,
934
- prompt_embeds.dtype,
935
- device,
936
- generator,
937
- latents,
938
- )
939
-
940
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
- extra_step_kwargs = self.prepare_extra_func_kwargs(
942
- self.scheduler.step,
943
- {"generator": generator, "eta": eta},
944
- )
945
-
946
- target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
- autocast_enabled = (
948
- target_dtype != torch.float32
949
- ) and not self.args.disable_autocast
950
- vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
- vae_autocast_enabled = (
952
- vae_dtype != torch.float32
953
- ) and not self.args.disable_autocast
954
-
955
- # 7. Denoising loop
956
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
- self._num_timesteps = len(timesteps)
958
-
959
- # if is_progress_bar:
960
- with self.progress_bar(total=num_inference_steps) as progress_bar:
961
- for i, t in enumerate(timesteps):
962
- if self.interrupt:
963
- continue
964
-
965
- # expand the latents if we are doing classifier free guidance
966
- latent_model_input = (
967
- torch.cat([latents] * 2)
968
- if self.do_classifier_free_guidance
969
- else latents
970
- )
971
- latent_model_input = self.scheduler.scale_model_input(
972
- latent_model_input, t
973
- )
974
-
975
- t_expand = t.repeat(latent_model_input.shape[0])
976
- guidance_expand = (
977
- torch.tensor(
978
- [embedded_guidance_scale] * latent_model_input.shape[0],
979
- dtype=torch.float32,
980
- device=device,
981
- ).to(target_dtype)
982
- * 1000.0
983
- if embedded_guidance_scale is not None
984
- else None
985
- )
986
-
987
- # predict the noise residual
988
- with torch.autocast(
989
- device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
990
- ):
991
- noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
992
- latent_model_input, # [2, 16, 33, 24, 42]
993
- t_expand, # [2]
994
- text_states=prompt_embeds, # [2, 256, 4096]
995
- text_mask=prompt_mask, # [2, 256]
996
- text_states_2=prompt_embeds_2, # [2, 768]
997
- freqs_cos=freqs_cis[0], # [seqlen, head_dim]
998
- freqs_sin=freqs_cis[1], # [seqlen, head_dim]
999
- guidance=guidance_expand,
1000
- return_dict=True,
1001
- )[
1002
- "x"
1003
- ]
1004
-
1005
- # perform guidance
1006
- if self.do_classifier_free_guidance:
1007
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1008
- noise_pred = noise_pred_uncond + self.guidance_scale * (
1009
- noise_pred_text - noise_pred_uncond
1010
- )
1011
-
1012
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1013
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1014
- noise_pred = rescale_noise_cfg(
1015
- noise_pred,
1016
- noise_pred_text,
1017
- guidance_rescale=self.guidance_rescale,
1018
- )
1019
-
1020
- # compute the previous noisy sample x_t -> x_t-1
1021
- latents = self.scheduler.step(
1022
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1023
- )[0]
1024
-
1025
- if callback_on_step_end is not None:
1026
- callback_kwargs = {}
1027
- for k in callback_on_step_end_tensor_inputs:
1028
- callback_kwargs[k] = locals()[k]
1029
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1030
-
1031
- latents = callback_outputs.pop("latents", latents)
1032
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1033
- negative_prompt_embeds = callback_outputs.pop(
1034
- "negative_prompt_embeds", negative_prompt_embeds
1035
- )
1036
-
1037
- # call the callback, if provided
1038
- if i == len(timesteps) - 1 or (
1039
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1040
- ):
1041
- if progress_bar is not None:
1042
- progress_bar.update()
1043
- if callback is not None and i % callback_steps == 0:
1044
- step_idx = i // getattr(self.scheduler, "order", 1)
1045
- callback(step_idx, t, latents)
1046
-
1047
- if not output_type == "latent":
1048
- expand_temporal_dim = False
1049
- if len(latents.shape) == 4:
1050
- if isinstance(self.vae, AutoencoderKLCausal3D):
1051
- latents = latents.unsqueeze(2)
1052
- expand_temporal_dim = True
1053
- elif len(latents.shape) == 5:
1054
- pass
1055
- else:
1056
- raise ValueError(
1057
- f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1058
- )
1059
-
1060
- if (
1061
- hasattr(self.vae.config, "shift_factor")
1062
- and self.vae.config.shift_factor
1063
- ):
1064
- latents = (
1065
- latents / self.vae.config.scaling_factor
1066
- + self.vae.config.shift_factor
1067
- )
1068
- else:
1069
- latents = latents / self.vae.config.scaling_factor
1070
-
1071
- with torch.autocast(
1072
- device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1073
- ):
1074
- if enable_tiling:
1075
- self.vae.enable_tiling()
1076
- image = self.vae.decode(
1077
- latents, return_dict=False, generator=generator
1078
- )[0]
1079
- else:
1080
- image = self.vae.decode(
1081
- latents, return_dict=False, generator=generator
1082
- )[0]
1083
-
1084
- if expand_temporal_dim or image.shape[2] == 1:
1085
- image = image.squeeze(2)
1086
-
1087
- else:
1088
- image = latents
1089
-
1090
- image = (image / 2 + 0.5).clamp(0, 1)
1091
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1092
- image = image.cpu().float()
1093
-
1094
- # Offload all models
1095
- self.maybe_free_model_hooks()
1096
-
1097
- if not return_dict:
1098
- return image
1099
-
1100
- return HunyuanVideoPipelineOutput(videos=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/diffusion/schedulers/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
 
 
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py DELETED
@@ -1,257 +0,0 @@
1
- # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Modified from diffusers==0.29.2
17
- #
18
- # ==============================================================================
19
-
20
- from dataclasses import dataclass
21
- from typing import Optional, Tuple, Union
22
-
23
- import numpy as np
24
- import torch
25
-
26
- from diffusers.configuration_utils import ConfigMixin, register_to_config
27
- from diffusers.utils import BaseOutput, logging
28
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
-
30
-
31
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
-
33
-
34
- @dataclass
35
- class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
- """
37
- Output class for the scheduler's `step` function output.
38
-
39
- Args:
40
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
- denoising loop.
43
- """
44
-
45
- prev_sample: torch.FloatTensor
46
-
47
-
48
- class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
- """
50
- Euler scheduler.
51
-
52
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
- methods the library implements for all schedulers such as loading and saving.
54
-
55
- Args:
56
- num_train_timesteps (`int`, defaults to 1000):
57
- The number of diffusion steps to train the model.
58
- timestep_spacing (`str`, defaults to `"linspace"`):
59
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
- shift (`float`, defaults to 1.0):
62
- The shift value for the timestep schedule.
63
- reverse (`bool`, defaults to `True`):
64
- Whether to reverse the timestep schedule.
65
- """
66
-
67
- _compatibles = []
68
- order = 1
69
-
70
- @register_to_config
71
- def __init__(
72
- self,
73
- num_train_timesteps: int = 1000,
74
- shift: float = 1.0,
75
- reverse: bool = True,
76
- solver: str = "euler",
77
- n_tokens: Optional[int] = None,
78
- ):
79
- sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
-
81
- if not reverse:
82
- sigmas = sigmas.flip(0)
83
-
84
- self.sigmas = sigmas
85
- # the value fed to model
86
- self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
-
88
- self._step_index = None
89
- self._begin_index = None
90
-
91
- self.supported_solver = ["euler"]
92
- if solver not in self.supported_solver:
93
- raise ValueError(
94
- f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
- )
96
-
97
- @property
98
- def step_index(self):
99
- """
100
- The index counter for current timestep. It will increase 1 after each scheduler step.
101
- """
102
- return self._step_index
103
-
104
- @property
105
- def begin_index(self):
106
- """
107
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
- """
109
- return self._begin_index
110
-
111
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
- def set_begin_index(self, begin_index: int = 0):
113
- """
114
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
-
116
- Args:
117
- begin_index (`int`):
118
- The begin index for the scheduler.
119
- """
120
- self._begin_index = begin_index
121
-
122
- def _sigma_to_t(self, sigma):
123
- return sigma * self.config.num_train_timesteps
124
-
125
- def set_timesteps(
126
- self,
127
- num_inference_steps: int,
128
- device: Union[str, torch.device] = None,
129
- n_tokens: int = None,
130
- ):
131
- """
132
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
-
134
- Args:
135
- num_inference_steps (`int`):
136
- The number of diffusion steps used when generating samples with a pre-trained model.
137
- device (`str` or `torch.device`, *optional*):
138
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
- n_tokens (`int`, *optional*):
140
- Number of tokens in the input sequence.
141
- """
142
- self.num_inference_steps = num_inference_steps
143
-
144
- sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
- sigmas = self.sd3_time_shift(sigmas)
146
-
147
- if not self.config.reverse:
148
- sigmas = 1 - sigmas
149
-
150
- self.sigmas = sigmas
151
- self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
- dtype=torch.float32, device=device
153
- )
154
-
155
- # Reset step index
156
- self._step_index = None
157
-
158
- def index_for_timestep(self, timestep, schedule_timesteps=None):
159
- if schedule_timesteps is None:
160
- schedule_timesteps = self.timesteps
161
-
162
- indices = (schedule_timesteps == timestep).nonzero()
163
-
164
- # The sigma index that is taken for the **very** first `step`
165
- # is always the second index (or the last index if there is only 1)
166
- # This way we can ensure we don't accidentally skip a sigma in
167
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
- pos = 1 if len(indices) > 1 else 0
169
-
170
- return indices[pos].item()
171
-
172
- def _init_step_index(self, timestep):
173
- if self.begin_index is None:
174
- if isinstance(timestep, torch.Tensor):
175
- timestep = timestep.to(self.timesteps.device)
176
- self._step_index = self.index_for_timestep(timestep)
177
- else:
178
- self._step_index = self._begin_index
179
-
180
- def scale_model_input(
181
- self, sample: torch.Tensor, timestep: Optional[int] = None
182
- ) -> torch.Tensor:
183
- return sample
184
-
185
- def sd3_time_shift(self, t: torch.Tensor):
186
- return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
-
188
- def step(
189
- self,
190
- model_output: torch.FloatTensor,
191
- timestep: Union[float, torch.FloatTensor],
192
- sample: torch.FloatTensor,
193
- return_dict: bool = True,
194
- ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
- """
196
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
- process from the learned model outputs (most often the predicted noise).
198
-
199
- Args:
200
- model_output (`torch.FloatTensor`):
201
- The direct output from learned diffusion model.
202
- timestep (`float`):
203
- The current discrete timestep in the diffusion chain.
204
- sample (`torch.FloatTensor`):
205
- A current instance of a sample created by the diffusion process.
206
- generator (`torch.Generator`, *optional*):
207
- A random number generator.
208
- n_tokens (`int`, *optional*):
209
- Number of tokens in the input sequence.
210
- return_dict (`bool`):
211
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
- tuple.
213
-
214
- Returns:
215
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
- returned, otherwise a tuple is returned where the first element is the sample tensor.
218
- """
219
-
220
- if (
221
- isinstance(timestep, int)
222
- or isinstance(timestep, torch.IntTensor)
223
- or isinstance(timestep, torch.LongTensor)
224
- ):
225
- raise ValueError(
226
- (
227
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
- " one of the `scheduler.timesteps` as a timestep."
230
- ),
231
- )
232
-
233
- if self.step_index is None:
234
- self._init_step_index(timestep)
235
-
236
- # Upcast to avoid precision issues when computing prev_sample
237
- sample = sample.to(torch.float32)
238
-
239
- dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
-
241
- if self.config.solver == "euler":
242
- prev_sample = sample + model_output.to(torch.float32) * dt
243
- else:
244
- raise ValueError(
245
- f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
- )
247
-
248
- # upon completion increase step index by one
249
- self._step_index += 1
250
-
251
- if not return_dict:
252
- return (prev_sample,)
253
-
254
- return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
-
256
- def __len__(self):
257
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/inference.py DELETED
@@ -1,671 +0,0 @@
1
- import os
2
- import time
3
- import random
4
- import functools
5
- from typing import List, Optional, Tuple, Union
6
-
7
- from pathlib import Path
8
- from loguru import logger
9
-
10
- import torch
11
- import torch.distributed as dist
12
- from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
13
- from hyvideo.vae import load_vae
14
- from hyvideo.modules import load_model
15
- from hyvideo.text_encoder import TextEncoder
16
- from hyvideo.utils.data_utils import align_to
17
- from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
18
- from hyvideo.modules.fp8_optimization import convert_fp8_linear
19
- from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
20
- from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
21
-
22
- try:
23
- import xfuser
24
- from xfuser.core.distributed import (
25
- get_sequence_parallel_world_size,
26
- get_sequence_parallel_rank,
27
- get_sp_group,
28
- initialize_model_parallel,
29
- init_distributed_environment
30
- )
31
- except:
32
- xfuser = None
33
- get_sequence_parallel_world_size = None
34
- get_sequence_parallel_rank = None
35
- get_sp_group = None
36
- initialize_model_parallel = None
37
- init_distributed_environment = None
38
-
39
-
40
- def parallelize_transformer(pipe):
41
- transformer = pipe.transformer
42
- original_forward = transformer.forward
43
-
44
- @functools.wraps(transformer.__class__.forward)
45
- def new_forward(
46
- self,
47
- x: torch.Tensor,
48
- t: torch.Tensor, # Should be in range(0, 1000).
49
- text_states: torch.Tensor = None,
50
- text_mask: torch.Tensor = None, # Now we don't use it.
51
- text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
52
- freqs_cos: Optional[torch.Tensor] = None,
53
- freqs_sin: Optional[torch.Tensor] = None,
54
- guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
55
- return_dict: bool = True,
56
- ):
57
- if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
58
- # try to split x by height
59
- split_dim = -2
60
- elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
61
- # try to split x by width
62
- split_dim = -1
63
- else:
64
- raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
65
-
66
- # patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
67
- temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
68
-
69
- x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
70
-
71
- dim_thw = freqs_cos.shape[-1]
72
- freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
73
- freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
74
- freqs_cos = freqs_cos.reshape(-1, dim_thw)
75
- dim_thw = freqs_sin.shape[-1]
76
- freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
77
- freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
78
- freqs_sin = freqs_sin.reshape(-1, dim_thw)
79
-
80
- from xfuser.core.long_ctx_attention import xFuserLongContextAttention
81
-
82
- for block in transformer.double_blocks + transformer.single_blocks:
83
- block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
84
-
85
- output = original_forward(
86
- x,
87
- t,
88
- text_states,
89
- text_mask,
90
- text_states_2,
91
- freqs_cos,
92
- freqs_sin,
93
- guidance,
94
- return_dict,
95
- )
96
-
97
- return_dict = not isinstance(output, tuple)
98
- sample = output["x"]
99
- sample = get_sp_group().all_gather(sample, dim=split_dim)
100
- output["x"] = sample
101
- return output
102
-
103
- new_forward = new_forward.__get__(transformer)
104
- transformer.forward = new_forward
105
-
106
-
107
- class Inference(object):
108
- def __init__(
109
- self,
110
- args,
111
- vae,
112
- vae_kwargs,
113
- text_encoder,
114
- model,
115
- text_encoder_2=None,
116
- pipeline=None,
117
- use_cpu_offload=False,
118
- device=None,
119
- logger=None,
120
- parallel_args=None,
121
- ):
122
- self.vae = vae
123
- self.vae_kwargs = vae_kwargs
124
-
125
- self.text_encoder = text_encoder
126
- self.text_encoder_2 = text_encoder_2
127
-
128
- self.model = model
129
- self.pipeline = pipeline
130
- self.use_cpu_offload = use_cpu_offload
131
-
132
- self.args = args
133
- self.device = (
134
- device
135
- if device is not None
136
- else "cuda"
137
- if torch.cuda.is_available()
138
- else "cpu"
139
- )
140
- self.logger = logger
141
- self.parallel_args = parallel_args
142
-
143
- @classmethod
144
- def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
145
- """
146
- Initialize the Inference pipeline.
147
-
148
- Args:
149
- pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
150
- args (argparse.Namespace): The arguments for the pipeline.
151
- device (int): The device for inference. Default is 0.
152
- """
153
- # ========================================================================
154
- logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
155
-
156
- # ==================== Initialize Distributed Environment ================
157
- if args.ulysses_degree > 1 or args.ring_degree > 1:
158
- assert xfuser is not None, \
159
- "Ulysses Attention and Ring Attention requires xfuser package."
160
-
161
- assert args.use_cpu_offload is False, \
162
- "Cannot enable use_cpu_offload in the distributed environment."
163
-
164
- dist.init_process_group("nccl")
165
-
166
- assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
167
- "number of GPUs should be equal to ring_degree * ulysses_degree."
168
-
169
- init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
170
-
171
- initialize_model_parallel(
172
- sequence_parallel_degree=dist.get_world_size(),
173
- ring_degree=args.ring_degree,
174
- ulysses_degree=args.ulysses_degree,
175
- )
176
- device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
177
- else:
178
- if device is None:
179
- device = "cuda" if torch.cuda.is_available() else "cpu"
180
-
181
- parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
182
-
183
- # ======================== Get the args path =============================
184
-
185
- # Disable gradient
186
- torch.set_grad_enabled(False)
187
-
188
- # =========================== Build main model ===========================
189
- logger.info("Building model...")
190
- factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
191
- in_channels = args.latent_channels
192
- out_channels = args.latent_channels
193
-
194
- model = load_model(
195
- args,
196
- in_channels=in_channels,
197
- out_channels=out_channels,
198
- factor_kwargs=factor_kwargs,
199
- )
200
- if args.use_fp8:
201
- convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
202
- model = model.to(device)
203
- model = Inference.load_state_dict(args, model, pretrained_model_path)
204
- model.eval()
205
-
206
- # ============================= Build extra models ========================
207
- # VAE
208
- vae, _, s_ratio, t_ratio = load_vae(
209
- args.vae,
210
- args.vae_precision,
211
- logger=logger,
212
- device=device if not args.use_cpu_offload else "cpu",
213
- )
214
- vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
215
-
216
- # Text encoder
217
- if args.prompt_template_video is not None:
218
- crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
219
- "crop_start", 0
220
- )
221
- elif args.prompt_template is not None:
222
- crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
223
- else:
224
- crop_start = 0
225
- max_length = args.text_len + crop_start
226
-
227
- # prompt_template
228
- prompt_template = (
229
- PROMPT_TEMPLATE[args.prompt_template]
230
- if args.prompt_template is not None
231
- else None
232
- )
233
-
234
- # prompt_template_video
235
- prompt_template_video = (
236
- PROMPT_TEMPLATE[args.prompt_template_video]
237
- if args.prompt_template_video is not None
238
- else None
239
- )
240
-
241
- text_encoder = TextEncoder(
242
- text_encoder_type=args.text_encoder,
243
- max_length=max_length,
244
- text_encoder_precision=args.text_encoder_precision,
245
- tokenizer_type=args.tokenizer,
246
- prompt_template=prompt_template,
247
- prompt_template_video=prompt_template_video,
248
- hidden_state_skip_layer=args.hidden_state_skip_layer,
249
- apply_final_norm=args.apply_final_norm,
250
- reproduce=args.reproduce,
251
- logger=logger,
252
- device=device if not args.use_cpu_offload else "cpu",
253
- )
254
- text_encoder_2 = None
255
- if args.text_encoder_2 is not None:
256
- text_encoder_2 = TextEncoder(
257
- text_encoder_type=args.text_encoder_2,
258
- max_length=args.text_len_2,
259
- text_encoder_precision=args.text_encoder_precision_2,
260
- tokenizer_type=args.tokenizer_2,
261
- reproduce=args.reproduce,
262
- logger=logger,
263
- device=device if not args.use_cpu_offload else "cpu",
264
- )
265
-
266
- return cls(
267
- args=args,
268
- vae=vae,
269
- vae_kwargs=vae_kwargs,
270
- text_encoder=text_encoder,
271
- text_encoder_2=text_encoder_2,
272
- model=model,
273
- use_cpu_offload=args.use_cpu_offload,
274
- device=device,
275
- logger=logger,
276
- parallel_args=parallel_args
277
- )
278
-
279
- @staticmethod
280
- def load_state_dict(args, model, pretrained_model_path):
281
- load_key = args.load_key
282
- dit_weight = Path(args.dit_weight)
283
-
284
- if dit_weight is None:
285
- model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
286
- files = list(model_dir.glob("*.pt"))
287
- if len(files) == 0:
288
- raise ValueError(f"No model weights found in {model_dir}")
289
- if str(files[0]).startswith("pytorch_model_"):
290
- model_path = dit_weight / f"pytorch_model_{load_key}.pt"
291
- bare_model = True
292
- elif any(str(f).endswith("_model_states.pt") for f in files):
293
- files = [f for f in files if str(f).endswith("_model_states.pt")]
294
- model_path = files[0]
295
- if len(files) > 1:
296
- logger.warning(
297
- f"Multiple model weights found in {dit_weight}, using {model_path}"
298
- )
299
- bare_model = False
300
- else:
301
- raise ValueError(
302
- f"Invalid model path: {dit_weight} with unrecognized weight format: "
303
- f"{list(map(str, files))}. When given a directory as --dit-weight, only "
304
- f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
305
- f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
306
- f"specific weight file, please provide the full path to the file."
307
- )
308
- else:
309
- if dit_weight.is_dir():
310
- files = list(dit_weight.glob("*.pt"))
311
- if len(files) == 0:
312
- raise ValueError(f"No model weights found in {dit_weight}")
313
- if str(files[0]).startswith("pytorch_model_"):
314
- model_path = dit_weight / f"pytorch_model_{load_key}.pt"
315
- bare_model = True
316
- elif any(str(f).endswith("_model_states.pt") for f in files):
317
- files = [f for f in files if str(f).endswith("_model_states.pt")]
318
- model_path = files[0]
319
- if len(files) > 1:
320
- logger.warning(
321
- f"Multiple model weights found in {dit_weight}, using {model_path}"
322
- )
323
- bare_model = False
324
- else:
325
- raise ValueError(
326
- f"Invalid model path: {dit_weight} with unrecognized weight format: "
327
- f"{list(map(str, files))}. When given a directory as --dit-weight, only "
328
- f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
329
- f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
330
- f"specific weight file, please provide the full path to the file."
331
- )
332
- elif dit_weight.is_file():
333
- model_path = dit_weight
334
- bare_model = "unknown"
335
- else:
336
- raise ValueError(f"Invalid model path: {dit_weight}")
337
-
338
- if not model_path.exists():
339
- raise ValueError(f"model_path not exists: {model_path}")
340
- logger.info(f"Loading torch model {model_path}...")
341
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
342
-
343
- if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
344
- bare_model = False
345
- if bare_model is False:
346
- if load_key in state_dict:
347
- state_dict = state_dict[load_key]
348
- else:
349
- raise KeyError(
350
- f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
351
- f"are: {list(state_dict.keys())}."
352
- )
353
- model.load_state_dict(state_dict, strict=True)
354
- return model
355
-
356
- @staticmethod
357
- def parse_size(size):
358
- if isinstance(size, int):
359
- size = [size]
360
- if not isinstance(size, (list, tuple)):
361
- raise ValueError(f"Size must be an integer or (height, width), got {size}.")
362
- if len(size) == 1:
363
- size = [size[0], size[0]]
364
- if len(size) != 2:
365
- raise ValueError(f"Size must be an integer or (height, width), got {size}.")
366
- return size
367
-
368
-
369
- class HunyuanVideoSampler(Inference):
370
- def __init__(
371
- self,
372
- args,
373
- vae,
374
- vae_kwargs,
375
- text_encoder,
376
- model,
377
- text_encoder_2=None,
378
- pipeline=None,
379
- use_cpu_offload=False,
380
- device=0,
381
- logger=None,
382
- parallel_args=None
383
- ):
384
- super().__init__(
385
- args,
386
- vae,
387
- vae_kwargs,
388
- text_encoder,
389
- model,
390
- text_encoder_2=text_encoder_2,
391
- pipeline=pipeline,
392
- use_cpu_offload=use_cpu_offload,
393
- device=device,
394
- logger=logger,
395
- parallel_args=parallel_args
396
- )
397
-
398
- self.pipeline = self.load_diffusion_pipeline(
399
- args=args,
400
- vae=self.vae,
401
- text_encoder=self.text_encoder,
402
- text_encoder_2=self.text_encoder_2,
403
- model=self.model,
404
- device=self.device,
405
- )
406
-
407
- self.default_negative_prompt = NEGATIVE_PROMPT
408
- if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
409
- parallelize_transformer(self.pipeline)
410
-
411
- def load_diffusion_pipeline(
412
- self,
413
- args,
414
- vae,
415
- text_encoder,
416
- text_encoder_2,
417
- model,
418
- scheduler=None,
419
- device=None,
420
- progress_bar_config=None,
421
- data_type="video",
422
- ):
423
- """Load the denoising scheduler for inference."""
424
- if scheduler is None:
425
- if args.denoise_type == "flow":
426
- scheduler = FlowMatchDiscreteScheduler(
427
- shift=args.flow_shift,
428
- reverse=args.flow_reverse,
429
- solver=args.flow_solver,
430
- )
431
- else:
432
- raise ValueError(f"Invalid denoise type {args.denoise_type}")
433
-
434
- pipeline = HunyuanVideoPipeline(
435
- vae=vae,
436
- text_encoder=text_encoder,
437
- text_encoder_2=text_encoder_2,
438
- transformer=model,
439
- scheduler=scheduler,
440
- progress_bar_config=progress_bar_config,
441
- args=args,
442
- )
443
- if self.use_cpu_offload:
444
- pipeline.enable_sequential_cpu_offload()
445
- else:
446
- pipeline = pipeline.to(device)
447
-
448
- return pipeline
449
-
450
- def get_rotary_pos_embed(self, video_length, height, width):
451
- target_ndim = 3
452
- ndim = 5 - 2
453
- # 884
454
- if "884" in self.args.vae:
455
- latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
456
- elif "888" in self.args.vae:
457
- latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
458
- else:
459
- latents_size = [video_length, height // 8, width // 8]
460
-
461
- if isinstance(self.model.patch_size, int):
462
- assert all(s % self.model.patch_size == 0 for s in latents_size), (
463
- f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
464
- f"but got {latents_size}."
465
- )
466
- rope_sizes = [s // self.model.patch_size for s in latents_size]
467
- elif isinstance(self.model.patch_size, list):
468
- assert all(
469
- s % self.model.patch_size[idx] == 0
470
- for idx, s in enumerate(latents_size)
471
- ), (
472
- f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
473
- f"but got {latents_size}."
474
- )
475
- rope_sizes = [
476
- s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
477
- ]
478
-
479
- if len(rope_sizes) != target_ndim:
480
- rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
481
- head_dim = self.model.hidden_size // self.model.heads_num
482
- rope_dim_list = self.model.rope_dim_list
483
- if rope_dim_list is None:
484
- rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
485
- assert (
486
- sum(rope_dim_list) == head_dim
487
- ), "sum(rope_dim_list) should equal to head_dim of attention layer"
488
- freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
489
- rope_dim_list,
490
- rope_sizes,
491
- theta=self.args.rope_theta,
492
- use_real=True,
493
- theta_rescale_factor=1,
494
- )
495
- return freqs_cos, freqs_sin
496
-
497
- @torch.no_grad()
498
- def predict(
499
- self,
500
- prompt,
501
- height=192,
502
- width=336,
503
- video_length=129,
504
- seed=None,
505
- negative_prompt=None,
506
- infer_steps=50,
507
- guidance_scale=6,
508
- flow_shift=5.0,
509
- embedded_guidance_scale=None,
510
- batch_size=1,
511
- num_videos_per_prompt=1,
512
- **kwargs,
513
- ):
514
- """
515
- Predict the image/video from the given text.
516
-
517
- Args:
518
- prompt (str or List[str]): The input text.
519
- kwargs:
520
- height (int): The height of the output video. Default is 192.
521
- width (int): The width of the output video. Default is 336.
522
- video_length (int): The frame number of the output video. Default is 129.
523
- seed (int or List[str]): The random seed for the generation. Default is a random integer.
524
- negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
525
- guidance_scale (float): The guidance scale for the generation. Default is 6.0.
526
- num_images_per_prompt (int): The number of images per prompt. Default is 1.
527
- infer_steps (int): The number of inference steps. Default is 100.
528
- """
529
- out_dict = dict()
530
-
531
- # ========================================================================
532
- # Arguments: seed
533
- # ========================================================================
534
- if isinstance(seed, torch.Tensor):
535
- seed = seed.tolist()
536
- if seed is None:
537
- seeds = [
538
- random.randint(0, 1_000_000)
539
- for _ in range(batch_size * num_videos_per_prompt)
540
- ]
541
- elif isinstance(seed, int):
542
- seeds = [
543
- seed + i
544
- for _ in range(batch_size)
545
- for i in range(num_videos_per_prompt)
546
- ]
547
- elif isinstance(seed, (list, tuple)):
548
- if len(seed) == batch_size:
549
- seeds = [
550
- int(seed[i]) + j
551
- for i in range(batch_size)
552
- for j in range(num_videos_per_prompt)
553
- ]
554
- elif len(seed) == batch_size * num_videos_per_prompt:
555
- seeds = [int(s) for s in seed]
556
- else:
557
- raise ValueError(
558
- f"Length of seed must be equal to number of prompt(batch_size) or "
559
- f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
560
- )
561
- else:
562
- raise ValueError(
563
- f"Seed must be an integer, a list of integers, or None, got {seed}."
564
- )
565
- generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
566
- out_dict["seeds"] = seeds
567
-
568
- # ========================================================================
569
- # Arguments: target_width, target_height, target_video_length
570
- # ========================================================================
571
- if width <= 0 or height <= 0 or video_length <= 0:
572
- raise ValueError(
573
- f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
574
- )
575
- if (video_length - 1) % 4 != 0:
576
- raise ValueError(
577
- f"`video_length-1` must be a multiple of 4, got {video_length}"
578
- )
579
-
580
- logger.info(
581
- f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
582
- )
583
-
584
- target_height = align_to(height, 16)
585
- target_width = align_to(width, 16)
586
- target_video_length = video_length
587
-
588
- out_dict["size"] = (target_height, target_width, target_video_length)
589
-
590
- # ========================================================================
591
- # Arguments: prompt, new_prompt, negative_prompt
592
- # ========================================================================
593
- if not isinstance(prompt, str):
594
- raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
595
- prompt = [prompt.strip()]
596
-
597
- # negative prompt
598
- if negative_prompt is None or negative_prompt == "":
599
- negative_prompt = self.default_negative_prompt
600
- if not isinstance(negative_prompt, str):
601
- raise TypeError(
602
- f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
603
- )
604
- negative_prompt = [negative_prompt.strip()]
605
-
606
- # ========================================================================
607
- # Scheduler
608
- # ========================================================================
609
- scheduler = FlowMatchDiscreteScheduler(
610
- shift=flow_shift,
611
- reverse=self.args.flow_reverse,
612
- solver=self.args.flow_solver
613
- )
614
- self.pipeline.scheduler = scheduler
615
-
616
- # ========================================================================
617
- # Build Rope freqs
618
- # ========================================================================
619
- freqs_cos, freqs_sin = self.get_rotary_pos_embed(
620
- target_video_length, target_height, target_width
621
- )
622
- n_tokens = freqs_cos.shape[0]
623
-
624
- # ========================================================================
625
- # Print infer args
626
- # ========================================================================
627
- debug_str = f"""
628
- height: {target_height}
629
- width: {target_width}
630
- video_length: {target_video_length}
631
- prompt: {prompt}
632
- neg_prompt: {negative_prompt}
633
- seed: {seed}
634
- infer_steps: {infer_steps}
635
- num_videos_per_prompt: {num_videos_per_prompt}
636
- guidance_scale: {guidance_scale}
637
- n_tokens: {n_tokens}
638
- flow_shift: {flow_shift}
639
- embedded_guidance_scale: {embedded_guidance_scale}"""
640
- logger.debug(debug_str)
641
-
642
- # ========================================================================
643
- # Pipeline inference
644
- # ========================================================================
645
- start_time = time.time()
646
- samples = self.pipeline(
647
- prompt=prompt,
648
- height=target_height,
649
- width=target_width,
650
- video_length=target_video_length,
651
- num_inference_steps=infer_steps,
652
- guidance_scale=guidance_scale,
653
- negative_prompt=negative_prompt,
654
- num_videos_per_prompt=num_videos_per_prompt,
655
- generator=generator,
656
- output_type="pil",
657
- freqs_cis=(freqs_cos, freqs_sin),
658
- n_tokens=n_tokens,
659
- embedded_guidance_scale=embedded_guidance_scale,
660
- data_type="video" if target_video_length > 1 else "image",
661
- is_progress_bar=True,
662
- vae_ver=self.args.vae,
663
- enable_tiling=self.args.vae_tiling,
664
- )[0]
665
- out_dict["samples"] = samples
666
- out_dict["prompts"] = prompt
667
-
668
- gen_time = time.time() - start_time
669
- logger.info(f"Success, time: {gen_time}")
670
-
671
- return out_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2
-
3
-
4
- def load_model(args, in_channels, out_channels, factor_kwargs):
5
- """load hunyuan video model
6
-
7
- Args:
8
- args (dict): model args
9
- in_channels (int): input channels number
10
- out_channels (int): output channels number
11
- factor_kwargs (dict): factor kwargs
12
-
13
- Returns:
14
- model (nn.Module): The hunyuan video model
15
- """
16
- if args.model in HUNYUAN_VIDEO_CONFIG.keys():
17
- model = HYVideoDiffusionTransformer(
18
- args,
19
- in_channels=in_channels,
20
- out_channels=out_channels,
21
- **HUNYUAN_VIDEO_CONFIG[args.model],
22
- **factor_kwargs,
23
- )
24
- return model
25
- else:
26
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/activation_layers.py DELETED
@@ -1,23 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- def get_activation_layer(act_type):
5
- """get activation layer
6
-
7
- Args:
8
- act_type (str): the activation type
9
-
10
- Returns:
11
- torch.nn.functional: the activation layer
12
- """
13
- if act_type == "gelu":
14
- return lambda: nn.GELU()
15
- elif act_type == "gelu_tanh":
16
- # Approximate `tanh` requires torch >= 1.13
17
- return lambda: nn.GELU(approximate="tanh")
18
- elif act_type == "relu":
19
- return nn.ReLU
20
- elif act_type == "silu":
21
- return nn.SiLU
22
- else:
23
- raise ValueError(f"Unknown activation type: {act_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/attenion.py DELETED
@@ -1,212 +0,0 @@
1
- import importlib.metadata
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- try:
9
- import flash_attn
10
- from flash_attn.flash_attn_interface import _flash_attn_forward
11
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
- except ImportError:
13
- flash_attn = None
14
- flash_attn_varlen_func = None
15
- _flash_attn_forward = None
16
-
17
-
18
- MEMORY_LAYOUT = {
19
- "flash": (
20
- lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
21
- lambda x: x,
22
- ),
23
- "torch": (
24
- lambda x: x.transpose(1, 2),
25
- lambda x: x.transpose(1, 2),
26
- ),
27
- "vanilla": (
28
- lambda x: x.transpose(1, 2),
29
- lambda x: x.transpose(1, 2),
30
- ),
31
- }
32
-
33
-
34
- def get_cu_seqlens(text_mask, img_len):
35
- """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
36
-
37
- Args:
38
- text_mask (torch.Tensor): the mask of text
39
- img_len (int): the length of image
40
-
41
- Returns:
42
- torch.Tensor: the calculated cu_seqlens for flash attention
43
- """
44
- batch_size = text_mask.shape[0]
45
- text_len = text_mask.sum(dim=1)
46
- max_len = text_mask.shape[1] + img_len
47
-
48
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
49
-
50
- for i in range(batch_size):
51
- s = text_len[i] + img_len
52
- s1 = i * max_len + s
53
- s2 = (i + 1) * max_len
54
- cu_seqlens[2 * i + 1] = s1
55
- cu_seqlens[2 * i + 2] = s2
56
-
57
- return cu_seqlens
58
-
59
-
60
- def attention(
61
- q,
62
- k,
63
- v,
64
- mode="flash",
65
- drop_rate=0,
66
- attn_mask=None,
67
- causal=False,
68
- cu_seqlens_q=None,
69
- cu_seqlens_kv=None,
70
- max_seqlen_q=None,
71
- max_seqlen_kv=None,
72
- batch_size=1,
73
- ):
74
- """
75
- Perform QKV self attention.
76
-
77
- Args:
78
- q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
79
- k (torch.Tensor): Key tensor with shape [b, s1, a, d]
80
- v (torch.Tensor): Value tensor with shape [b, s1, a, d]
81
- mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
82
- drop_rate (float): Dropout rate in attention map. (default: 0)
83
- attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
84
- (default: None)
85
- causal (bool): Whether to use causal attention. (default: False)
86
- cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
87
- used to index into q.
88
- cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
89
- used to index into kv.
90
- max_seqlen_q (int): The maximum sequence length in the batch of q.
91
- max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
92
-
93
- Returns:
94
- torch.Tensor: Output tensor after self attention with shape [b, s, ad]
95
- """
96
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
97
- q = pre_attn_layout(q)
98
- k = pre_attn_layout(k)
99
- v = pre_attn_layout(v)
100
-
101
- if mode == "torch":
102
- if attn_mask is not None and attn_mask.dtype != torch.bool:
103
- attn_mask = attn_mask.to(q.dtype)
104
- x = F.scaled_dot_product_attention(
105
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
106
- )
107
- elif mode == "flash":
108
- x = flash_attn_varlen_func(
109
- q,
110
- k,
111
- v,
112
- cu_seqlens_q,
113
- cu_seqlens_kv,
114
- max_seqlen_q,
115
- max_seqlen_kv,
116
- )
117
- # x with shape [(bxs), a, d]
118
- x = x.view(
119
- batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
120
- ) # reshape x to [b, s, a, d]
121
- elif mode == "vanilla":
122
- scale_factor = 1 / math.sqrt(q.size(-1))
123
-
124
- b, a, s, _ = q.shape
125
- s1 = k.size(2)
126
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
127
- if causal:
128
- # Only applied to self attention
129
- assert (
130
- attn_mask is None
131
- ), "Causal mask and attn_mask cannot be used together"
132
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
133
- diagonal=0
134
- )
135
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
136
- attn_bias.to(q.dtype)
137
-
138
- if attn_mask is not None:
139
- if attn_mask.dtype == torch.bool:
140
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
141
- else:
142
- attn_bias += attn_mask
143
-
144
- # TODO: Maybe force q and k to be float32 to avoid numerical overflow
145
- attn = (q @ k.transpose(-2, -1)) * scale_factor
146
- attn += attn_bias
147
- attn = attn.softmax(dim=-1)
148
- attn = torch.dropout(attn, p=drop_rate, train=True)
149
- x = attn @ v
150
- else:
151
- raise NotImplementedError(f"Unsupported attention mode: {mode}")
152
-
153
- x = post_attn_layout(x)
154
- b, s, a, d = x.shape
155
- out = x.reshape(b, s, -1)
156
- return out
157
-
158
-
159
- def parallel_attention(
160
- hybrid_seq_parallel_attn,
161
- q,
162
- k,
163
- v,
164
- img_q_len,
165
- img_kv_len,
166
- cu_seqlens_q,
167
- cu_seqlens_kv
168
- ):
169
- attn1 = hybrid_seq_parallel_attn(
170
- None,
171
- q[:, :img_q_len, :, :],
172
- k[:, :img_kv_len, :, :],
173
- v[:, :img_kv_len, :, :],
174
- dropout_p=0.0,
175
- causal=False,
176
- joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
177
- joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
178
- joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179
- joint_strategy="rear",
180
- )
181
- if flash_attn.__version__ >= '2.7.0':
182
- attn2, *_ = _flash_attn_forward(
183
- q[:,cu_seqlens_q[1]:],
184
- k[:,cu_seqlens_kv[1]:],
185
- v[:,cu_seqlens_kv[1]:],
186
- dropout_p=0.0,
187
- softmax_scale=q.shape[-1] ** (-0.5),
188
- causal=False,
189
- window_size_left=-1,
190
- window_size_right=-1,
191
- softcap=0.0,
192
- alibi_slopes=None,
193
- return_softmax=False,
194
- )
195
- else:
196
- attn2, *_ = _flash_attn_forward(
197
- q[:,cu_seqlens_q[1]:],
198
- k[:,cu_seqlens_kv[1]:],
199
- v[:,cu_seqlens_kv[1]:],
200
- dropout_p=0.0,
201
- softmax_scale=q.shape[-1] ** (-0.5),
202
- causal=False,
203
- window_size=(-1, -1),
204
- softcap=0.0,
205
- alibi_slopes=None,
206
- return_softmax=False,
207
- )
208
- attn = torch.cat([attn1, attn2], dim=1)
209
- b, s, a, d = attn.shape
210
- attn = attn.reshape(b, s, -1)
211
-
212
- return attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/embed_layers.py DELETED
@@ -1,157 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from einops import rearrange, repeat
5
-
6
- from ..utils.helpers import to_2tuple
7
-
8
-
9
- class PatchEmbed(nn.Module):
10
- """2D Image to Patch Embedding
11
-
12
- Image to Patch Embedding using Conv2d
13
-
14
- A convolution based approach to patchifying a 2D image w/ embedding projection.
15
-
16
- Based on the impl in https://github.com/google-research/vision_transformer
17
-
18
- Hacked together by / Copyright 2020 Ross Wightman
19
-
20
- Remove the _assert function in forward function to be compatible with multi-resolution images.
21
- """
22
-
23
- def __init__(
24
- self,
25
- patch_size=16,
26
- in_chans=3,
27
- embed_dim=768,
28
- norm_layer=None,
29
- flatten=True,
30
- bias=True,
31
- dtype=None,
32
- device=None,
33
- ):
34
- factory_kwargs = {"dtype": dtype, "device": device}
35
- super().__init__()
36
- patch_size = to_2tuple(patch_size)
37
- self.patch_size = patch_size
38
- self.flatten = flatten
39
-
40
- self.proj = nn.Conv3d(
41
- in_chans,
42
- embed_dim,
43
- kernel_size=patch_size,
44
- stride=patch_size,
45
- bias=bias,
46
- **factory_kwargs
47
- )
48
- nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
49
- if bias:
50
- nn.init.zeros_(self.proj.bias)
51
-
52
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
53
-
54
- def forward(self, x):
55
- x = self.proj(x)
56
- if self.flatten:
57
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
58
- x = self.norm(x)
59
- return x
60
-
61
-
62
- class TextProjection(nn.Module):
63
- """
64
- Projects text embeddings. Also handles dropout for classifier-free guidance.
65
-
66
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
67
- """
68
-
69
- def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
70
- factory_kwargs = {"dtype": dtype, "device": device}
71
- super().__init__()
72
- self.linear_1 = nn.Linear(
73
- in_features=in_channels,
74
- out_features=hidden_size,
75
- bias=True,
76
- **factory_kwargs
77
- )
78
- self.act_1 = act_layer()
79
- self.linear_2 = nn.Linear(
80
- in_features=hidden_size,
81
- out_features=hidden_size,
82
- bias=True,
83
- **factory_kwargs
84
- )
85
-
86
- def forward(self, caption):
87
- hidden_states = self.linear_1(caption)
88
- hidden_states = self.act_1(hidden_states)
89
- hidden_states = self.linear_2(hidden_states)
90
- return hidden_states
91
-
92
-
93
- def timestep_embedding(t, dim, max_period=10000):
94
- """
95
- Create sinusoidal timestep embeddings.
96
-
97
- Args:
98
- t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
99
- dim (int): the dimension of the output.
100
- max_period (int): controls the minimum frequency of the embeddings.
101
-
102
- Returns:
103
- embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
104
-
105
- .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
106
- """
107
- half = dim // 2
108
- freqs = torch.exp(
109
- -math.log(max_period)
110
- * torch.arange(start=0, end=half, dtype=torch.float32)
111
- / half
112
- ).to(device=t.device)
113
- args = t[:, None].float() * freqs[None]
114
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
115
- if dim % 2:
116
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
117
- return embedding
118
-
119
-
120
- class TimestepEmbedder(nn.Module):
121
- """
122
- Embeds scalar timesteps into vector representations.
123
- """
124
-
125
- def __init__(
126
- self,
127
- hidden_size,
128
- act_layer,
129
- frequency_embedding_size=256,
130
- max_period=10000,
131
- out_size=None,
132
- dtype=None,
133
- device=None,
134
- ):
135
- factory_kwargs = {"dtype": dtype, "device": device}
136
- super().__init__()
137
- self.frequency_embedding_size = frequency_embedding_size
138
- self.max_period = max_period
139
- if out_size is None:
140
- out_size = hidden_size
141
-
142
- self.mlp = nn.Sequential(
143
- nn.Linear(
144
- frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
145
- ),
146
- act_layer(),
147
- nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
148
- )
149
- nn.init.normal_(self.mlp[0].weight, std=0.02)
150
- nn.init.normal_(self.mlp[2].weight, std=0.02)
151
-
152
- def forward(self, t):
153
- t_freq = timestep_embedding(
154
- t, self.frequency_embedding_size, self.max_period
155
- ).type(self.mlp[0].weight.dtype)
156
- t_emb = self.mlp(t_freq)
157
- return t_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/fp8_optimization.py DELETED
@@ -1,102 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8
- _bits = torch.tensor(bits)
9
- _mantissa_bit = torch.tensor(mantissa_bit)
10
- _sign_bits = torch.tensor(sign_bits)
11
- M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12
- E = _bits - _sign_bits - M
13
- bias = 2 ** (E - 1) - 1
14
- mantissa = 1
15
- for i in range(mantissa_bit - 1):
16
- mantissa += 1 / (2 ** (i+1))
17
- maxval = mantissa * 2 ** (2**E - 1 - bias)
18
- return maxval
19
-
20
- def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21
- """
22
- Default is E4M3.
23
- """
24
- bits = torch.tensor(bits)
25
- mantissa_bit = torch.tensor(mantissa_bit)
26
- sign_bits = torch.tensor(sign_bits)
27
- M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28
- E = bits - sign_bits - M
29
- bias = 2 ** (E - 1) - 1
30
- mantissa = 1
31
- for i in range(mantissa_bit - 1):
32
- mantissa += 1 / (2 ** (i+1))
33
- maxval = mantissa * 2 ** (2**E - 1 - bias)
34
- minval = - maxval
35
- minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36
- input_clamp = torch.min(torch.max(x, minval), maxval)
37
- log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38
- log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39
- # dequant
40
- qdq_out = torch.round(input_clamp / log_scales) * log_scales
41
- return qdq_out, log_scales
42
-
43
- def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44
- for i in range(len(x.shape) - 1):
45
- scale = scale.unsqueeze(-1)
46
- new_x = x / scale
47
- quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48
- return quant_dequant_x, scale, log_scales
49
-
50
- def fp8_activation_dequant(qdq_out, scale, dtype):
51
- qdq_out = qdq_out.type(dtype)
52
- quant_dequant_x = qdq_out * scale.to(dtype)
53
- return quant_dequant_x
54
-
55
- def fp8_linear_forward(cls, original_dtype, input):
56
- weight_dtype = cls.weight.dtype
57
- #####
58
- if cls.weight.dtype != torch.float8_e4m3fn:
59
- maxval = get_fp_maxval()
60
- scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61
- linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62
- linear_weight = linear_weight.to(torch.float8_e4m3fn)
63
- weight_dtype = linear_weight.dtype
64
- else:
65
- scale = cls.fp8_scale.to(cls.weight.device)
66
- linear_weight = cls.weight
67
- #####
68
-
69
- if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70
- if True or len(input.shape) == 3:
71
- cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72
- if cls.bias != None:
73
- output = F.linear(input, cls_dequant, cls.bias)
74
- else:
75
- output = F.linear(input, cls_dequant)
76
- return output
77
- else:
78
- return cls.original_forward(input.to(original_dtype))
79
- else:
80
- return cls.original_forward(input)
81
-
82
- def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83
- setattr(module, "fp8_matmul_enabled", True)
84
-
85
- # loading fp8 mapping file
86
- fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87
- if os.path.exists(fp8_map_path):
88
- fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
- else:
90
- raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91
-
92
- fp8_layers = []
93
- for key, layer in module.named_modules():
94
- if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95
- fp8_layers.append(key)
96
- original_forward = layer.forward
97
- layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98
- setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99
- setattr(layer, "original_forward", original_forward)
100
- setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101
-
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/mlp_layers.py DELETED
@@ -1,118 +0,0 @@
1
- # Modified from timm library:
2
- # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
-
4
- from functools import partial
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
- from .modulate_layers import modulate
10
- from ..utils.helpers import to_2tuple
11
-
12
-
13
- class MLP(nn.Module):
14
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
-
16
- def __init__(
17
- self,
18
- in_channels,
19
- hidden_channels=None,
20
- out_features=None,
21
- act_layer=nn.GELU,
22
- norm_layer=None,
23
- bias=True,
24
- drop=0.0,
25
- use_conv=False,
26
- device=None,
27
- dtype=None,
28
- ):
29
- factory_kwargs = {"device": device, "dtype": dtype}
30
- super().__init__()
31
- out_features = out_features or in_channels
32
- hidden_channels = hidden_channels or in_channels
33
- bias = to_2tuple(bias)
34
- drop_probs = to_2tuple(drop)
35
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
-
37
- self.fc1 = linear_layer(
38
- in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
- )
40
- self.act = act_layer()
41
- self.drop1 = nn.Dropout(drop_probs[0])
42
- self.norm = (
43
- norm_layer(hidden_channels, **factory_kwargs)
44
- if norm_layer is not None
45
- else nn.Identity()
46
- )
47
- self.fc2 = linear_layer(
48
- hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
- )
50
- self.drop2 = nn.Dropout(drop_probs[1])
51
-
52
- def forward(self, x):
53
- x = self.fc1(x)
54
- x = self.act(x)
55
- x = self.drop1(x)
56
- x = self.norm(x)
57
- x = self.fc2(x)
58
- x = self.drop2(x)
59
- return x
60
-
61
-
62
- #
63
- class MLPEmbedder(nn.Module):
64
- """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
- def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
- factory_kwargs = {"device": device, "dtype": dtype}
67
- super().__init__()
68
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
- self.silu = nn.SiLU()
70
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
-
72
- def forward(self, x: torch.Tensor) -> torch.Tensor:
73
- return self.out_layer(self.silu(self.in_layer(x)))
74
-
75
-
76
- class FinalLayer(nn.Module):
77
- """The final layer of DiT."""
78
-
79
- def __init__(
80
- self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
- ):
82
- factory_kwargs = {"device": device, "dtype": dtype}
83
- super().__init__()
84
-
85
- # Just use LayerNorm for the final layer
86
- self.norm_final = nn.LayerNorm(
87
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
- )
89
- if isinstance(patch_size, int):
90
- self.linear = nn.Linear(
91
- hidden_size,
92
- patch_size * patch_size * out_channels,
93
- bias=True,
94
- **factory_kwargs
95
- )
96
- else:
97
- self.linear = nn.Linear(
98
- hidden_size,
99
- patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
- bias=True,
101
- )
102
- nn.init.zeros_(self.linear.weight)
103
- nn.init.zeros_(self.linear.bias)
104
-
105
- # Here we don't distinguish between the modulate types. Just use the simple one.
106
- self.adaLN_modulation = nn.Sequential(
107
- act_layer(),
108
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
- )
110
- # Zero-initialize the modulation
111
- nn.init.zeros_(self.adaLN_modulation[1].weight)
112
- nn.init.zeros_(self.adaLN_modulation[1].bias)
113
-
114
- def forward(self, x, c):
115
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
- x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
- x = self.linear(x)
118
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/models.py DELETED
@@ -1,760 +0,0 @@
1
- from typing import Any, List, Tuple, Optional, Union, Dict
2
- from einops import rearrange
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from diffusers.models import ModelMixin
9
- from diffusers.configuration_utils import ConfigMixin, register_to_config
10
-
11
- from .activation_layers import get_activation_layer
12
- from .norm_layers import get_norm_layer
13
- from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
14
- from .attenion import attention, parallel_attention, get_cu_seqlens
15
- from .posemb_layers import apply_rotary_emb
16
- from .mlp_layers import MLP, MLPEmbedder, FinalLayer
17
- from .modulate_layers import ModulateDiT, modulate, apply_gate
18
- from .token_refiner import SingleTokenRefiner
19
-
20
-
21
- class MMDoubleStreamBlock(nn.Module):
22
- """
23
- A multimodal dit block with seperate modulation for
24
- text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
25
- (Flux.1): https://github.com/black-forest-labs/flux
26
- """
27
-
28
- def __init__(
29
- self,
30
- hidden_size: int,
31
- heads_num: int,
32
- mlp_width_ratio: float,
33
- mlp_act_type: str = "gelu_tanh",
34
- qk_norm: bool = True,
35
- qk_norm_type: str = "rms",
36
- qkv_bias: bool = False,
37
- dtype: Optional[torch.dtype] = None,
38
- device: Optional[torch.device] = None,
39
- ):
40
- factory_kwargs = {"device": device, "dtype": dtype}
41
- super().__init__()
42
-
43
- self.deterministic = False
44
- self.heads_num = heads_num
45
- head_dim = hidden_size // heads_num
46
- mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
47
-
48
- self.img_mod = ModulateDiT(
49
- hidden_size,
50
- factor=6,
51
- act_layer=get_activation_layer("silu"),
52
- **factory_kwargs,
53
- )
54
- self.img_norm1 = nn.LayerNorm(
55
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
56
- )
57
-
58
- self.img_attn_qkv = nn.Linear(
59
- hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
60
- )
61
- qk_norm_layer = get_norm_layer(qk_norm_type)
62
- self.img_attn_q_norm = (
63
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
64
- if qk_norm
65
- else nn.Identity()
66
- )
67
- self.img_attn_k_norm = (
68
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
69
- if qk_norm
70
- else nn.Identity()
71
- )
72
- self.img_attn_proj = nn.Linear(
73
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
74
- )
75
-
76
- self.img_norm2 = nn.LayerNorm(
77
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
78
- )
79
- self.img_mlp = MLP(
80
- hidden_size,
81
- mlp_hidden_dim,
82
- act_layer=get_activation_layer(mlp_act_type),
83
- bias=True,
84
- **factory_kwargs,
85
- )
86
-
87
- self.txt_mod = ModulateDiT(
88
- hidden_size,
89
- factor=6,
90
- act_layer=get_activation_layer("silu"),
91
- **factory_kwargs,
92
- )
93
- self.txt_norm1 = nn.LayerNorm(
94
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
95
- )
96
-
97
- self.txt_attn_qkv = nn.Linear(
98
- hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
99
- )
100
- self.txt_attn_q_norm = (
101
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
102
- if qk_norm
103
- else nn.Identity()
104
- )
105
- self.txt_attn_k_norm = (
106
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
107
- if qk_norm
108
- else nn.Identity()
109
- )
110
- self.txt_attn_proj = nn.Linear(
111
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
112
- )
113
-
114
- self.txt_norm2 = nn.LayerNorm(
115
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
116
- )
117
- self.txt_mlp = MLP(
118
- hidden_size,
119
- mlp_hidden_dim,
120
- act_layer=get_activation_layer(mlp_act_type),
121
- bias=True,
122
- **factory_kwargs,
123
- )
124
- self.hybrid_seq_parallel_attn = None
125
-
126
- def enable_deterministic(self):
127
- self.deterministic = True
128
-
129
- def disable_deterministic(self):
130
- self.deterministic = False
131
-
132
- def forward(
133
- self,
134
- img: torch.Tensor,
135
- txt: torch.Tensor,
136
- vec: torch.Tensor,
137
- cu_seqlens_q: Optional[torch.Tensor] = None,
138
- cu_seqlens_kv: Optional[torch.Tensor] = None,
139
- max_seqlen_q: Optional[int] = None,
140
- max_seqlen_kv: Optional[int] = None,
141
- freqs_cis: tuple = None,
142
- ) -> Tuple[torch.Tensor, torch.Tensor]:
143
- (
144
- img_mod1_shift,
145
- img_mod1_scale,
146
- img_mod1_gate,
147
- img_mod2_shift,
148
- img_mod2_scale,
149
- img_mod2_gate,
150
- ) = self.img_mod(vec).chunk(6, dim=-1)
151
- (
152
- txt_mod1_shift,
153
- txt_mod1_scale,
154
- txt_mod1_gate,
155
- txt_mod2_shift,
156
- txt_mod2_scale,
157
- txt_mod2_gate,
158
- ) = self.txt_mod(vec).chunk(6, dim=-1)
159
-
160
- # Prepare image for attention.
161
- img_modulated = self.img_norm1(img)
162
- img_modulated = modulate(
163
- img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
164
- )
165
- img_qkv = self.img_attn_qkv(img_modulated)
166
- img_q, img_k, img_v = rearrange(
167
- img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
168
- )
169
- # Apply QK-Norm if needed
170
- img_q = self.img_attn_q_norm(img_q).to(img_v)
171
- img_k = self.img_attn_k_norm(img_k).to(img_v)
172
-
173
- # Apply RoPE if needed.
174
- if freqs_cis is not None:
175
- img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
176
- assert (
177
- img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
178
- ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
179
- img_q, img_k = img_qq, img_kk
180
-
181
- # Prepare txt for attention.
182
- txt_modulated = self.txt_norm1(txt)
183
- txt_modulated = modulate(
184
- txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
185
- )
186
- txt_qkv = self.txt_attn_qkv(txt_modulated)
187
- txt_q, txt_k, txt_v = rearrange(
188
- txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
189
- )
190
- # Apply QK-Norm if needed.
191
- txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
192
- txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
193
-
194
- # Run actual attention.
195
- q = torch.cat((img_q, txt_q), dim=1)
196
- k = torch.cat((img_k, txt_k), dim=1)
197
- v = torch.cat((img_v, txt_v), dim=1)
198
- assert (
199
- cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
200
- ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
201
-
202
- # attention computation start
203
- if not self.hybrid_seq_parallel_attn:
204
- attn = attention(
205
- q,
206
- k,
207
- v,
208
- cu_seqlens_q=cu_seqlens_q,
209
- cu_seqlens_kv=cu_seqlens_kv,
210
- max_seqlen_q=max_seqlen_q,
211
- max_seqlen_kv=max_seqlen_kv,
212
- batch_size=img_k.shape[0],
213
- )
214
- else:
215
- attn = parallel_attention(
216
- self.hybrid_seq_parallel_attn,
217
- q,
218
- k,
219
- v,
220
- img_q_len=img_q.shape[1],
221
- img_kv_len=img_k.shape[1],
222
- cu_seqlens_q=cu_seqlens_q,
223
- cu_seqlens_kv=cu_seqlens_kv
224
- )
225
-
226
- # attention computation end
227
-
228
- img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
229
-
230
- # Calculate the img bloks.
231
- img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
232
- img = img + apply_gate(
233
- self.img_mlp(
234
- modulate(
235
- self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
236
- )
237
- ),
238
- gate=img_mod2_gate,
239
- )
240
-
241
- # Calculate the txt bloks.
242
- txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
243
- txt = txt + apply_gate(
244
- self.txt_mlp(
245
- modulate(
246
- self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
247
- )
248
- ),
249
- gate=txt_mod2_gate,
250
- )
251
-
252
- return img, txt
253
-
254
-
255
- class MMSingleStreamBlock(nn.Module):
256
- """
257
- A DiT block with parallel linear layers as described in
258
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
259
- Also refer to (SD3): https://arxiv.org/abs/2403.03206
260
- (Flux.1): https://github.com/black-forest-labs/flux
261
- """
262
-
263
- def __init__(
264
- self,
265
- hidden_size: int,
266
- heads_num: int,
267
- mlp_width_ratio: float = 4.0,
268
- mlp_act_type: str = "gelu_tanh",
269
- qk_norm: bool = True,
270
- qk_norm_type: str = "rms",
271
- qk_scale: float = None,
272
- dtype: Optional[torch.dtype] = None,
273
- device: Optional[torch.device] = None,
274
- ):
275
- factory_kwargs = {"device": device, "dtype": dtype}
276
- super().__init__()
277
-
278
- self.deterministic = False
279
- self.hidden_size = hidden_size
280
- self.heads_num = heads_num
281
- head_dim = hidden_size // heads_num
282
- mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
283
- self.mlp_hidden_dim = mlp_hidden_dim
284
- self.scale = qk_scale or head_dim ** -0.5
285
-
286
- # qkv and mlp_in
287
- self.linear1 = nn.Linear(
288
- hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
289
- )
290
- # proj and mlp_out
291
- self.linear2 = nn.Linear(
292
- hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
293
- )
294
-
295
- qk_norm_layer = get_norm_layer(qk_norm_type)
296
- self.q_norm = (
297
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
298
- if qk_norm
299
- else nn.Identity()
300
- )
301
- self.k_norm = (
302
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
303
- if qk_norm
304
- else nn.Identity()
305
- )
306
-
307
- self.pre_norm = nn.LayerNorm(
308
- hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
309
- )
310
-
311
- self.mlp_act = get_activation_layer(mlp_act_type)()
312
- self.modulation = ModulateDiT(
313
- hidden_size,
314
- factor=3,
315
- act_layer=get_activation_layer("silu"),
316
- **factory_kwargs,
317
- )
318
- self.hybrid_seq_parallel_attn = None
319
-
320
- def enable_deterministic(self):
321
- self.deterministic = True
322
-
323
- def disable_deterministic(self):
324
- self.deterministic = False
325
-
326
- def forward(
327
- self,
328
- x: torch.Tensor,
329
- vec: torch.Tensor,
330
- txt_len: int,
331
- cu_seqlens_q: Optional[torch.Tensor] = None,
332
- cu_seqlens_kv: Optional[torch.Tensor] = None,
333
- max_seqlen_q: Optional[int] = None,
334
- max_seqlen_kv: Optional[int] = None,
335
- freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
336
- ) -> torch.Tensor:
337
- mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
338
- x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
339
- qkv, mlp = torch.split(
340
- self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
341
- )
342
-
343
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
344
-
345
- # Apply QK-Norm if needed.
346
- q = self.q_norm(q).to(v)
347
- k = self.k_norm(k).to(v)
348
-
349
- # Apply RoPE if needed.
350
- if freqs_cis is not None:
351
- img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
352
- img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
353
- img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
354
- assert (
355
- img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
356
- ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
357
- img_q, img_k = img_qq, img_kk
358
- q = torch.cat((img_q, txt_q), dim=1)
359
- k = torch.cat((img_k, txt_k), dim=1)
360
-
361
- # Compute attention.
362
- assert (
363
- cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
364
- ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
365
-
366
- # attention computation start
367
- if not self.hybrid_seq_parallel_attn:
368
- attn = attention(
369
- q,
370
- k,
371
- v,
372
- cu_seqlens_q=cu_seqlens_q,
373
- cu_seqlens_kv=cu_seqlens_kv,
374
- max_seqlen_q=max_seqlen_q,
375
- max_seqlen_kv=max_seqlen_kv,
376
- batch_size=x.shape[0],
377
- )
378
- else:
379
- attn = parallel_attention(
380
- self.hybrid_seq_parallel_attn,
381
- q,
382
- k,
383
- v,
384
- img_q_len=img_q.shape[1],
385
- img_kv_len=img_k.shape[1],
386
- cu_seqlens_q=cu_seqlens_q,
387
- cu_seqlens_kv=cu_seqlens_kv
388
- )
389
- # attention computation end
390
-
391
- # Compute activation in mlp stream, cat again and run second linear layer.
392
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
393
- return x + apply_gate(output, gate=mod_gate)
394
-
395
-
396
- class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
397
- """
398
- HunyuanVideo Transformer backbone
399
-
400
- Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
401
-
402
- Reference:
403
- [1] Flux.1: https://github.com/black-forest-labs/flux
404
- [2] MMDiT: http://arxiv.org/abs/2403.03206
405
-
406
- Parameters
407
- ----------
408
- args: argparse.Namespace
409
- The arguments parsed by argparse.
410
- patch_size: list
411
- The size of the patch.
412
- in_channels: int
413
- The number of input channels.
414
- out_channels: int
415
- The number of output channels.
416
- hidden_size: int
417
- The hidden size of the transformer backbone.
418
- heads_num: int
419
- The number of attention heads.
420
- mlp_width_ratio: float
421
- The ratio of the hidden size of the MLP in the transformer block.
422
- mlp_act_type: str
423
- The activation function of the MLP in the transformer block.
424
- depth_double_blocks: int
425
- The number of transformer blocks in the double blocks.
426
- depth_single_blocks: int
427
- The number of transformer blocks in the single blocks.
428
- rope_dim_list: list
429
- The dimension of the rotary embedding for t, h, w.
430
- qkv_bias: bool
431
- Whether to use bias in the qkv linear layer.
432
- qk_norm: bool
433
- Whether to use qk norm.
434
- qk_norm_type: str
435
- The type of qk norm.
436
- guidance_embed: bool
437
- Whether to use guidance embedding for distillation.
438
- text_projection: str
439
- The type of the text projection, default is single_refiner.
440
- use_attention_mask: bool
441
- Whether to use attention mask for text encoder.
442
- dtype: torch.dtype
443
- The dtype of the model.
444
- device: torch.device
445
- The device of the model.
446
- """
447
-
448
- @register_to_config
449
- def __init__(
450
- self,
451
- args: Any,
452
- patch_size: list = [1, 2, 2],
453
- in_channels: int = 4, # Should be VAE.config.latent_channels.
454
- out_channels: int = None,
455
- hidden_size: int = 3072,
456
- heads_num: int = 24,
457
- mlp_width_ratio: float = 4.0,
458
- mlp_act_type: str = "gelu_tanh",
459
- mm_double_blocks_depth: int = 20,
460
- mm_single_blocks_depth: int = 40,
461
- rope_dim_list: List[int] = [16, 56, 56],
462
- qkv_bias: bool = True,
463
- qk_norm: bool = True,
464
- qk_norm_type: str = "rms",
465
- guidance_embed: bool = False, # For modulation.
466
- text_projection: str = "single_refiner",
467
- use_attention_mask: bool = True,
468
- dtype: Optional[torch.dtype] = None,
469
- device: Optional[torch.device] = None,
470
- ):
471
- factory_kwargs = {"device": device, "dtype": dtype}
472
- super().__init__()
473
-
474
- self.patch_size = patch_size
475
- self.in_channels = in_channels
476
- self.out_channels = in_channels if out_channels is None else out_channels
477
- self.unpatchify_channels = self.out_channels
478
- self.guidance_embed = guidance_embed
479
- self.rope_dim_list = rope_dim_list
480
-
481
- # Text projection. Default to linear projection.
482
- # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
483
- self.use_attention_mask = use_attention_mask
484
- self.text_projection = text_projection
485
-
486
- self.text_states_dim = args.text_states_dim
487
- self.text_states_dim_2 = args.text_states_dim_2
488
-
489
- if hidden_size % heads_num != 0:
490
- raise ValueError(
491
- f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
492
- )
493
- pe_dim = hidden_size // heads_num
494
- if sum(rope_dim_list) != pe_dim:
495
- raise ValueError(
496
- f"Got {rope_dim_list} but expected positional dim {pe_dim}"
497
- )
498
- self.hidden_size = hidden_size
499
- self.heads_num = heads_num
500
-
501
- # image projection
502
- self.img_in = PatchEmbed(
503
- self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
504
- )
505
-
506
- # text projection
507
- if self.text_projection == "linear":
508
- self.txt_in = TextProjection(
509
- self.text_states_dim,
510
- self.hidden_size,
511
- get_activation_layer("silu"),
512
- **factory_kwargs,
513
- )
514
- elif self.text_projection == "single_refiner":
515
- self.txt_in = SingleTokenRefiner(
516
- self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
517
- )
518
- else:
519
- raise NotImplementedError(
520
- f"Unsupported text_projection: {self.text_projection}"
521
- )
522
-
523
- # time modulation
524
- self.time_in = TimestepEmbedder(
525
- self.hidden_size, get_activation_layer("silu"), **factory_kwargs
526
- )
527
-
528
- # text modulation
529
- self.vector_in = MLPEmbedder(
530
- self.text_states_dim_2, self.hidden_size, **factory_kwargs
531
- )
532
-
533
- # guidance modulation
534
- self.guidance_in = (
535
- TimestepEmbedder(
536
- self.hidden_size, get_activation_layer("silu"), **factory_kwargs
537
- )
538
- if guidance_embed
539
- else None
540
- )
541
-
542
- # double blocks
543
- self.double_blocks = nn.ModuleList(
544
- [
545
- MMDoubleStreamBlock(
546
- self.hidden_size,
547
- self.heads_num,
548
- mlp_width_ratio=mlp_width_ratio,
549
- mlp_act_type=mlp_act_type,
550
- qk_norm=qk_norm,
551
- qk_norm_type=qk_norm_type,
552
- qkv_bias=qkv_bias,
553
- **factory_kwargs,
554
- )
555
- for _ in range(mm_double_blocks_depth)
556
- ]
557
- )
558
-
559
- # single blocks
560
- self.single_blocks = nn.ModuleList(
561
- [
562
- MMSingleStreamBlock(
563
- self.hidden_size,
564
- self.heads_num,
565
- mlp_width_ratio=mlp_width_ratio,
566
- mlp_act_type=mlp_act_type,
567
- qk_norm=qk_norm,
568
- qk_norm_type=qk_norm_type,
569
- **factory_kwargs,
570
- )
571
- for _ in range(mm_single_blocks_depth)
572
- ]
573
- )
574
-
575
- self.final_layer = FinalLayer(
576
- self.hidden_size,
577
- self.patch_size,
578
- self.out_channels,
579
- get_activation_layer("silu"),
580
- **factory_kwargs,
581
- )
582
-
583
- def enable_deterministic(self):
584
- for block in self.double_blocks:
585
- block.enable_deterministic()
586
- for block in self.single_blocks:
587
- block.enable_deterministic()
588
-
589
- def disable_deterministic(self):
590
- for block in self.double_blocks:
591
- block.disable_deterministic()
592
- for block in self.single_blocks:
593
- block.disable_deterministic()
594
-
595
- def forward(
596
- self,
597
- x: torch.Tensor,
598
- t: torch.Tensor, # Should be in range(0, 1000).
599
- text_states: torch.Tensor = None,
600
- text_mask: torch.Tensor = None, # Now we don't use it.
601
- text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
602
- freqs_cos: Optional[torch.Tensor] = None,
603
- freqs_sin: Optional[torch.Tensor] = None,
604
- guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
605
- return_dict: bool = True,
606
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
607
- out = {}
608
- img = x
609
- txt = text_states
610
- _, _, ot, oh, ow = x.shape
611
- tt, th, tw = (
612
- ot // self.patch_size[0],
613
- oh // self.patch_size[1],
614
- ow // self.patch_size[2],
615
- )
616
-
617
- # Prepare modulation vectors.
618
- vec = self.time_in(t)
619
-
620
- # text modulation
621
- vec = vec + self.vector_in(text_states_2)
622
-
623
- # guidance modulation
624
- if self.guidance_embed:
625
- if guidance is None:
626
- raise ValueError(
627
- "Didn't get guidance strength for guidance distilled model."
628
- )
629
-
630
- # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
631
- vec = vec + self.guidance_in(guidance)
632
-
633
- # Embed image and text.
634
- img = self.img_in(img)
635
- if self.text_projection == "linear":
636
- txt = self.txt_in(txt)
637
- elif self.text_projection == "single_refiner":
638
- txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
639
- else:
640
- raise NotImplementedError(
641
- f"Unsupported text_projection: {self.text_projection}"
642
- )
643
-
644
- txt_seq_len = txt.shape[1]
645
- img_seq_len = img.shape[1]
646
-
647
- # Compute cu_squlens and max_seqlen for flash attention
648
- cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
649
- cu_seqlens_kv = cu_seqlens_q
650
- max_seqlen_q = img_seq_len + txt_seq_len
651
- max_seqlen_kv = max_seqlen_q
652
-
653
- freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
654
- # --------------------- Pass through DiT blocks ------------------------
655
- for _, block in enumerate(self.double_blocks):
656
- double_block_args = [
657
- img,
658
- txt,
659
- vec,
660
- cu_seqlens_q,
661
- cu_seqlens_kv,
662
- max_seqlen_q,
663
- max_seqlen_kv,
664
- freqs_cis,
665
- ]
666
-
667
- img, txt = block(*double_block_args)
668
-
669
- # Merge txt and img to pass through single stream blocks.
670
- x = torch.cat((img, txt), 1)
671
- if len(self.single_blocks) > 0:
672
- for _, block in enumerate(self.single_blocks):
673
- single_block_args = [
674
- x,
675
- vec,
676
- txt_seq_len,
677
- cu_seqlens_q,
678
- cu_seqlens_kv,
679
- max_seqlen_q,
680
- max_seqlen_kv,
681
- (freqs_cos, freqs_sin),
682
- ]
683
-
684
- x = block(*single_block_args)
685
-
686
- img = x[:, :img_seq_len, ...]
687
-
688
- # ---------------------------- Final layer ------------------------------
689
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
690
-
691
- img = self.unpatchify(img, tt, th, tw)
692
- if return_dict:
693
- out["x"] = img
694
- return out
695
- return img
696
-
697
- def unpatchify(self, x, t, h, w):
698
- """
699
- x: (N, T, patch_size**2 * C)
700
- imgs: (N, H, W, C)
701
- """
702
- c = self.unpatchify_channels
703
- pt, ph, pw = self.patch_size
704
- assert t * h * w == x.shape[1]
705
-
706
- x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
707
- x = torch.einsum("nthwcopq->nctohpwq", x)
708
- imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
709
-
710
- return imgs
711
-
712
- def params_count(self):
713
- counts = {
714
- "double": sum(
715
- [
716
- sum(p.numel() for p in block.img_attn_qkv.parameters())
717
- + sum(p.numel() for p in block.img_attn_proj.parameters())
718
- + sum(p.numel() for p in block.img_mlp.parameters())
719
- + sum(p.numel() for p in block.txt_attn_qkv.parameters())
720
- + sum(p.numel() for p in block.txt_attn_proj.parameters())
721
- + sum(p.numel() for p in block.txt_mlp.parameters())
722
- for block in self.double_blocks
723
- ]
724
- ),
725
- "single": sum(
726
- [
727
- sum(p.numel() for p in block.linear1.parameters())
728
- + sum(p.numel() for p in block.linear2.parameters())
729
- for block in self.single_blocks
730
- ]
731
- ),
732
- "total": sum(p.numel() for p in self.parameters()),
733
- }
734
- counts["attn+mlp"] = counts["double"] + counts["single"]
735
- return counts
736
-
737
-
738
- #################################################################################
739
- # HunyuanVideo Configs #
740
- #################################################################################
741
-
742
- HUNYUAN_VIDEO_CONFIG = {
743
- "HYVideo-T/2": {
744
- "mm_double_blocks_depth": 20,
745
- "mm_single_blocks_depth": 40,
746
- "rope_dim_list": [16, 56, 56],
747
- "hidden_size": 3072,
748
- "heads_num": 24,
749
- "mlp_width_ratio": 4,
750
- },
751
- "HYVideo-T/2-cfgdistill": {
752
- "mm_double_blocks_depth": 20,
753
- "mm_single_blocks_depth": 40,
754
- "rope_dim_list": [16, 56, 56],
755
- "hidden_size": 3072,
756
- "heads_num": 24,
757
- "mlp_width_ratio": 4,
758
- "guidance_embed": True,
759
- },
760
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/modulate_layers.py DELETED
@@ -1,76 +0,0 @@
1
- from typing import Callable
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class ModulateDiT(nn.Module):
8
- """Modulation layer for DiT."""
9
- def __init__(
10
- self,
11
- hidden_size: int,
12
- factor: int,
13
- act_layer: Callable,
14
- dtype=None,
15
- device=None,
16
- ):
17
- factory_kwargs = {"dtype": dtype, "device": device}
18
- super().__init__()
19
- self.act = act_layer()
20
- self.linear = nn.Linear(
21
- hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
- )
23
- # Zero-initialize the modulation
24
- nn.init.zeros_(self.linear.weight)
25
- nn.init.zeros_(self.linear.bias)
26
-
27
- def forward(self, x: torch.Tensor) -> torch.Tensor:
28
- return self.linear(self.act(x))
29
-
30
-
31
- def modulate(x, shift=None, scale=None):
32
- """modulate by shift and scale
33
-
34
- Args:
35
- x (torch.Tensor): input tensor.
36
- shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
- scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
-
39
- Returns:
40
- torch.Tensor: the output tensor after modulate.
41
- """
42
- if scale is None and shift is None:
43
- return x
44
- elif shift is None:
45
- return x * (1 + scale.unsqueeze(1))
46
- elif scale is None:
47
- return x + shift.unsqueeze(1)
48
- else:
49
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
-
51
-
52
- def apply_gate(x, gate=None, tanh=False):
53
- """AI is creating summary for apply_gate
54
-
55
- Args:
56
- x (torch.Tensor): input tensor.
57
- gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
- tanh (bool, optional): whether to use tanh function. Defaults to False.
59
-
60
- Returns:
61
- torch.Tensor: the output tensor after apply gate.
62
- """
63
- if gate is None:
64
- return x
65
- if tanh:
66
- return x * gate.unsqueeze(1).tanh()
67
- else:
68
- return x * gate.unsqueeze(1)
69
-
70
-
71
- def ckpt_wrapper(module):
72
- def ckpt_forward(*inputs):
73
- outputs = module(*inputs)
74
- return outputs
75
-
76
- return ckpt_forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/norm_layers.py DELETED
@@ -1,77 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class RMSNorm(nn.Module):
6
- def __init__(
7
- self,
8
- dim: int,
9
- elementwise_affine=True,
10
- eps: float = 1e-6,
11
- device=None,
12
- dtype=None,
13
- ):
14
- """
15
- Initialize the RMSNorm normalization layer.
16
-
17
- Args:
18
- dim (int): The dimension of the input tensor.
19
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
-
21
- Attributes:
22
- eps (float): A small value added to the denominator for numerical stability.
23
- weight (nn.Parameter): Learnable scaling parameter.
24
-
25
- """
26
- factory_kwargs = {"device": device, "dtype": dtype}
27
- super().__init__()
28
- self.eps = eps
29
- if elementwise_affine:
30
- self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
-
32
- def _norm(self, x):
33
- """
34
- Apply the RMSNorm normalization to the input tensor.
35
-
36
- Args:
37
- x (torch.Tensor): The input tensor.
38
-
39
- Returns:
40
- torch.Tensor: The normalized tensor.
41
-
42
- """
43
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
-
45
- def forward(self, x):
46
- """
47
- Forward pass through the RMSNorm layer.
48
-
49
- Args:
50
- x (torch.Tensor): The input tensor.
51
-
52
- Returns:
53
- torch.Tensor: The output tensor after applying RMSNorm.
54
-
55
- """
56
- output = self._norm(x.float()).type_as(x)
57
- if hasattr(self, "weight"):
58
- output = output * self.weight
59
- return output
60
-
61
-
62
- def get_norm_layer(norm_layer):
63
- """
64
- Get the normalization layer.
65
-
66
- Args:
67
- norm_layer (str): The type of normalization layer.
68
-
69
- Returns:
70
- norm_layer (nn.Module): The normalization layer.
71
- """
72
- if norm_layer == "layer":
73
- return nn.LayerNorm
74
- elif norm_layer == "rms":
75
- return RMSNorm
76
- else:
77
- raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/posemb_layers.py DELETED
@@ -1,310 +0,0 @@
1
- import torch
2
- from typing import Union, Tuple, List
3
-
4
-
5
- def _to_tuple(x, dim=2):
6
- if isinstance(x, int):
7
- return (x,) * dim
8
- elif len(x) == dim:
9
- return x
10
- else:
11
- raise ValueError(f"Expected length {dim} or int, but got {x}")
12
-
13
-
14
- def get_meshgrid_nd(start, *args, dim=2):
15
- """
16
- Get n-D meshgrid with start, stop and num.
17
-
18
- Args:
19
- start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
- step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
- should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
- n-tuples.
23
- *args: See above.
24
- dim (int): Dimension of the meshgrid. Defaults to 2.
25
-
26
- Returns:
27
- grid (np.ndarray): [dim, ...]
28
- """
29
- if len(args) == 0:
30
- # start is grid_size
31
- num = _to_tuple(start, dim=dim)
32
- start = (0,) * dim
33
- stop = num
34
- elif len(args) == 1:
35
- # start is start, args[0] is stop, step is 1
36
- start = _to_tuple(start, dim=dim)
37
- stop = _to_tuple(args[0], dim=dim)
38
- num = [stop[i] - start[i] for i in range(dim)]
39
- elif len(args) == 2:
40
- # start is start, args[0] is stop, args[1] is num
41
- start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
- stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
- num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
- else:
45
- raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
-
47
- # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
- axis_grid = []
49
- for i in range(dim):
50
- a, b, n = start[i], stop[i], num[i]
51
- g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
- axis_grid.append(g)
53
- grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
- grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
-
56
- return grid
57
-
58
-
59
- #################################################################################
60
- # Rotary Positional Embedding Functions #
61
- #################################################################################
62
- # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
-
64
-
65
- def reshape_for_broadcast(
66
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
- x: torch.Tensor,
68
- head_first=False,
69
- ):
70
- """
71
- Reshape frequency tensor for broadcasting it with another tensor.
72
-
73
- This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
- for the purpose of broadcasting the frequency tensor during element-wise operations.
75
-
76
- Notes:
77
- When using FlashMHAModified, head_first should be False.
78
- When using Attention, head_first should be True.
79
-
80
- Args:
81
- freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
- x (torch.Tensor): Target tensor for broadcasting compatibility.
83
- head_first (bool): head dimension first (except batch dim) or not.
84
-
85
- Returns:
86
- torch.Tensor: Reshaped frequency tensor.
87
-
88
- Raises:
89
- AssertionError: If the frequency tensor doesn't match the expected shape.
90
- AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
- """
92
- ndim = x.ndim
93
- assert 0 <= 1 < ndim
94
-
95
- if isinstance(freqs_cis, tuple):
96
- # freqs_cis: (cos, sin) in real space
97
- if head_first:
98
- assert freqs_cis[0].shape == (
99
- x.shape[-2],
100
- x.shape[-1],
101
- ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
- shape = [
103
- d if i == ndim - 2 or i == ndim - 1 else 1
104
- for i, d in enumerate(x.shape)
105
- ]
106
- else:
107
- assert freqs_cis[0].shape == (
108
- x.shape[1],
109
- x.shape[-1],
110
- ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
- return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
- else:
114
- # freqs_cis: values in complex space
115
- if head_first:
116
- assert freqs_cis.shape == (
117
- x.shape[-2],
118
- x.shape[-1],
119
- ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
- shape = [
121
- d if i == ndim - 2 or i == ndim - 1 else 1
122
- for i, d in enumerate(x.shape)
123
- ]
124
- else:
125
- assert freqs_cis.shape == (
126
- x.shape[1],
127
- x.shape[-1],
128
- ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
- return freqs_cis.view(*shape)
131
-
132
-
133
- def rotate_half(x):
134
- x_real, x_imag = (
135
- x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
- ) # [B, S, H, D//2]
137
- return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
-
139
-
140
- def apply_rotary_emb(
141
- xq: torch.Tensor,
142
- xk: torch.Tensor,
143
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
- head_first: bool = False,
145
- ) -> Tuple[torch.Tensor, torch.Tensor]:
146
- """
147
- Apply rotary embeddings to input tensors using the given frequency tensor.
148
-
149
- This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
- frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
- is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
- returned as real tensors.
153
-
154
- Args:
155
- xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
- xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
- freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
- head_first (bool): head dimension first (except batch dim) or not.
159
-
160
- Returns:
161
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
-
163
- """
164
- xk_out = None
165
- if isinstance(freqs_cis, tuple):
166
- cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
- cos, sin = cos.to(xq.device), sin.to(xq.device)
168
- # real * cos - imag * sin
169
- # imag * cos + real * sin
170
- xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
- xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
- else:
173
- # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
- xq_ = torch.view_as_complex(
175
- xq.float().reshape(*xq.shape[:-1], -1, 2)
176
- ) # [B, S, H, D//2]
177
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
- xq.device
179
- ) # [S, D//2] --> [1, S, 1, D//2]
180
- # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
- # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
- xk_ = torch.view_as_complex(
184
- xk.float().reshape(*xk.shape[:-1], -1, 2)
185
- ) # [B, S, H, D//2]
186
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
-
188
- return xq_out, xk_out
189
-
190
-
191
- def get_nd_rotary_pos_embed(
192
- rope_dim_list,
193
- start,
194
- *args,
195
- theta=10000.0,
196
- use_real=False,
197
- theta_rescale_factor: Union[float, List[float]] = 1.0,
198
- interpolation_factor: Union[float, List[float]] = 1.0,
199
- ):
200
- """
201
- This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
-
203
- Args:
204
- rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
- sum(rope_dim_list) should equal to head_dim of attention layer.
206
- start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
- args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
- *args: See above.
209
- theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
- use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
- Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
- part and an imaginary part separately.
213
- theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
-
215
- Returns:
216
- pos_embed (torch.Tensor): [HW, D/2]
217
- """
218
-
219
- grid = get_meshgrid_nd(
220
- start, *args, dim=len(rope_dim_list)
221
- ) # [3, W, H, D] / [2, W, H]
222
-
223
- if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
- theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
- elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
- theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
- assert len(theta_rescale_factor) == len(
228
- rope_dim_list
229
- ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
-
231
- if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
- interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
- elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
- interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
- assert len(interpolation_factor) == len(
236
- rope_dim_list
237
- ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
-
239
- # use 1/ndim of dimensions to encode grid_axis
240
- embs = []
241
- for i in range(len(rope_dim_list)):
242
- emb = get_1d_rotary_pos_embed(
243
- rope_dim_list[i],
244
- grid[i].reshape(-1),
245
- theta,
246
- use_real=use_real,
247
- theta_rescale_factor=theta_rescale_factor[i],
248
- interpolation_factor=interpolation_factor[i],
249
- ) # 2 x [WHD, rope_dim_list[i]]
250
- embs.append(emb)
251
-
252
- if use_real:
253
- cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
- sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
- return cos, sin
256
- else:
257
- emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
- return emb
259
-
260
-
261
- def get_1d_rotary_pos_embed(
262
- dim: int,
263
- pos: Union[torch.FloatTensor, int],
264
- theta: float = 10000.0,
265
- use_real: bool = False,
266
- theta_rescale_factor: float = 1.0,
267
- interpolation_factor: float = 1.0,
268
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
- """
270
- Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
- (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
-
273
- This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
- and the end index 'end'. The 'theta' parameter scales the frequencies.
275
- The returned tensor contains complex values in complex64 data type.
276
-
277
- Args:
278
- dim (int): Dimension of the frequency tensor.
279
- pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
- theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
- use_real (bool, optional): If True, return real part and imaginary part separately.
282
- Otherwise, return complex numbers.
283
- theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
-
285
- Returns:
286
- freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
- freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
- """
289
- if isinstance(pos, int):
290
- pos = torch.arange(pos).float()
291
-
292
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
- # has some connection to NTK literature
294
- if theta_rescale_factor != 1.0:
295
- theta *= theta_rescale_factor ** (dim / (dim - 2))
296
-
297
- freqs = 1.0 / (
298
- theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
- ) # [D/2]
300
- # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
- freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
- if use_real:
303
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
- return freqs_cos, freqs_sin
306
- else:
307
- freqs_cis = torch.polar(
308
- torch.ones_like(freqs), freqs
309
- ) # complex64 # [S, D/2]
310
- return freqs_cis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/modules/token_refiner.py DELETED
@@ -1,236 +0,0 @@
1
- from typing import Optional
2
-
3
- from einops import rearrange
4
- import torch
5
- import torch.nn as nn
6
-
7
- from .activation_layers import get_activation_layer
8
- from .attenion import attention
9
- from .norm_layers import get_norm_layer
10
- from .embed_layers import TimestepEmbedder, TextProjection
11
- from .attenion import attention
12
- from .mlp_layers import MLP
13
- from .modulate_layers import modulate, apply_gate
14
-
15
-
16
- class IndividualTokenRefinerBlock(nn.Module):
17
- def __init__(
18
- self,
19
- hidden_size,
20
- heads_num,
21
- mlp_width_ratio: str = 4.0,
22
- mlp_drop_rate: float = 0.0,
23
- act_type: str = "silu",
24
- qk_norm: bool = False,
25
- qk_norm_type: str = "layer",
26
- qkv_bias: bool = True,
27
- dtype: Optional[torch.dtype] = None,
28
- device: Optional[torch.device] = None,
29
- ):
30
- factory_kwargs = {"device": device, "dtype": dtype}
31
- super().__init__()
32
- self.heads_num = heads_num
33
- head_dim = hidden_size // heads_num
34
- mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
-
36
- self.norm1 = nn.LayerNorm(
37
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
38
- )
39
- self.self_attn_qkv = nn.Linear(
40
- hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
41
- )
42
- qk_norm_layer = get_norm_layer(qk_norm_type)
43
- self.self_attn_q_norm = (
44
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
45
- if qk_norm
46
- else nn.Identity()
47
- )
48
- self.self_attn_k_norm = (
49
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
50
- if qk_norm
51
- else nn.Identity()
52
- )
53
- self.self_attn_proj = nn.Linear(
54
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
55
- )
56
-
57
- self.norm2 = nn.LayerNorm(
58
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
59
- )
60
- act_layer = get_activation_layer(act_type)
61
- self.mlp = MLP(
62
- in_channels=hidden_size,
63
- hidden_channels=mlp_hidden_dim,
64
- act_layer=act_layer,
65
- drop=mlp_drop_rate,
66
- **factory_kwargs,
67
- )
68
-
69
- self.adaLN_modulation = nn.Sequential(
70
- act_layer(),
71
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
72
- )
73
- # Zero-initialize the modulation
74
- nn.init.zeros_(self.adaLN_modulation[1].weight)
75
- nn.init.zeros_(self.adaLN_modulation[1].bias)
76
-
77
- def forward(
78
- self,
79
- x: torch.Tensor,
80
- c: torch.Tensor, # timestep_aware_representations + context_aware_representations
81
- attn_mask: torch.Tensor = None,
82
- ):
83
- gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
84
-
85
- norm_x = self.norm1(x)
86
- qkv = self.self_attn_qkv(norm_x)
87
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
88
- # Apply QK-Norm if needed
89
- q = self.self_attn_q_norm(q).to(v)
90
- k = self.self_attn_k_norm(k).to(v)
91
-
92
- # Self-Attention
93
- attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
94
-
95
- x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
96
-
97
- # FFN Layer
98
- x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
99
-
100
- return x
101
-
102
-
103
- class IndividualTokenRefiner(nn.Module):
104
- def __init__(
105
- self,
106
- hidden_size,
107
- heads_num,
108
- depth,
109
- mlp_width_ratio: float = 4.0,
110
- mlp_drop_rate: float = 0.0,
111
- act_type: str = "silu",
112
- qk_norm: bool = False,
113
- qk_norm_type: str = "layer",
114
- qkv_bias: bool = True,
115
- dtype: Optional[torch.dtype] = None,
116
- device: Optional[torch.device] = None,
117
- ):
118
- factory_kwargs = {"device": device, "dtype": dtype}
119
- super().__init__()
120
- self.blocks = nn.ModuleList(
121
- [
122
- IndividualTokenRefinerBlock(
123
- hidden_size=hidden_size,
124
- heads_num=heads_num,
125
- mlp_width_ratio=mlp_width_ratio,
126
- mlp_drop_rate=mlp_drop_rate,
127
- act_type=act_type,
128
- qk_norm=qk_norm,
129
- qk_norm_type=qk_norm_type,
130
- qkv_bias=qkv_bias,
131
- **factory_kwargs,
132
- )
133
- for _ in range(depth)
134
- ]
135
- )
136
-
137
- def forward(
138
- self,
139
- x: torch.Tensor,
140
- c: torch.LongTensor,
141
- mask: Optional[torch.Tensor] = None,
142
- ):
143
- self_attn_mask = None
144
- if mask is not None:
145
- batch_size = mask.shape[0]
146
- seq_len = mask.shape[1]
147
- mask = mask.to(x.device)
148
- # batch_size x 1 x seq_len x seq_len
149
- self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
150
- 1, 1, seq_len, 1
151
- )
152
- # batch_size x 1 x seq_len x seq_len
153
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
154
- # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
155
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
156
- # avoids self-attention weight being NaN for padding tokens
157
- self_attn_mask[:, :, :, 0] = True
158
-
159
- for block in self.blocks:
160
- x = block(x, c, self_attn_mask)
161
- return x
162
-
163
-
164
- class SingleTokenRefiner(nn.Module):
165
- """
166
- A single token refiner block for llm text embedding refine.
167
- """
168
- def __init__(
169
- self,
170
- in_channels,
171
- hidden_size,
172
- heads_num,
173
- depth,
174
- mlp_width_ratio: float = 4.0,
175
- mlp_drop_rate: float = 0.0,
176
- act_type: str = "silu",
177
- qk_norm: bool = False,
178
- qk_norm_type: str = "layer",
179
- qkv_bias: bool = True,
180
- attn_mode: str = "torch",
181
- dtype: Optional[torch.dtype] = None,
182
- device: Optional[torch.device] = None,
183
- ):
184
- factory_kwargs = {"device": device, "dtype": dtype}
185
- super().__init__()
186
- self.attn_mode = attn_mode
187
- assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
188
-
189
- self.input_embedder = nn.Linear(
190
- in_channels, hidden_size, bias=True, **factory_kwargs
191
- )
192
-
193
- act_layer = get_activation_layer(act_type)
194
- # Build timestep embedding layer
195
- self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
196
- # Build context embedding layer
197
- self.c_embedder = TextProjection(
198
- in_channels, hidden_size, act_layer, **factory_kwargs
199
- )
200
-
201
- self.individual_token_refiner = IndividualTokenRefiner(
202
- hidden_size=hidden_size,
203
- heads_num=heads_num,
204
- depth=depth,
205
- mlp_width_ratio=mlp_width_ratio,
206
- mlp_drop_rate=mlp_drop_rate,
207
- act_type=act_type,
208
- qk_norm=qk_norm,
209
- qk_norm_type=qk_norm_type,
210
- qkv_bias=qkv_bias,
211
- **factory_kwargs,
212
- )
213
-
214
- def forward(
215
- self,
216
- x: torch.Tensor,
217
- t: torch.LongTensor,
218
- mask: Optional[torch.LongTensor] = None,
219
- ):
220
- timestep_aware_representations = self.t_embedder(t)
221
-
222
- if mask is None:
223
- context_aware_representations = x.mean(dim=1)
224
- else:
225
- mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
226
- context_aware_representations = (x * mask_float).sum(
227
- dim=1
228
- ) / mask_float.sum(dim=1)
229
- context_aware_representations = self.c_embedder(context_aware_representations)
230
- c = timestep_aware_representations + context_aware_representations
231
-
232
- x = self.input_embedder(x)
233
-
234
- x = self.individual_token_refiner(x, c, mask)
235
-
236
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/prompt_rewrite.py DELETED
@@ -1,51 +0,0 @@
1
- normal_mode_prompt = """Normal mode - Video Recaption Task:
2
-
3
- You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4
-
5
- 0. Preserve ALL information, including style words and technical terms.
6
-
7
- 1. If the input is in Chinese, translate the entire description to English.
8
-
9
- 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10
-
11
- 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12
-
13
- 4. Output ALL must be in English.
14
-
15
- Given Input:
16
- input: "{input}"
17
- """
18
-
19
-
20
- master_mode_prompt = """Master mode - Video Recaption Task:
21
-
22
- You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
23
-
24
- 0. Preserve ALL information, including style words and technical terms.
25
-
26
- 1. If the input is in Chinese, translate the entire description to English.
27
-
28
- 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
29
-
30
- 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
31
-
32
- 4. Output ALL must be in English.
33
-
34
- Given Input:
35
- input: "{input}"
36
- """
37
-
38
- def get_rewrite_prompt(ori_prompt, mode="Normal"):
39
- if mode == "Normal":
40
- prompt = normal_mode_prompt.format(input=ori_prompt)
41
- elif mode == "Master":
42
- prompt = master_mode_prompt.format(input=ori_prompt)
43
- else:
44
- raise Exception("Only supports Normal and Normal", mode)
45
- return prompt
46
-
47
- ori_prompt = "一只小狗在草地上奔跑。"
48
- normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
49
- master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
50
-
51
- # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/text_encoder/__init__.py DELETED
@@ -1,357 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple
3
- from copy import deepcopy
4
-
5
- import torch
6
- import torch.nn as nn
7
- from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
8
- from transformers.utils import ModelOutput
9
-
10
- from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
11
- from ..constants import PRECISION_TO_TYPE
12
-
13
-
14
- def use_default(value, default):
15
- return value if value is not None else default
16
-
17
-
18
- def load_text_encoder(
19
- text_encoder_type,
20
- text_encoder_precision=None,
21
- text_encoder_path=None,
22
- logger=None,
23
- device=None,
24
- ):
25
- if text_encoder_path is None:
26
- text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
27
- if logger is not None:
28
- logger.info(
29
- f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}"
30
- )
31
-
32
- if text_encoder_type == "clipL":
33
- text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
34
- text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
35
- elif text_encoder_type == "llm":
36
- text_encoder = AutoModel.from_pretrained(
37
- text_encoder_path, low_cpu_mem_usage=True
38
- )
39
- text_encoder.final_layer_norm = text_encoder.norm
40
- else:
41
- raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
42
- # from_pretrained will ensure that the model is in eval mode.
43
-
44
- if text_encoder_precision is not None:
45
- text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
46
-
47
- text_encoder.requires_grad_(False)
48
-
49
- if logger is not None:
50
- logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
51
-
52
- if device is not None:
53
- text_encoder = text_encoder.to(device)
54
-
55
- return text_encoder, text_encoder_path
56
-
57
-
58
- def load_tokenizer(
59
- tokenizer_type, tokenizer_path=None, padding_side="right", logger=None
60
- ):
61
- if tokenizer_path is None:
62
- tokenizer_path = TOKENIZER_PATH[tokenizer_type]
63
- if logger is not None:
64
- logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
65
-
66
- if tokenizer_type == "clipL":
67
- tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
68
- elif tokenizer_type == "llm":
69
- tokenizer = AutoTokenizer.from_pretrained(
70
- tokenizer_path, padding_side=padding_side
71
- )
72
- else:
73
- raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
74
-
75
- return tokenizer, tokenizer_path
76
-
77
-
78
- @dataclass
79
- class TextEncoderModelOutput(ModelOutput):
80
- """
81
- Base class for model's outputs that also contains a pooling of the last hidden states.
82
-
83
- Args:
84
- hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
85
- Sequence of hidden-states at the output of the last layer of the model.
86
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
87
- Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
88
- hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
89
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
90
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
91
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
- text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
93
- List of decoded texts.
94
- """
95
-
96
- hidden_state: torch.FloatTensor = None
97
- attention_mask: Optional[torch.LongTensor] = None
98
- hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
99
- text_outputs: Optional[list] = None
100
-
101
-
102
- class TextEncoder(nn.Module):
103
- def __init__(
104
- self,
105
- text_encoder_type: str,
106
- max_length: int,
107
- text_encoder_precision: Optional[str] = None,
108
- text_encoder_path: Optional[str] = None,
109
- tokenizer_type: Optional[str] = None,
110
- tokenizer_path: Optional[str] = None,
111
- output_key: Optional[str] = None,
112
- use_attention_mask: bool = True,
113
- input_max_length: Optional[int] = None,
114
- prompt_template: Optional[dict] = None,
115
- prompt_template_video: Optional[dict] = None,
116
- hidden_state_skip_layer: Optional[int] = None,
117
- apply_final_norm: bool = False,
118
- reproduce: bool = False,
119
- logger=None,
120
- device=None,
121
- ):
122
- super().__init__()
123
- self.text_encoder_type = text_encoder_type
124
- self.max_length = max_length
125
- self.precision = text_encoder_precision
126
- self.model_path = text_encoder_path
127
- self.tokenizer_type = (
128
- tokenizer_type if tokenizer_type is not None else text_encoder_type
129
- )
130
- self.tokenizer_path = (
131
- tokenizer_path if tokenizer_path is not None else text_encoder_path
132
- )
133
- self.use_attention_mask = use_attention_mask
134
- if prompt_template_video is not None:
135
- assert (
136
- use_attention_mask is True
137
- ), "Attention mask is True required when training videos."
138
- self.input_max_length = (
139
- input_max_length if input_max_length is not None else max_length
140
- )
141
- self.prompt_template = prompt_template
142
- self.prompt_template_video = prompt_template_video
143
- self.hidden_state_skip_layer = hidden_state_skip_layer
144
- self.apply_final_norm = apply_final_norm
145
- self.reproduce = reproduce
146
- self.logger = logger
147
-
148
- self.use_template = self.prompt_template is not None
149
- if self.use_template:
150
- assert (
151
- isinstance(self.prompt_template, dict)
152
- and "template" in self.prompt_template
153
- ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
154
- assert "{}" in str(self.prompt_template["template"]), (
155
- "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
156
- f"got {self.prompt_template['template']}"
157
- )
158
-
159
- self.use_video_template = self.prompt_template_video is not None
160
- if self.use_video_template:
161
- if self.prompt_template_video is not None:
162
- assert (
163
- isinstance(self.prompt_template_video, dict)
164
- and "template" in self.prompt_template_video
165
- ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
166
- assert "{}" in str(self.prompt_template_video["template"]), (
167
- "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
168
- f"got {self.prompt_template_video['template']}"
169
- )
170
-
171
- if "t5" in text_encoder_type:
172
- self.output_key = output_key or "last_hidden_state"
173
- elif "clip" in text_encoder_type:
174
- self.output_key = output_key or "pooler_output"
175
- elif "llm" in text_encoder_type or "glm" in text_encoder_type:
176
- self.output_key = output_key or "last_hidden_state"
177
- else:
178
- raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
179
-
180
- self.model, self.model_path = load_text_encoder(
181
- text_encoder_type=self.text_encoder_type,
182
- text_encoder_precision=self.precision,
183
- text_encoder_path=self.model_path,
184
- logger=self.logger,
185
- device=device,
186
- )
187
- self.dtype = self.model.dtype
188
- self.device = self.model.device
189
-
190
- self.tokenizer, self.tokenizer_path = load_tokenizer(
191
- tokenizer_type=self.tokenizer_type,
192
- tokenizer_path=self.tokenizer_path,
193
- padding_side="right",
194
- logger=self.logger,
195
- )
196
-
197
- def __repr__(self):
198
- return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
199
-
200
- @staticmethod
201
- def apply_text_to_template(text, template, prevent_empty_text=True):
202
- """
203
- Apply text to template.
204
-
205
- Args:
206
- text (str): Input text.
207
- template (str or list): Template string or list of chat conversation.
208
- prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
209
- by adding a space. Defaults to True.
210
- """
211
- if isinstance(template, str):
212
- # Will send string to tokenizer. Used for llm
213
- return template.format(text)
214
- else:
215
- raise TypeError(f"Unsupported template type: {type(template)}")
216
-
217
- def text2tokens(self, text, data_type="image"):
218
- """
219
- Tokenize the input text.
220
-
221
- Args:
222
- text (str or list): Input text.
223
- """
224
- tokenize_input_type = "str"
225
- if self.use_template:
226
- if data_type == "image":
227
- prompt_template = self.prompt_template["template"]
228
- elif data_type == "video":
229
- prompt_template = self.prompt_template_video["template"]
230
- else:
231
- raise ValueError(f"Unsupported data type: {data_type}")
232
- if isinstance(text, (list, tuple)):
233
- text = [
234
- self.apply_text_to_template(one_text, prompt_template)
235
- for one_text in text
236
- ]
237
- if isinstance(text[0], list):
238
- tokenize_input_type = "list"
239
- elif isinstance(text, str):
240
- text = self.apply_text_to_template(text, prompt_template)
241
- if isinstance(text, list):
242
- tokenize_input_type = "list"
243
- else:
244
- raise TypeError(f"Unsupported text type: {type(text)}")
245
-
246
- kwargs = dict(
247
- truncation=True,
248
- max_length=self.max_length,
249
- padding="max_length",
250
- return_tensors="pt",
251
- )
252
- if tokenize_input_type == "str":
253
- return self.tokenizer(
254
- text,
255
- return_length=False,
256
- return_overflowing_tokens=False,
257
- return_attention_mask=True,
258
- **kwargs,
259
- )
260
- elif tokenize_input_type == "list":
261
- return self.tokenizer.apply_chat_template(
262
- text,
263
- add_generation_prompt=True,
264
- tokenize=True,
265
- return_dict=True,
266
- **kwargs,
267
- )
268
- else:
269
- raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
270
-
271
- def encode(
272
- self,
273
- batch_encoding,
274
- use_attention_mask=None,
275
- output_hidden_states=False,
276
- do_sample=None,
277
- hidden_state_skip_layer=None,
278
- return_texts=False,
279
- data_type="image",
280
- device=None,
281
- ):
282
- """
283
- Args:
284
- batch_encoding (dict): Batch encoding from tokenizer.
285
- use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
286
- Defaults to None.
287
- output_hidden_states (bool): Whether to output hidden states. If False, return the value of
288
- self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
289
- output_hidden_states will be set True. Defaults to False.
290
- do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
291
- When self.produce is False, do_sample is set to True by default.
292
- hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
293
- If None, self.output_key will be used. Defaults to None.
294
- return_texts (bool): Whether to return the decoded texts. Defaults to False.
295
- """
296
- device = self.model.device if device is None else device
297
- use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
298
- hidden_state_skip_layer = use_default(
299
- hidden_state_skip_layer, self.hidden_state_skip_layer
300
- )
301
- do_sample = use_default(do_sample, not self.reproduce)
302
- attention_mask = (
303
- batch_encoding["attention_mask"].to(device) if use_attention_mask else None
304
- )
305
- outputs = self.model(
306
- input_ids=batch_encoding["input_ids"].to(device),
307
- attention_mask=attention_mask,
308
- output_hidden_states=output_hidden_states
309
- or hidden_state_skip_layer is not None,
310
- )
311
- if hidden_state_skip_layer is not None:
312
- last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
313
- # Real last hidden state already has layer norm applied. So here we only apply it
314
- # for intermediate layers.
315
- if hidden_state_skip_layer > 0 and self.apply_final_norm:
316
- last_hidden_state = self.model.final_layer_norm(last_hidden_state)
317
- else:
318
- last_hidden_state = outputs[self.output_key]
319
-
320
- # Remove hidden states of instruction tokens, only keep prompt tokens.
321
- if self.use_template:
322
- if data_type == "image":
323
- crop_start = self.prompt_template.get("crop_start", -1)
324
- elif data_type == "video":
325
- crop_start = self.prompt_template_video.get("crop_start", -1)
326
- else:
327
- raise ValueError(f"Unsupported data type: {data_type}")
328
- if crop_start > 0:
329
- last_hidden_state = last_hidden_state[:, crop_start:]
330
- attention_mask = (
331
- attention_mask[:, crop_start:] if use_attention_mask else None
332
- )
333
-
334
- if output_hidden_states:
335
- return TextEncoderModelOutput(
336
- last_hidden_state, attention_mask, outputs.hidden_states
337
- )
338
- return TextEncoderModelOutput(last_hidden_state, attention_mask)
339
-
340
- def forward(
341
- self,
342
- text,
343
- use_attention_mask=None,
344
- output_hidden_states=False,
345
- do_sample=False,
346
- hidden_state_skip_layer=None,
347
- return_texts=False,
348
- ):
349
- batch_encoding = self.text2tokens(text)
350
- return self.encode(
351
- batch_encoding,
352
- use_attention_mask=use_attention_mask,
353
- output_hidden_states=output_hidden_states,
354
- do_sample=do_sample,
355
- hidden_state_skip_layer=hidden_state_skip_layer,
356
- return_texts=return_texts,
357
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/utils/__init__.py DELETED
File without changes
hyvideo/utils/data_utils.py DELETED
@@ -1,15 +0,0 @@
1
- import numpy as np
2
- import math
3
-
4
-
5
- def align_to(value, alignment):
6
- """align hight, width according to alignment
7
-
8
- Args:
9
- value (int): height or width
10
- alignment (int): target alignment factor
11
-
12
- Returns:
13
- int: the aligned value
14
- """
15
- return int(math.ceil(value / alignment) * alignment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/utils/file_utils.py DELETED
@@ -1,70 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from einops import rearrange
4
-
5
- import torch
6
- import torchvision
7
- import numpy as np
8
- import imageio
9
-
10
- CODE_SUFFIXES = {
11
- ".py", # Python codes
12
- ".sh", # Shell scripts
13
- ".yaml",
14
- ".yml", # Configuration files
15
- }
16
-
17
-
18
- def safe_dir(path):
19
- """
20
- Create a directory (or the parent directory of a file) if it does not exist.
21
-
22
- Args:
23
- path (str or Path): Path to the directory.
24
-
25
- Returns:
26
- path (Path): Path object of the directory.
27
- """
28
- path = Path(path)
29
- path.mkdir(exist_ok=True, parents=True)
30
- return path
31
-
32
-
33
- def safe_file(path):
34
- """
35
- Create the parent directory of a file if it does not exist.
36
-
37
- Args:
38
- path (str or Path): Path to the file.
39
-
40
- Returns:
41
- path (Path): Path object of the file.
42
- """
43
- path = Path(path)
44
- path.parent.mkdir(exist_ok=True, parents=True)
45
- return path
46
-
47
- def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
48
- """save videos by video tensor
49
- copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
50
-
51
- Args:
52
- videos (torch.Tensor): video tensor predicted by the model
53
- path (str): path to save video
54
- rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
55
- n_rows (int, optional): Defaults to 1.
56
- fps (int, optional): video save fps. Defaults to 8.
57
- """
58
- videos = rearrange(videos, "b c t h w -> t b c h w")
59
- outputs = []
60
- for x in videos:
61
- x = torchvision.utils.make_grid(x, nrow=n_rows)
62
- x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
63
- if rescale:
64
- x = (x + 1.0) / 2.0 # -1,1 -> 0,1
65
- x = torch.clamp(x, 0, 1)
66
- x = (x * 255).numpy().astype(np.uint8)
67
- outputs.append(x)
68
-
69
- os.makedirs(os.path.dirname(path), exist_ok=True)
70
- imageio.mimsave(path, outputs, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/utils/helpers.py DELETED
@@ -1,40 +0,0 @@
1
- import collections.abc
2
-
3
- from itertools import repeat
4
-
5
-
6
- def _ntuple(n):
7
- def parse(x):
8
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
- x = tuple(x)
10
- if len(x) == 1:
11
- x = tuple(repeat(x[0], n))
12
- return x
13
- return tuple(repeat(x, n))
14
- return parse
15
-
16
-
17
- to_1tuple = _ntuple(1)
18
- to_2tuple = _ntuple(2)
19
- to_3tuple = _ntuple(3)
20
- to_4tuple = _ntuple(4)
21
-
22
-
23
- def as_tuple(x):
24
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
- return tuple(x)
26
- if x is None or isinstance(x, (int, float, str)):
27
- return (x,)
28
- else:
29
- raise ValueError(f"Unknown type {type(x)}")
30
-
31
-
32
- def as_list_of_2tuple(x):
33
- x = as_tuple(x)
34
- if len(x) == 1:
35
- x = (x[0], x[0])
36
- assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
- lst = []
38
- for i in range(0, len(x), 2):
39
- lst.append((x[i], x[i + 1]))
40
- return lst
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py DELETED
@@ -1,46 +0,0 @@
1
- import argparse
2
- import torch
3
- from transformers import (
4
- AutoProcessor,
5
- LlavaForConditionalGeneration,
6
- )
7
-
8
-
9
- def preprocess_text_encoder_tokenizer(args):
10
-
11
- processor = AutoProcessor.from_pretrained(args.input_dir)
12
- model = LlavaForConditionalGeneration.from_pretrained(
13
- args.input_dir,
14
- torch_dtype=torch.float16,
15
- low_cpu_mem_usage=True,
16
- ).to(0)
17
-
18
- model.language_model.save_pretrained(
19
- f"{args.output_dir}"
20
- )
21
- processor.tokenizer.save_pretrained(
22
- f"{args.output_dir}"
23
- )
24
-
25
- if __name__ == "__main__":
26
-
27
- parser = argparse.ArgumentParser()
28
- parser.add_argument(
29
- "--input_dir",
30
- type=str,
31
- required=True,
32
- help="The path to the llava-llama-3-8b-v1_1-transformers.",
33
- )
34
- parser.add_argument(
35
- "--output_dir",
36
- type=str,
37
- default="",
38
- help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
39
- "if '', the parent dir of output will be the same as input dir.",
40
- )
41
- args = parser.parse_args()
42
-
43
- if len(args.output_dir) == 0:
44
- args.output_dir = "/".join(args.input_dir.split("/")[:-1])
45
-
46
- preprocess_text_encoder_tokenizer(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/vae/__init__.py DELETED
@@ -1,62 +0,0 @@
1
- from pathlib import Path
2
-
3
- import torch
4
-
5
- from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6
- from ..constants import VAE_PATH, PRECISION_TO_TYPE
7
-
8
- def load_vae(vae_type: str="884-16c-hy",
9
- vae_precision: str=None,
10
- sample_size: tuple=None,
11
- vae_path: str=None,
12
- logger=None,
13
- device=None
14
- ):
15
- """the fucntion to load the 3D VAE model
16
-
17
- Args:
18
- vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19
- vae_precision (str, optional): the precision to load vae. Defaults to None.
20
- sample_size (tuple, optional): the tiling size. Defaults to None.
21
- vae_path (str, optional): the path to vae. Defaults to None.
22
- logger (_type_, optional): logger. Defaults to None.
23
- device (_type_, optional): device to load vae. Defaults to None.
24
- """
25
- if vae_path is None:
26
- vae_path = VAE_PATH[vae_type]
27
-
28
- if logger is not None:
29
- logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30
- config = AutoencoderKLCausal3D.load_config(vae_path)
31
- if sample_size:
32
- vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33
- else:
34
- vae = AutoencoderKLCausal3D.from_config(config)
35
-
36
- vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37
- assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38
-
39
- ckpt = torch.load(vae_ckpt, map_location=vae.device)
40
- if "state_dict" in ckpt:
41
- ckpt = ckpt["state_dict"]
42
- if any(k.startswith("vae.") for k in ckpt.keys()):
43
- ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44
- vae.load_state_dict(ckpt)
45
-
46
- spatial_compression_ratio = vae.config.spatial_compression_ratio
47
- time_compression_ratio = vae.config.time_compression_ratio
48
-
49
- if vae_precision is not None:
50
- vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51
-
52
- vae.requires_grad_(False)
53
-
54
- if logger is not None:
55
- logger.info(f"VAE to dtype: {vae.dtype}")
56
-
57
- if device is not None:
58
- vae = vae.to(device)
59
-
60
- vae.eval()
61
-
62
- return vae, vae_path, spatial_compression_ratio, time_compression_ratio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/vae/autoencoder_kl_causal_3d.py DELETED
@@ -1,603 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Modified from diffusers==0.29.2
17
- #
18
- # ==============================================================================
19
- from typing import Dict, Optional, Tuple, Union
20
- from dataclasses import dataclass
21
-
22
- import torch
23
- import torch.nn as nn
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
-
27
- try:
28
- # This diffusers is modified and packed in the mirror.
29
- from diffusers.loaders import FromOriginalVAEMixin
30
- except ImportError:
31
- # Use this to be compatible with the original diffusers.
32
- from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
- from diffusers.utils.accelerate_utils import apply_forward_hook
34
- from diffusers.models.attention_processor import (
35
- ADDED_KV_ATTENTION_PROCESSORS,
36
- CROSS_ATTENTION_PROCESSORS,
37
- Attention,
38
- AttentionProcessor,
39
- AttnAddedKVProcessor,
40
- AttnProcessor,
41
- )
42
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
- from diffusers.models.modeling_utils import ModelMixin
44
- from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
-
46
-
47
- @dataclass
48
- class DecoderOutput2(BaseOutput):
49
- sample: torch.FloatTensor
50
- posterior: Optional[DiagonalGaussianDistribution] = None
51
-
52
-
53
- class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
- r"""
55
- A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
-
57
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
- for all models (such as downloading or saving).
59
- """
60
-
61
- _supports_gradient_checkpointing = True
62
-
63
- @register_to_config
64
- def __init__(
65
- self,
66
- in_channels: int = 3,
67
- out_channels: int = 3,
68
- down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
- up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
- block_out_channels: Tuple[int] = (64,),
71
- layers_per_block: int = 1,
72
- act_fn: str = "silu",
73
- latent_channels: int = 4,
74
- norm_num_groups: int = 32,
75
- sample_size: int = 32,
76
- sample_tsize: int = 64,
77
- scaling_factor: float = 0.18215,
78
- force_upcast: float = True,
79
- spatial_compression_ratio: int = 8,
80
- time_compression_ratio: int = 4,
81
- mid_block_add_attention: bool = True,
82
- ):
83
- super().__init__()
84
-
85
- self.time_compression_ratio = time_compression_ratio
86
-
87
- self.encoder = EncoderCausal3D(
88
- in_channels=in_channels,
89
- out_channels=latent_channels,
90
- down_block_types=down_block_types,
91
- block_out_channels=block_out_channels,
92
- layers_per_block=layers_per_block,
93
- act_fn=act_fn,
94
- norm_num_groups=norm_num_groups,
95
- double_z=True,
96
- time_compression_ratio=time_compression_ratio,
97
- spatial_compression_ratio=spatial_compression_ratio,
98
- mid_block_add_attention=mid_block_add_attention,
99
- )
100
-
101
- self.decoder = DecoderCausal3D(
102
- in_channels=latent_channels,
103
- out_channels=out_channels,
104
- up_block_types=up_block_types,
105
- block_out_channels=block_out_channels,
106
- layers_per_block=layers_per_block,
107
- norm_num_groups=norm_num_groups,
108
- act_fn=act_fn,
109
- time_compression_ratio=time_compression_ratio,
110
- spatial_compression_ratio=spatial_compression_ratio,
111
- mid_block_add_attention=mid_block_add_attention,
112
- )
113
-
114
- self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
- self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
-
117
- self.use_slicing = False
118
- self.use_spatial_tiling = False
119
- self.use_temporal_tiling = False
120
-
121
- # only relevant if vae tiling is enabled
122
- self.tile_sample_min_tsize = sample_tsize
123
- self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
-
125
- self.tile_sample_min_size = self.config.sample_size
126
- sample_size = (
127
- self.config.sample_size[0]
128
- if isinstance(self.config.sample_size, (list, tuple))
129
- else self.config.sample_size
130
- )
131
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
132
- self.tile_overlap_factor = 0.25
133
-
134
- def _set_gradient_checkpointing(self, module, value=False):
135
- if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
136
- module.gradient_checkpointing = value
137
-
138
- def enable_temporal_tiling(self, use_tiling: bool = True):
139
- self.use_temporal_tiling = use_tiling
140
-
141
- def disable_temporal_tiling(self):
142
- self.enable_temporal_tiling(False)
143
-
144
- def enable_spatial_tiling(self, use_tiling: bool = True):
145
- self.use_spatial_tiling = use_tiling
146
-
147
- def disable_spatial_tiling(self):
148
- self.enable_spatial_tiling(False)
149
-
150
- def enable_tiling(self, use_tiling: bool = True):
151
- r"""
152
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
153
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
154
- processing larger videos.
155
- """
156
- self.enable_spatial_tiling(use_tiling)
157
- self.enable_temporal_tiling(use_tiling)
158
-
159
- def disable_tiling(self):
160
- r"""
161
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
162
- decoding in one step.
163
- """
164
- self.disable_spatial_tiling()
165
- self.disable_temporal_tiling()
166
-
167
- def enable_slicing(self):
168
- r"""
169
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
170
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
171
- """
172
- self.use_slicing = True
173
-
174
- def disable_slicing(self):
175
- r"""
176
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
177
- decoding in one step.
178
- """
179
- self.use_slicing = False
180
-
181
- @property
182
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
183
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
184
- r"""
185
- Returns:
186
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
187
- indexed by its weight name.
188
- """
189
- # set recursively
190
- processors = {}
191
-
192
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
193
- if hasattr(module, "get_processor"):
194
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
195
-
196
- for sub_name, child in module.named_children():
197
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
198
-
199
- return processors
200
-
201
- for name, module in self.named_children():
202
- fn_recursive_add_processors(name, module, processors)
203
-
204
- return processors
205
-
206
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
207
- def set_attn_processor(
208
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
209
- ):
210
- r"""
211
- Sets the attention processor to use to compute attention.
212
-
213
- Parameters:
214
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
215
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
216
- for **all** `Attention` layers.
217
-
218
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
219
- processor. This is strongly recommended when setting trainable attention processors.
220
-
221
- """
222
- count = len(self.attn_processors.keys())
223
-
224
- if isinstance(processor, dict) and len(processor) != count:
225
- raise ValueError(
226
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
227
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
228
- )
229
-
230
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
231
- if hasattr(module, "set_processor"):
232
- if not isinstance(processor, dict):
233
- module.set_processor(processor, _remove_lora=_remove_lora)
234
- else:
235
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
236
-
237
- for sub_name, child in module.named_children():
238
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
239
-
240
- for name, module in self.named_children():
241
- fn_recursive_attn_processor(name, module, processor)
242
-
243
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
244
- def set_default_attn_processor(self):
245
- """
246
- Disables custom attention processors and sets the default attention implementation.
247
- """
248
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
249
- processor = AttnAddedKVProcessor()
250
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
- processor = AttnProcessor()
252
- else:
253
- raise ValueError(
254
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
255
- )
256
-
257
- self.set_attn_processor(processor, _remove_lora=True)
258
-
259
- @apply_forward_hook
260
- def encode(
261
- self, x: torch.FloatTensor, return_dict: bool = True
262
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
263
- """
264
- Encode a batch of images/videos into latents.
265
-
266
- Args:
267
- x (`torch.FloatTensor`): Input batch of images/videos.
268
- return_dict (`bool`, *optional*, defaults to `True`):
269
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
270
-
271
- Returns:
272
- The latent representations of the encoded images/videos. If `return_dict` is True, a
273
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
274
- """
275
- assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
276
-
277
- if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
278
- return self.temporal_tiled_encode(x, return_dict=return_dict)
279
-
280
- if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
281
- return self.spatial_tiled_encode(x, return_dict=return_dict)
282
-
283
- if self.use_slicing and x.shape[0] > 1:
284
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
285
- h = torch.cat(encoded_slices)
286
- else:
287
- h = self.encoder(x)
288
-
289
- moments = self.quant_conv(h)
290
- posterior = DiagonalGaussianDistribution(moments)
291
-
292
- if not return_dict:
293
- return (posterior,)
294
-
295
- return AutoencoderKLOutput(latent_dist=posterior)
296
-
297
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
298
- assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
299
-
300
- if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
301
- return self.temporal_tiled_decode(z, return_dict=return_dict)
302
-
303
- if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
304
- return self.spatial_tiled_decode(z, return_dict=return_dict)
305
-
306
- z = self.post_quant_conv(z)
307
- dec = self.decoder(z)
308
-
309
- if not return_dict:
310
- return (dec,)
311
-
312
- return DecoderOutput(sample=dec)
313
-
314
- @apply_forward_hook
315
- def decode(
316
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
317
- ) -> Union[DecoderOutput, torch.FloatTensor]:
318
- """
319
- Decode a batch of images/videos.
320
-
321
- Args:
322
- z (`torch.FloatTensor`): Input batch of latent vectors.
323
- return_dict (`bool`, *optional*, defaults to `True`):
324
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
-
326
- Returns:
327
- [`~models.vae.DecoderOutput`] or `tuple`:
328
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
- returned.
330
-
331
- """
332
- if self.use_slicing and z.shape[0] > 1:
333
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
- decoded = torch.cat(decoded_slices)
335
- else:
336
- decoded = self._decode(z).sample
337
-
338
- if not return_dict:
339
- return (decoded,)
340
-
341
- return DecoderOutput(sample=decoded)
342
-
343
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
- blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
- for y in range(blend_extent):
346
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
- return b
348
-
349
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
- blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
- for x in range(blend_extent):
352
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
- return b
354
-
355
- def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
- blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
- for x in range(blend_extent):
358
- b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
- return b
360
-
361
- def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
362
- r"""Encode a batch of images/videos using a tiled encoder.
363
-
364
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
365
- steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
366
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
367
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
368
- output, but they should be much less noticeable.
369
-
370
- Args:
371
- x (`torch.FloatTensor`): Input batch of images/videos.
372
- return_dict (`bool`, *optional*, defaults to `True`):
373
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
374
-
375
- Returns:
376
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
377
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
378
- `tuple` is returned.
379
- """
380
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
381
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
382
- row_limit = self.tile_latent_min_size - blend_extent
383
-
384
- # Split video into tiles and encode them separately.
385
- rows = []
386
- for i in range(0, x.shape[-2], overlap_size):
387
- row = []
388
- for j in range(0, x.shape[-1], overlap_size):
389
- tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
390
- tile = self.encoder(tile)
391
- tile = self.quant_conv(tile)
392
- row.append(tile)
393
- rows.append(row)
394
- result_rows = []
395
- for i, row in enumerate(rows):
396
- result_row = []
397
- for j, tile in enumerate(row):
398
- # blend the above tile and the left tile
399
- # to the current tile and add the current tile to the result row
400
- if i > 0:
401
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
402
- if j > 0:
403
- tile = self.blend_h(row[j - 1], tile, blend_extent)
404
- result_row.append(tile[:, :, :, :row_limit, :row_limit])
405
- result_rows.append(torch.cat(result_row, dim=-1))
406
-
407
- moments = torch.cat(result_rows, dim=-2)
408
- if return_moments:
409
- return moments
410
-
411
- posterior = DiagonalGaussianDistribution(moments)
412
- if not return_dict:
413
- return (posterior,)
414
-
415
- return AutoencoderKLOutput(latent_dist=posterior)
416
-
417
- def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
418
- r"""
419
- Decode a batch of images/videos using a tiled decoder.
420
-
421
- Args:
422
- z (`torch.FloatTensor`): Input batch of latent vectors.
423
- return_dict (`bool`, *optional*, defaults to `True`):
424
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
425
-
426
- Returns:
427
- [`~models.vae.DecoderOutput`] or `tuple`:
428
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
429
- returned.
430
- """
431
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
432
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
433
- row_limit = self.tile_sample_min_size - blend_extent
434
-
435
- # Split z into overlapping tiles and decode them separately.
436
- # The tiles have an overlap to avoid seams between tiles.
437
- rows = []
438
- for i in range(0, z.shape[-2], overlap_size):
439
- row = []
440
- for j in range(0, z.shape[-1], overlap_size):
441
- tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
442
- tile = self.post_quant_conv(tile)
443
- decoded = self.decoder(tile)
444
- row.append(decoded)
445
- rows.append(row)
446
- result_rows = []
447
- for i, row in enumerate(rows):
448
- result_row = []
449
- for j, tile in enumerate(row):
450
- # blend the above tile and the left tile
451
- # to the current tile and add the current tile to the result row
452
- if i > 0:
453
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
454
- if j > 0:
455
- tile = self.blend_h(row[j - 1], tile, blend_extent)
456
- result_row.append(tile[:, :, :, :row_limit, :row_limit])
457
- result_rows.append(torch.cat(result_row, dim=-1))
458
-
459
- dec = torch.cat(result_rows, dim=-2)
460
- if not return_dict:
461
- return (dec,)
462
-
463
- return DecoderOutput(sample=dec)
464
-
465
- def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
466
-
467
- B, C, T, H, W = x.shape
468
- overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
469
- blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
470
- t_limit = self.tile_latent_min_tsize - blend_extent
471
-
472
- # Split the video into tiles and encode them separately.
473
- row = []
474
- for i in range(0, T, overlap_size):
475
- tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
476
- if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
477
- tile = self.spatial_tiled_encode(tile, return_moments=True)
478
- else:
479
- tile = self.encoder(tile)
480
- tile = self.quant_conv(tile)
481
- if i > 0:
482
- tile = tile[:, :, 1:, :, :]
483
- row.append(tile)
484
- result_row = []
485
- for i, tile in enumerate(row):
486
- if i > 0:
487
- tile = self.blend_t(row[i - 1], tile, blend_extent)
488
- result_row.append(tile[:, :, :t_limit, :, :])
489
- else:
490
- result_row.append(tile[:, :, :t_limit + 1, :, :])
491
-
492
- moments = torch.cat(result_row, dim=2)
493
- posterior = DiagonalGaussianDistribution(moments)
494
-
495
- if not return_dict:
496
- return (posterior,)
497
-
498
- return AutoencoderKLOutput(latent_dist=posterior)
499
-
500
- def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
501
- # Split z into overlapping tiles and decode them separately.
502
-
503
- B, C, T, H, W = z.shape
504
- overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
505
- blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
506
- t_limit = self.tile_sample_min_tsize - blend_extent
507
-
508
- row = []
509
- for i in range(0, T, overlap_size):
510
- tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
511
- if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
512
- decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
513
- else:
514
- tile = self.post_quant_conv(tile)
515
- decoded = self.decoder(tile)
516
- if i > 0:
517
- decoded = decoded[:, :, 1:, :, :]
518
- row.append(decoded)
519
- result_row = []
520
- for i, tile in enumerate(row):
521
- if i > 0:
522
- tile = self.blend_t(row[i - 1], tile, blend_extent)
523
- result_row.append(tile[:, :, :t_limit, :, :])
524
- else:
525
- result_row.append(tile[:, :, :t_limit + 1, :, :])
526
-
527
- dec = torch.cat(result_row, dim=2)
528
- if not return_dict:
529
- return (dec,)
530
-
531
- return DecoderOutput(sample=dec)
532
-
533
- def forward(
534
- self,
535
- sample: torch.FloatTensor,
536
- sample_posterior: bool = False,
537
- return_dict: bool = True,
538
- return_posterior: bool = False,
539
- generator: Optional[torch.Generator] = None,
540
- ) -> Union[DecoderOutput2, torch.FloatTensor]:
541
- r"""
542
- Args:
543
- sample (`torch.FloatTensor`): Input sample.
544
- sample_posterior (`bool`, *optional*, defaults to `False`):
545
- Whether to sample from the posterior.
546
- return_dict (`bool`, *optional*, defaults to `True`):
547
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
548
- """
549
- x = sample
550
- posterior = self.encode(x).latent_dist
551
- if sample_posterior:
552
- z = posterior.sample(generator=generator)
553
- else:
554
- z = posterior.mode()
555
- dec = self.decode(z).sample
556
-
557
- if not return_dict:
558
- if return_posterior:
559
- return (dec, posterior)
560
- else:
561
- return (dec,)
562
- if return_posterior:
563
- return DecoderOutput2(sample=dec, posterior=posterior)
564
- else:
565
- return DecoderOutput2(sample=dec)
566
-
567
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
568
- def fuse_qkv_projections(self):
569
- """
570
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
571
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
572
-
573
- <Tip warning={true}>
574
-
575
- This API is 🧪 experimental.
576
-
577
- </Tip>
578
- """
579
- self.original_attn_processors = None
580
-
581
- for _, attn_processor in self.attn_processors.items():
582
- if "Added" in str(attn_processor.__class__.__name__):
583
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
584
-
585
- self.original_attn_processors = self.attn_processors
586
-
587
- for module in self.modules():
588
- if isinstance(module, Attention):
589
- module.fuse_projections(fuse=True)
590
-
591
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
592
- def unfuse_qkv_projections(self):
593
- """Disables the fused QKV projection if enabled.
594
-
595
- <Tip warning={true}>
596
-
597
- This API is 🧪 experimental.
598
-
599
- </Tip>
600
-
601
- """
602
- if self.original_attn_processors is not None:
603
- self.set_attn_processor(self.original_attn_processors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/vae/unet_causal_3d_blocks.py DELETED
@@ -1,764 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Modified from diffusers==0.29.2
17
- #
18
- # ==============================================================================
19
-
20
- from typing import Optional, Tuple, Union
21
-
22
- import torch
23
- import torch.nn.functional as F
24
- from torch import nn
25
- from einops import rearrange
26
-
27
- from diffusers.utils import logging
28
- from diffusers.models.activations import get_activation
29
- from diffusers.models.attention_processor import SpatialNorm
30
- from diffusers.models.attention_processor import Attention
31
- from diffusers.models.normalization import AdaGroupNorm
32
- from diffusers.models.normalization import RMSNorm
33
-
34
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
-
36
-
37
- def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38
- seq_len = n_frame * n_hw
39
- mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
40
- for i in range(seq_len):
41
- i_frame = i // n_hw
42
- mask[i, : (i_frame + 1) * n_hw] = 0
43
- if batch_size is not None:
44
- mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
45
- return mask
46
-
47
-
48
- class CausalConv3d(nn.Module):
49
- """
50
- Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
51
- This maintains temporal causality in video generation tasks.
52
- """
53
-
54
- def __init__(
55
- self,
56
- chan_in,
57
- chan_out,
58
- kernel_size: Union[int, Tuple[int, int, int]],
59
- stride: Union[int, Tuple[int, int, int]] = 1,
60
- dilation: Union[int, Tuple[int, int, int]] = 1,
61
- pad_mode='replicate',
62
- **kwargs
63
- ):
64
- super().__init__()
65
-
66
- self.pad_mode = pad_mode
67
- padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
68
- self.time_causal_padding = padding
69
-
70
- self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
71
-
72
- def forward(self, x):
73
- x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
74
- return self.conv(x)
75
-
76
-
77
- class UpsampleCausal3D(nn.Module):
78
- """
79
- A 3D upsampling layer with an optional convolution.
80
- """
81
-
82
- def __init__(
83
- self,
84
- channels: int,
85
- use_conv: bool = False,
86
- use_conv_transpose: bool = False,
87
- out_channels: Optional[int] = None,
88
- name: str = "conv",
89
- kernel_size: Optional[int] = None,
90
- padding=1,
91
- norm_type=None,
92
- eps=None,
93
- elementwise_affine=None,
94
- bias=True,
95
- interpolate=True,
96
- upsample_factor=(2, 2, 2),
97
- ):
98
- super().__init__()
99
- self.channels = channels
100
- self.out_channels = out_channels or channels
101
- self.use_conv = use_conv
102
- self.use_conv_transpose = use_conv_transpose
103
- self.name = name
104
- self.interpolate = interpolate
105
- self.upsample_factor = upsample_factor
106
-
107
- if norm_type == "ln_norm":
108
- self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
109
- elif norm_type == "rms_norm":
110
- self.norm = RMSNorm(channels, eps, elementwise_affine)
111
- elif norm_type is None:
112
- self.norm = None
113
- else:
114
- raise ValueError(f"unknown norm_type: {norm_type}")
115
-
116
- conv = None
117
- if use_conv_transpose:
118
- raise NotImplementedError
119
- elif use_conv:
120
- if kernel_size is None:
121
- kernel_size = 3
122
- conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
123
-
124
- if name == "conv":
125
- self.conv = conv
126
- else:
127
- self.Conv2d_0 = conv
128
-
129
- def forward(
130
- self,
131
- hidden_states: torch.FloatTensor,
132
- output_size: Optional[int] = None,
133
- scale: float = 1.0,
134
- ) -> torch.FloatTensor:
135
- assert hidden_states.shape[1] == self.channels
136
-
137
- if self.norm is not None:
138
- raise NotImplementedError
139
-
140
- if self.use_conv_transpose:
141
- return self.conv(hidden_states)
142
-
143
- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
- dtype = hidden_states.dtype
145
- if dtype == torch.bfloat16:
146
- hidden_states = hidden_states.to(torch.float32)
147
-
148
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
149
- if hidden_states.shape[0] >= 64:
150
- hidden_states = hidden_states.contiguous()
151
-
152
- # if `output_size` is passed we force the interpolation output
153
- # size and do not make use of `scale_factor=2`
154
- if self.interpolate:
155
- B, C, T, H, W = hidden_states.shape
156
- first_h, other_h = hidden_states.split((1, T - 1), dim=2)
157
- if output_size is None:
158
- if T > 1:
159
- other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
160
-
161
- first_h = first_h.squeeze(2)
162
- first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
163
- first_h = first_h.unsqueeze(2)
164
- else:
165
- raise NotImplementedError
166
-
167
- if T > 1:
168
- hidden_states = torch.cat((first_h, other_h), dim=2)
169
- else:
170
- hidden_states = first_h
171
-
172
- # If the input is bfloat16, we cast back to bfloat16
173
- if dtype == torch.bfloat16:
174
- hidden_states = hidden_states.to(dtype)
175
-
176
- if self.use_conv:
177
- if self.name == "conv":
178
- hidden_states = self.conv(hidden_states)
179
- else:
180
- hidden_states = self.Conv2d_0(hidden_states)
181
-
182
- return hidden_states
183
-
184
-
185
- class DownsampleCausal3D(nn.Module):
186
- """
187
- A 3D downsampling layer with an optional convolution.
188
- """
189
-
190
- def __init__(
191
- self,
192
- channels: int,
193
- use_conv: bool = False,
194
- out_channels: Optional[int] = None,
195
- padding: int = 1,
196
- name: str = "conv",
197
- kernel_size=3,
198
- norm_type=None,
199
- eps=None,
200
- elementwise_affine=None,
201
- bias=True,
202
- stride=2,
203
- ):
204
- super().__init__()
205
- self.channels = channels
206
- self.out_channels = out_channels or channels
207
- self.use_conv = use_conv
208
- self.padding = padding
209
- stride = stride
210
- self.name = name
211
-
212
- if norm_type == "ln_norm":
213
- self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
214
- elif norm_type == "rms_norm":
215
- self.norm = RMSNorm(channels, eps, elementwise_affine)
216
- elif norm_type is None:
217
- self.norm = None
218
- else:
219
- raise ValueError(f"unknown norm_type: {norm_type}")
220
-
221
- if use_conv:
222
- conv = CausalConv3d(
223
- self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
224
- )
225
- else:
226
- raise NotImplementedError
227
-
228
- if name == "conv":
229
- self.Conv2d_0 = conv
230
- self.conv = conv
231
- elif name == "Conv2d_0":
232
- self.conv = conv
233
- else:
234
- self.conv = conv
235
-
236
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
237
- assert hidden_states.shape[1] == self.channels
238
-
239
- if self.norm is not None:
240
- hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
241
-
242
- assert hidden_states.shape[1] == self.channels
243
-
244
- hidden_states = self.conv(hidden_states)
245
-
246
- return hidden_states
247
-
248
-
249
- class ResnetBlockCausal3D(nn.Module):
250
- r"""
251
- A Resnet block.
252
- """
253
-
254
- def __init__(
255
- self,
256
- *,
257
- in_channels: int,
258
- out_channels: Optional[int] = None,
259
- conv_shortcut: bool = False,
260
- dropout: float = 0.0,
261
- temb_channels: int = 512,
262
- groups: int = 32,
263
- groups_out: Optional[int] = None,
264
- pre_norm: bool = True,
265
- eps: float = 1e-6,
266
- non_linearity: str = "swish",
267
- skip_time_act: bool = False,
268
- # default, scale_shift, ada_group, spatial
269
- time_embedding_norm: str = "default",
270
- kernel: Optional[torch.FloatTensor] = None,
271
- output_scale_factor: float = 1.0,
272
- use_in_shortcut: Optional[bool] = None,
273
- up: bool = False,
274
- down: bool = False,
275
- conv_shortcut_bias: bool = True,
276
- conv_3d_out_channels: Optional[int] = None,
277
- ):
278
- super().__init__()
279
- self.pre_norm = pre_norm
280
- self.pre_norm = True
281
- self.in_channels = in_channels
282
- out_channels = in_channels if out_channels is None else out_channels
283
- self.out_channels = out_channels
284
- self.use_conv_shortcut = conv_shortcut
285
- self.up = up
286
- self.down = down
287
- self.output_scale_factor = output_scale_factor
288
- self.time_embedding_norm = time_embedding_norm
289
- self.skip_time_act = skip_time_act
290
-
291
- linear_cls = nn.Linear
292
-
293
- if groups_out is None:
294
- groups_out = groups
295
-
296
- if self.time_embedding_norm == "ada_group":
297
- self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
298
- elif self.time_embedding_norm == "spatial":
299
- self.norm1 = SpatialNorm(in_channels, temb_channels)
300
- else:
301
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
302
-
303
- self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
304
-
305
- if temb_channels is not None:
306
- if self.time_embedding_norm == "default":
307
- self.time_emb_proj = linear_cls(temb_channels, out_channels)
308
- elif self.time_embedding_norm == "scale_shift":
309
- self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
310
- elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
311
- self.time_emb_proj = None
312
- else:
313
- raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
314
- else:
315
- self.time_emb_proj = None
316
-
317
- if self.time_embedding_norm == "ada_group":
318
- self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
319
- elif self.time_embedding_norm == "spatial":
320
- self.norm2 = SpatialNorm(out_channels, temb_channels)
321
- else:
322
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
323
-
324
- self.dropout = torch.nn.Dropout(dropout)
325
- conv_3d_out_channels = conv_3d_out_channels or out_channels
326
- self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
327
-
328
- self.nonlinearity = get_activation(non_linearity)
329
-
330
- self.upsample = self.downsample = None
331
- if self.up:
332
- self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
333
- elif self.down:
334
- self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
335
-
336
- self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
337
-
338
- self.conv_shortcut = None
339
- if self.use_in_shortcut:
340
- self.conv_shortcut = CausalConv3d(
341
- in_channels,
342
- conv_3d_out_channels,
343
- kernel_size=1,
344
- stride=1,
345
- bias=conv_shortcut_bias,
346
- )
347
-
348
- def forward(
349
- self,
350
- input_tensor: torch.FloatTensor,
351
- temb: torch.FloatTensor,
352
- scale: float = 1.0,
353
- ) -> torch.FloatTensor:
354
- hidden_states = input_tensor
355
-
356
- if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
357
- hidden_states = self.norm1(hidden_states, temb)
358
- else:
359
- hidden_states = self.norm1(hidden_states)
360
-
361
- hidden_states = self.nonlinearity(hidden_states)
362
-
363
- if self.upsample is not None:
364
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
365
- if hidden_states.shape[0] >= 64:
366
- input_tensor = input_tensor.contiguous()
367
- hidden_states = hidden_states.contiguous()
368
- input_tensor = (
369
- self.upsample(input_tensor, scale=scale)
370
- )
371
- hidden_states = (
372
- self.upsample(hidden_states, scale=scale)
373
- )
374
- elif self.downsample is not None:
375
- input_tensor = (
376
- self.downsample(input_tensor, scale=scale)
377
- )
378
- hidden_states = (
379
- self.downsample(hidden_states, scale=scale)
380
- )
381
-
382
- hidden_states = self.conv1(hidden_states)
383
-
384
- if self.time_emb_proj is not None:
385
- if not self.skip_time_act:
386
- temb = self.nonlinearity(temb)
387
- temb = (
388
- self.time_emb_proj(temb, scale)[:, :, None, None]
389
- )
390
-
391
- if temb is not None and self.time_embedding_norm == "default":
392
- hidden_states = hidden_states + temb
393
-
394
- if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
395
- hidden_states = self.norm2(hidden_states, temb)
396
- else:
397
- hidden_states = self.norm2(hidden_states)
398
-
399
- if temb is not None and self.time_embedding_norm == "scale_shift":
400
- scale, shift = torch.chunk(temb, 2, dim=1)
401
- hidden_states = hidden_states * (1 + scale) + shift
402
-
403
- hidden_states = self.nonlinearity(hidden_states)
404
-
405
- hidden_states = self.dropout(hidden_states)
406
- hidden_states = self.conv2(hidden_states)
407
-
408
- if self.conv_shortcut is not None:
409
- input_tensor = (
410
- self.conv_shortcut(input_tensor)
411
- )
412
-
413
- output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
414
-
415
- return output_tensor
416
-
417
-
418
- def get_down_block3d(
419
- down_block_type: str,
420
- num_layers: int,
421
- in_channels: int,
422
- out_channels: int,
423
- temb_channels: int,
424
- add_downsample: bool,
425
- downsample_stride: int,
426
- resnet_eps: float,
427
- resnet_act_fn: str,
428
- transformer_layers_per_block: int = 1,
429
- num_attention_heads: Optional[int] = None,
430
- resnet_groups: Optional[int] = None,
431
- cross_attention_dim: Optional[int] = None,
432
- downsample_padding: Optional[int] = None,
433
- dual_cross_attention: bool = False,
434
- use_linear_projection: bool = False,
435
- only_cross_attention: bool = False,
436
- upcast_attention: bool = False,
437
- resnet_time_scale_shift: str = "default",
438
- attention_type: str = "default",
439
- resnet_skip_time_act: bool = False,
440
- resnet_out_scale_factor: float = 1.0,
441
- cross_attention_norm: Optional[str] = None,
442
- attention_head_dim: Optional[int] = None,
443
- downsample_type: Optional[str] = None,
444
- dropout: float = 0.0,
445
- ):
446
- # If attn head dim is not defined, we default it to the number of heads
447
- if attention_head_dim is None:
448
- logger.warn(
449
- f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
450
- )
451
- attention_head_dim = num_attention_heads
452
-
453
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
454
- if down_block_type == "DownEncoderBlockCausal3D":
455
- return DownEncoderBlockCausal3D(
456
- num_layers=num_layers,
457
- in_channels=in_channels,
458
- out_channels=out_channels,
459
- dropout=dropout,
460
- add_downsample=add_downsample,
461
- downsample_stride=downsample_stride,
462
- resnet_eps=resnet_eps,
463
- resnet_act_fn=resnet_act_fn,
464
- resnet_groups=resnet_groups,
465
- downsample_padding=downsample_padding,
466
- resnet_time_scale_shift=resnet_time_scale_shift,
467
- )
468
- raise ValueError(f"{down_block_type} does not exist.")
469
-
470
-
471
- def get_up_block3d(
472
- up_block_type: str,
473
- num_layers: int,
474
- in_channels: int,
475
- out_channels: int,
476
- prev_output_channel: int,
477
- temb_channels: int,
478
- add_upsample: bool,
479
- upsample_scale_factor: Tuple,
480
- resnet_eps: float,
481
- resnet_act_fn: str,
482
- resolution_idx: Optional[int] = None,
483
- transformer_layers_per_block: int = 1,
484
- num_attention_heads: Optional[int] = None,
485
- resnet_groups: Optional[int] = None,
486
- cross_attention_dim: Optional[int] = None,
487
- dual_cross_attention: bool = False,
488
- use_linear_projection: bool = False,
489
- only_cross_attention: bool = False,
490
- upcast_attention: bool = False,
491
- resnet_time_scale_shift: str = "default",
492
- attention_type: str = "default",
493
- resnet_skip_time_act: bool = False,
494
- resnet_out_scale_factor: float = 1.0,
495
- cross_attention_norm: Optional[str] = None,
496
- attention_head_dim: Optional[int] = None,
497
- upsample_type: Optional[str] = None,
498
- dropout: float = 0.0,
499
- ) -> nn.Module:
500
- # If attn head dim is not defined, we default it to the number of heads
501
- if attention_head_dim is None:
502
- logger.warn(
503
- f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
504
- )
505
- attention_head_dim = num_attention_heads
506
-
507
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
508
- if up_block_type == "UpDecoderBlockCausal3D":
509
- return UpDecoderBlockCausal3D(
510
- num_layers=num_layers,
511
- in_channels=in_channels,
512
- out_channels=out_channels,
513
- resolution_idx=resolution_idx,
514
- dropout=dropout,
515
- add_upsample=add_upsample,
516
- upsample_scale_factor=upsample_scale_factor,
517
- resnet_eps=resnet_eps,
518
- resnet_act_fn=resnet_act_fn,
519
- resnet_groups=resnet_groups,
520
- resnet_time_scale_shift=resnet_time_scale_shift,
521
- temb_channels=temb_channels,
522
- )
523
- raise ValueError(f"{up_block_type} does not exist.")
524
-
525
-
526
- class UNetMidBlockCausal3D(nn.Module):
527
- """
528
- A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
529
- """
530
-
531
- def __init__(
532
- self,
533
- in_channels: int,
534
- temb_channels: int,
535
- dropout: float = 0.0,
536
- num_layers: int = 1,
537
- resnet_eps: float = 1e-6,
538
- resnet_time_scale_shift: str = "default", # default, spatial
539
- resnet_act_fn: str = "swish",
540
- resnet_groups: int = 32,
541
- attn_groups: Optional[int] = None,
542
- resnet_pre_norm: bool = True,
543
- add_attention: bool = True,
544
- attention_head_dim: int = 1,
545
- output_scale_factor: float = 1.0,
546
- ):
547
- super().__init__()
548
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
549
- self.add_attention = add_attention
550
-
551
- if attn_groups is None:
552
- attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
553
-
554
- # there is always at least one resnet
555
- resnets = [
556
- ResnetBlockCausal3D(
557
- in_channels=in_channels,
558
- out_channels=in_channels,
559
- temb_channels=temb_channels,
560
- eps=resnet_eps,
561
- groups=resnet_groups,
562
- dropout=dropout,
563
- time_embedding_norm=resnet_time_scale_shift,
564
- non_linearity=resnet_act_fn,
565
- output_scale_factor=output_scale_factor,
566
- pre_norm=resnet_pre_norm,
567
- )
568
- ]
569
- attentions = []
570
-
571
- if attention_head_dim is None:
572
- logger.warn(
573
- f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
574
- )
575
- attention_head_dim = in_channels
576
-
577
- for _ in range(num_layers):
578
- if self.add_attention:
579
- attentions.append(
580
- Attention(
581
- in_channels,
582
- heads=in_channels // attention_head_dim,
583
- dim_head=attention_head_dim,
584
- rescale_output_factor=output_scale_factor,
585
- eps=resnet_eps,
586
- norm_num_groups=attn_groups,
587
- spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
588
- residual_connection=True,
589
- bias=True,
590
- upcast_softmax=True,
591
- _from_deprecated_attn_block=True,
592
- )
593
- )
594
- else:
595
- attentions.append(None)
596
-
597
- resnets.append(
598
- ResnetBlockCausal3D(
599
- in_channels=in_channels,
600
- out_channels=in_channels,
601
- temb_channels=temb_channels,
602
- eps=resnet_eps,
603
- groups=resnet_groups,
604
- dropout=dropout,
605
- time_embedding_norm=resnet_time_scale_shift,
606
- non_linearity=resnet_act_fn,
607
- output_scale_factor=output_scale_factor,
608
- pre_norm=resnet_pre_norm,
609
- )
610
- )
611
-
612
- self.attentions = nn.ModuleList(attentions)
613
- self.resnets = nn.ModuleList(resnets)
614
-
615
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
616
- hidden_states = self.resnets[0](hidden_states, temb)
617
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
- if attn is not None:
619
- B, C, T, H, W = hidden_states.shape
620
- hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
621
- attention_mask = prepare_causal_attention_mask(
622
- T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
623
- )
624
- hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
625
- hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
626
- hidden_states = resnet(hidden_states, temb)
627
-
628
- return hidden_states
629
-
630
-
631
- class DownEncoderBlockCausal3D(nn.Module):
632
- def __init__(
633
- self,
634
- in_channels: int,
635
- out_channels: int,
636
- dropout: float = 0.0,
637
- num_layers: int = 1,
638
- resnet_eps: float = 1e-6,
639
- resnet_time_scale_shift: str = "default",
640
- resnet_act_fn: str = "swish",
641
- resnet_groups: int = 32,
642
- resnet_pre_norm: bool = True,
643
- output_scale_factor: float = 1.0,
644
- add_downsample: bool = True,
645
- downsample_stride: int = 2,
646
- downsample_padding: int = 1,
647
- ):
648
- super().__init__()
649
- resnets = []
650
-
651
- for i in range(num_layers):
652
- in_channels = in_channels if i == 0 else out_channels
653
- resnets.append(
654
- ResnetBlockCausal3D(
655
- in_channels=in_channels,
656
- out_channels=out_channels,
657
- temb_channels=None,
658
- eps=resnet_eps,
659
- groups=resnet_groups,
660
- dropout=dropout,
661
- time_embedding_norm=resnet_time_scale_shift,
662
- non_linearity=resnet_act_fn,
663
- output_scale_factor=output_scale_factor,
664
- pre_norm=resnet_pre_norm,
665
- )
666
- )
667
-
668
- self.resnets = nn.ModuleList(resnets)
669
-
670
- if add_downsample:
671
- self.downsamplers = nn.ModuleList(
672
- [
673
- DownsampleCausal3D(
674
- out_channels,
675
- use_conv=True,
676
- out_channels=out_channels,
677
- padding=downsample_padding,
678
- name="op",
679
- stride=downsample_stride,
680
- )
681
- ]
682
- )
683
- else:
684
- self.downsamplers = None
685
-
686
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
687
- for resnet in self.resnets:
688
- hidden_states = resnet(hidden_states, temb=None, scale=scale)
689
-
690
- if self.downsamplers is not None:
691
- for downsampler in self.downsamplers:
692
- hidden_states = downsampler(hidden_states, scale)
693
-
694
- return hidden_states
695
-
696
-
697
- class UpDecoderBlockCausal3D(nn.Module):
698
- def __init__(
699
- self,
700
- in_channels: int,
701
- out_channels: int,
702
- resolution_idx: Optional[int] = None,
703
- dropout: float = 0.0,
704
- num_layers: int = 1,
705
- resnet_eps: float = 1e-6,
706
- resnet_time_scale_shift: str = "default", # default, spatial
707
- resnet_act_fn: str = "swish",
708
- resnet_groups: int = 32,
709
- resnet_pre_norm: bool = True,
710
- output_scale_factor: float = 1.0,
711
- add_upsample: bool = True,
712
- upsample_scale_factor=(2, 2, 2),
713
- temb_channels: Optional[int] = None,
714
- ):
715
- super().__init__()
716
- resnets = []
717
-
718
- for i in range(num_layers):
719
- input_channels = in_channels if i == 0 else out_channels
720
-
721
- resnets.append(
722
- ResnetBlockCausal3D(
723
- in_channels=input_channels,
724
- out_channels=out_channels,
725
- temb_channels=temb_channels,
726
- eps=resnet_eps,
727
- groups=resnet_groups,
728
- dropout=dropout,
729
- time_embedding_norm=resnet_time_scale_shift,
730
- non_linearity=resnet_act_fn,
731
- output_scale_factor=output_scale_factor,
732
- pre_norm=resnet_pre_norm,
733
- )
734
- )
735
-
736
- self.resnets = nn.ModuleList(resnets)
737
-
738
- if add_upsample:
739
- self.upsamplers = nn.ModuleList(
740
- [
741
- UpsampleCausal3D(
742
- out_channels,
743
- use_conv=True,
744
- out_channels=out_channels,
745
- upsample_factor=upsample_scale_factor,
746
- )
747
- ]
748
- )
749
- else:
750
- self.upsamplers = None
751
-
752
- self.resolution_idx = resolution_idx
753
-
754
- def forward(
755
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
756
- ) -> torch.FloatTensor:
757
- for resnet in self.resnets:
758
- hidden_states = resnet(hidden_states, temb=temb, scale=scale)
759
-
760
- if self.upsamplers is not None:
761
- for upsampler in self.upsamplers:
762
- hidden_states = upsampler(hidden_states)
763
-
764
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyvideo/vae/vae.py DELETED
@@ -1,355 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
-
8
- from diffusers.utils import BaseOutput, is_torch_version
9
- from diffusers.utils.torch_utils import randn_tensor
10
- from diffusers.models.attention_processor import SpatialNorm
11
- from .unet_causal_3d_blocks import (
12
- CausalConv3d,
13
- UNetMidBlockCausal3D,
14
- get_down_block3d,
15
- get_up_block3d,
16
- )
17
-
18
-
19
- @dataclass
20
- class DecoderOutput(BaseOutput):
21
- r"""
22
- Output of decoding method.
23
-
24
- Args:
25
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26
- The decoded output sample from the last layer of the model.
27
- """
28
-
29
- sample: torch.FloatTensor
30
-
31
-
32
- class EncoderCausal3D(nn.Module):
33
- r"""
34
- The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
35
- """
36
-
37
- def __init__(
38
- self,
39
- in_channels: int = 3,
40
- out_channels: int = 3,
41
- down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
42
- block_out_channels: Tuple[int, ...] = (64,),
43
- layers_per_block: int = 2,
44
- norm_num_groups: int = 32,
45
- act_fn: str = "silu",
46
- double_z: bool = True,
47
- mid_block_add_attention=True,
48
- time_compression_ratio: int = 4,
49
- spatial_compression_ratio: int = 8,
50
- ):
51
- super().__init__()
52
- self.layers_per_block = layers_per_block
53
-
54
- self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
55
- self.mid_block = None
56
- self.down_blocks = nn.ModuleList([])
57
-
58
- # down
59
- output_channel = block_out_channels[0]
60
- for i, down_block_type in enumerate(down_block_types):
61
- input_channel = output_channel
62
- output_channel = block_out_channels[i]
63
- is_final_block = i == len(block_out_channels) - 1
64
- num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
65
- num_time_downsample_layers = int(np.log2(time_compression_ratio))
66
-
67
- if time_compression_ratio == 4:
68
- add_spatial_downsample = bool(i < num_spatial_downsample_layers)
69
- add_time_downsample = bool(
70
- i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
71
- and not is_final_block
72
- )
73
- else:
74
- raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
75
-
76
- downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
77
- downsample_stride_T = (2,) if add_time_downsample else (1,)
78
- downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
79
- down_block = get_down_block3d(
80
- down_block_type,
81
- num_layers=self.layers_per_block,
82
- in_channels=input_channel,
83
- out_channels=output_channel,
84
- add_downsample=bool(add_spatial_downsample or add_time_downsample),
85
- downsample_stride=downsample_stride,
86
- resnet_eps=1e-6,
87
- downsample_padding=0,
88
- resnet_act_fn=act_fn,
89
- resnet_groups=norm_num_groups,
90
- attention_head_dim=output_channel,
91
- temb_channels=None,
92
- )
93
- self.down_blocks.append(down_block)
94
-
95
- # mid
96
- self.mid_block = UNetMidBlockCausal3D(
97
- in_channels=block_out_channels[-1],
98
- resnet_eps=1e-6,
99
- resnet_act_fn=act_fn,
100
- output_scale_factor=1,
101
- resnet_time_scale_shift="default",
102
- attention_head_dim=block_out_channels[-1],
103
- resnet_groups=norm_num_groups,
104
- temb_channels=None,
105
- add_attention=mid_block_add_attention,
106
- )
107
-
108
- # out
109
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
110
- self.conv_act = nn.SiLU()
111
-
112
- conv_out_channels = 2 * out_channels if double_z else out_channels
113
- self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
114
-
115
- def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
116
- r"""The forward method of the `EncoderCausal3D` class."""
117
- assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
118
-
119
- sample = self.conv_in(sample)
120
-
121
- # down
122
- for down_block in self.down_blocks:
123
- sample = down_block(sample)
124
-
125
- # middle
126
- sample = self.mid_block(sample)
127
-
128
- # post-process
129
- sample = self.conv_norm_out(sample)
130
- sample = self.conv_act(sample)
131
- sample = self.conv_out(sample)
132
-
133
- return sample
134
-
135
-
136
- class DecoderCausal3D(nn.Module):
137
- r"""
138
- The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
139
- """
140
-
141
- def __init__(
142
- self,
143
- in_channels: int = 3,
144
- out_channels: int = 3,
145
- up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
146
- block_out_channels: Tuple[int, ...] = (64,),
147
- layers_per_block: int = 2,
148
- norm_num_groups: int = 32,
149
- act_fn: str = "silu",
150
- norm_type: str = "group", # group, spatial
151
- mid_block_add_attention=True,
152
- time_compression_ratio: int = 4,
153
- spatial_compression_ratio: int = 8,
154
- ):
155
- super().__init__()
156
- self.layers_per_block = layers_per_block
157
-
158
- self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
159
- self.mid_block = None
160
- self.up_blocks = nn.ModuleList([])
161
-
162
- temb_channels = in_channels if norm_type == "spatial" else None
163
-
164
- # mid
165
- self.mid_block = UNetMidBlockCausal3D(
166
- in_channels=block_out_channels[-1],
167
- resnet_eps=1e-6,
168
- resnet_act_fn=act_fn,
169
- output_scale_factor=1,
170
- resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
171
- attention_head_dim=block_out_channels[-1],
172
- resnet_groups=norm_num_groups,
173
- temb_channels=temb_channels,
174
- add_attention=mid_block_add_attention,
175
- )
176
-
177
- # up
178
- reversed_block_out_channels = list(reversed(block_out_channels))
179
- output_channel = reversed_block_out_channels[0]
180
- for i, up_block_type in enumerate(up_block_types):
181
- prev_output_channel = output_channel
182
- output_channel = reversed_block_out_channels[i]
183
- is_final_block = i == len(block_out_channels) - 1
184
- num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
185
- num_time_upsample_layers = int(np.log2(time_compression_ratio))
186
-
187
- if time_compression_ratio == 4:
188
- add_spatial_upsample = bool(i < num_spatial_upsample_layers)
189
- add_time_upsample = bool(
190
- i >= len(block_out_channels) - 1 - num_time_upsample_layers
191
- and not is_final_block
192
- )
193
- else:
194
- raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
195
-
196
- upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
197
- upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
198
- upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
199
- up_block = get_up_block3d(
200
- up_block_type,
201
- num_layers=self.layers_per_block + 1,
202
- in_channels=prev_output_channel,
203
- out_channels=output_channel,
204
- prev_output_channel=None,
205
- add_upsample=bool(add_spatial_upsample or add_time_upsample),
206
- upsample_scale_factor=upsample_scale_factor,
207
- resnet_eps=1e-6,
208
- resnet_act_fn=act_fn,
209
- resnet_groups=norm_num_groups,
210
- attention_head_dim=output_channel,
211
- temb_channels=temb_channels,
212
- resnet_time_scale_shift=norm_type,
213
- )
214
- self.up_blocks.append(up_block)
215
- prev_output_channel = output_channel
216
-
217
- # out
218
- if norm_type == "spatial":
219
- self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
220
- else:
221
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
222
- self.conv_act = nn.SiLU()
223
- self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
224
-
225
- self.gradient_checkpointing = False
226
-
227
- def forward(
228
- self,
229
- sample: torch.FloatTensor,
230
- latent_embeds: Optional[torch.FloatTensor] = None,
231
- ) -> torch.FloatTensor:
232
- r"""The forward method of the `DecoderCausal3D` class."""
233
- assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
234
-
235
- sample = self.conv_in(sample)
236
-
237
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
238
- if self.training and self.gradient_checkpointing:
239
-
240
- def create_custom_forward(module):
241
- def custom_forward(*inputs):
242
- return module(*inputs)
243
-
244
- return custom_forward
245
-
246
- if is_torch_version(">=", "1.11.0"):
247
- # middle
248
- sample = torch.utils.checkpoint.checkpoint(
249
- create_custom_forward(self.mid_block),
250
- sample,
251
- latent_embeds,
252
- use_reentrant=False,
253
- )
254
- sample = sample.to(upscale_dtype)
255
-
256
- # up
257
- for up_block in self.up_blocks:
258
- sample = torch.utils.checkpoint.checkpoint(
259
- create_custom_forward(up_block),
260
- sample,
261
- latent_embeds,
262
- use_reentrant=False,
263
- )
264
- else:
265
- # middle
266
- sample = torch.utils.checkpoint.checkpoint(
267
- create_custom_forward(self.mid_block), sample, latent_embeds
268
- )
269
- sample = sample.to(upscale_dtype)
270
-
271
- # up
272
- for up_block in self.up_blocks:
273
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
274
- else:
275
- # middle
276
- sample = self.mid_block(sample, latent_embeds)
277
- sample = sample.to(upscale_dtype)
278
-
279
- # up
280
- for up_block in self.up_blocks:
281
- sample = up_block(sample, latent_embeds)
282
-
283
- # post-process
284
- if latent_embeds is None:
285
- sample = self.conv_norm_out(sample)
286
- else:
287
- sample = self.conv_norm_out(sample, latent_embeds)
288
- sample = self.conv_act(sample)
289
- sample = self.conv_out(sample)
290
-
291
- return sample
292
-
293
-
294
- class DiagonalGaussianDistribution(object):
295
- def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
296
- if parameters.ndim == 3:
297
- dim = 2 # (B, L, C)
298
- elif parameters.ndim == 5 or parameters.ndim == 4:
299
- dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
300
- else:
301
- raise NotImplementedError
302
- self.parameters = parameters
303
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
304
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
305
- self.deterministic = deterministic
306
- self.std = torch.exp(0.5 * self.logvar)
307
- self.var = torch.exp(self.logvar)
308
- if self.deterministic:
309
- self.var = self.std = torch.zeros_like(
310
- self.mean, device=self.parameters.device, dtype=self.parameters.dtype
311
- )
312
-
313
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
314
- # make sure sample is on the same device as the parameters and has same dtype
315
- sample = randn_tensor(
316
- self.mean.shape,
317
- generator=generator,
318
- device=self.parameters.device,
319
- dtype=self.parameters.dtype,
320
- )
321
- x = self.mean + self.std * sample
322
- return x
323
-
324
- def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
325
- if self.deterministic:
326
- return torch.Tensor([0.0])
327
- else:
328
- reduce_dim = list(range(1, self.mean.ndim))
329
- if other is None:
330
- return 0.5 * torch.sum(
331
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
332
- dim=reduce_dim,
333
- )
334
- else:
335
- return 0.5 * torch.sum(
336
- torch.pow(self.mean - other.mean, 2) / other.var
337
- + self.var / other.var
338
- - 1.0
339
- - self.logvar
340
- + other.logvar,
341
- dim=reduce_dim,
342
- )
343
-
344
- def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
345
- if self.deterministic:
346
- return torch.Tensor([0.0])
347
- logtwopi = np.log(2.0 * np.pi)
348
- return 0.5 * torch.sum(
349
- logtwopi + self.logvar +
350
- torch.pow(sample - self.mean, 2) / self.var,
351
- dim=dims,
352
- )
353
-
354
- def mode(self) -> torch.Tensor:
355
- return self.mean