studyOverflow commited on
Commit
6189cd3
·
verified ·
1 Parent(s): 70df87e

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. data/StableDiffusion/safety_checker/config.json +171 -0
  2. data/StableDiffusion/scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json +9 -0
  3. data/StableDiffusion/unet/config.json +36 -0
  4. fastvideo/models/flux_hf/pipeline_flux.py +988 -0
  5. fastvideo/models/hunyuan/__init__.py +0 -0
  6. fastvideo/models/hunyuan/constants.py +89 -0
  7. fastvideo/models/hunyuan/idle_config.py +415 -0
  8. fastvideo/models/hunyuan/inference.py +534 -0
  9. fastvideo/models/hunyuan/modules/mlp_layers.py +133 -0
  10. fastvideo/models/hunyuan/prompt_rewrite.py +52 -0
  11. fastvideo/models/hunyuan_hf/__pycache__/modeling_hunyuan.cpython-310.pyc +0 -0
  12. fastvideo/models/hunyuan_hf/__pycache__/modeling_hunyuan.cpython-312.pyc +0 -0
  13. fastvideo/models/hunyuan_hf/modeling_hunyuan.py +952 -0
  14. fastvideo/models/hunyuan_hf/pipeline_hunyuan.py +756 -0
  15. fastvideo/models/mochi_hf/__pycache__/modeling_mochi.cpython-310.pyc +0 -0
  16. fastvideo/models/mochi_hf/__pycache__/modeling_mochi.cpython-312.pyc +0 -0
  17. fastvideo/models/mochi_hf/__pycache__/norm.cpython-310.pyc +0 -0
  18. fastvideo/models/mochi_hf/__pycache__/norm.cpython-312.pyc +0 -0
  19. fastvideo/models/mochi_hf/__pycache__/pipeline_mochi.cpython-312.pyc +0 -0
  20. fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py +502 -0
  21. fastvideo/models/mochi_hf/mochi_latents_utils.py +47 -0
  22. fastvideo/models/mochi_hf/modeling_mochi.py +729 -0
  23. fastvideo/models/mochi_hf/norm.py +132 -0
  24. fastvideo/models/mochi_hf/pipeline_mochi.py +829 -0
  25. fastvideo/models/qwenimage/__init__.py +0 -0
  26. fastvideo/models/qwenimage/autoencoder_kl_qwenimage.py +1070 -0
  27. fastvideo/models/qwenimage/pipeline_output.py +21 -0
  28. fastvideo/models/qwenimage/pipeline_qwenimage.py +727 -0
  29. fastvideo/models/qwenimage/transformer_qwenimage.py +645 -0
  30. fastvideo/models/stable_diffusion/ddim_with_logprob.py +215 -0
  31. fastvideo/models/stable_diffusion/ddim_with_logprob_v6.py +201 -0
  32. fastvideo/models/stable_diffusion/ddim_with_logprob_v6_2.py +200 -0
  33. fastvideo/models/stable_diffusion/ddim_with_logprob_v6_8.py +201 -0
  34. fastvideo/models/stable_diffusion/ddim_with_logprob_v8.py +201 -0
  35. fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0.py +201 -0
  36. fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0_2.py +201 -0
  37. fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0_v7.py +221 -0
  38. fastvideo/models/stable_diffusion/ddim_with_logprob_wo_eta.py +200 -0
  39. fastvideo/models/stable_diffusion/pipeline_with_logprob.py +250 -0
  40. fastvideo/models/stable_diffusion/pipeline_with_logprob_p1.py +258 -0
  41. fastvideo/models/stable_diffusion/pipeline_with_logprob_p2.py +324 -0
  42. fastvideo/models/stable_diffusion/pipeline_with_logprob_prefix.py +256 -0
  43. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta.py +261 -0
  44. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_bid.py +326 -0
  45. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_mask.py +267 -0
  46. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_mask2.py +267 -0
  47. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_v7.py +267 -0
  48. fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_v8.py +270 -0
  49. fastvideo/models/stable_diffusion/pipeline_with_logprob_wo_eta.py +252 -0
  50. fastvideo/models/stable_diffusion/pipeline_with_logprob_wo_eta_2.py +257 -0
data/StableDiffusion/safety_checker/config.json ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./safety_module",
3
+ "architectures": [
4
+ "StableDiffusionSafetyChecker"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 768,
10
+ "text_config": {
11
+ "_name_or_path": "",
12
+ "add_cross_attention": false,
13
+ "architectures": null,
14
+ "attention_dropout": 0.0,
15
+ "bad_words_ids": null,
16
+ "bos_token_id": 0,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "dropout": 0.0,
23
+ "early_stopping": false,
24
+ "encoder_no_repeat_ngram_size": 0,
25
+ "eos_token_id": 2,
26
+ "exponential_decay_length_penalty": null,
27
+ "finetuning_task": null,
28
+ "forced_bos_token_id": null,
29
+ "forced_eos_token_id": null,
30
+ "hidden_act": "quick_gelu",
31
+ "hidden_size": 768,
32
+ "id2label": {
33
+ "0": "LABEL_0",
34
+ "1": "LABEL_1"
35
+ },
36
+ "initializer_factor": 1.0,
37
+ "initializer_range": 0.02,
38
+ "intermediate_size": 3072,
39
+ "is_decoder": false,
40
+ "is_encoder_decoder": false,
41
+ "label2id": {
42
+ "LABEL_0": 0,
43
+ "LABEL_1": 1
44
+ },
45
+ "layer_norm_eps": 1e-05,
46
+ "length_penalty": 1.0,
47
+ "max_length": 20,
48
+ "max_position_embeddings": 77,
49
+ "min_length": 0,
50
+ "model_type": "clip_text_model",
51
+ "no_repeat_ngram_size": 0,
52
+ "num_attention_heads": 12,
53
+ "num_beam_groups": 1,
54
+ "num_beams": 1,
55
+ "num_hidden_layers": 12,
56
+ "num_return_sequences": 1,
57
+ "output_attentions": false,
58
+ "output_hidden_states": false,
59
+ "output_scores": false,
60
+ "pad_token_id": 1,
61
+ "prefix": null,
62
+ "problem_type": null,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.21.0.dev0",
79
+ "typical_p": 1.0,
80
+ "use_bfloat16": false,
81
+ "vocab_size": 49408
82
+ },
83
+ "text_config_dict": {
84
+ "hidden_size": 768,
85
+ "intermediate_size": 3072,
86
+ "num_attention_heads": 12,
87
+ "num_hidden_layers": 12
88
+ },
89
+ "torch_dtype": "float32",
90
+ "transformers_version": null,
91
+ "vision_config": {
92
+ "_name_or_path": "",
93
+ "add_cross_attention": false,
94
+ "architectures": null,
95
+ "attention_dropout": 0.0,
96
+ "bad_words_ids": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "exponential_decay_length_penalty": null,
108
+ "finetuning_task": null,
109
+ "forced_bos_token_id": null,
110
+ "forced_eos_token_id": null,
111
+ "hidden_act": "quick_gelu",
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_factor": 1.0,
119
+ "initializer_range": 0.02,
120
+ "intermediate_size": 4096,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layer_norm_eps": 1e-05,
128
+ "length_penalty": 1.0,
129
+ "max_length": 20,
130
+ "min_length": 0,
131
+ "model_type": "clip_vision_model",
132
+ "no_repeat_ngram_size": 0,
133
+ "num_attention_heads": 16,
134
+ "num_beam_groups": 1,
135
+ "num_beams": 1,
136
+ "num_hidden_layers": 24,
137
+ "num_return_sequences": 1,
138
+ "output_attentions": false,
139
+ "output_hidden_states": false,
140
+ "output_scores": false,
141
+ "pad_token_id": null,
142
+ "patch_size": 14,
143
+ "prefix": null,
144
+ "problem_type": null,
145
+ "pruned_heads": {},
146
+ "remove_invalid_values": false,
147
+ "repetition_penalty": 1.0,
148
+ "return_dict": true,
149
+ "return_dict_in_generate": false,
150
+ "sep_token_id": null,
151
+ "task_specific_params": null,
152
+ "temperature": 1.0,
153
+ "tie_encoder_decoder": false,
154
+ "tie_word_embeddings": true,
155
+ "tokenizer_class": null,
156
+ "top_k": 50,
157
+ "top_p": 1.0,
158
+ "torch_dtype": null,
159
+ "torchscript": false,
160
+ "transformers_version": "4.21.0.dev0",
161
+ "typical_p": 1.0,
162
+ "use_bfloat16": false
163
+ },
164
+ "vision_config_dict": {
165
+ "hidden_size": 1024,
166
+ "intermediate_size": 4096,
167
+ "num_attention_heads": 16,
168
+ "num_hidden_layers": 24,
169
+ "patch_size": 14
170
+ }
171
+ }
data/StableDiffusion/scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.2.2",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "num_train_timesteps": 1000,
8
+ "skip_prk_steps": true
9
+ }
data/StableDiffusion/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.2.2",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
fastvideo/models/flux_hf/pipeline_flux.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxPipeline
61
+
62
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
63
+ >>> pipe.to("cuda")
64
+ >>> prompt = "A cat holding a sign that says hello world"
65
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
66
+ >>> # Refer to the pipeline documentation for more details.
67
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
68
+ >>> image.save("flux.png")
69
+ ```
70
+ """
71
+
72
+
73
+ def calculate_shift(
74
+ image_seq_len,
75
+ base_seq_len: int = 256,
76
+ max_seq_len: int = 4096,
77
+ base_shift: float = 0.5,
78
+ max_shift: float = 1.15,
79
+ ):
80
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
81
+ b = base_shift - m * base_seq_len
82
+ mu = image_seq_len * m + b
83
+ return mu
84
+
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
87
+ def retrieve_timesteps(
88
+ scheduler,
89
+ num_inference_steps: Optional[int] = None,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ r"""
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
120
+ if timesteps is not None:
121
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accepts_timesteps:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" timestep schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
130
+ elif sigmas is not None:
131
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accept_sigmas:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ else:
141
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ return timesteps, num_inference_steps
144
+
145
+
146
+ class FluxPipeline(
147
+ DiffusionPipeline,
148
+ FluxLoraLoaderMixin,
149
+ FromSingleFileMixin,
150
+ TextualInversionLoaderMixin,
151
+ FluxIPAdapterMixin,
152
+ ):
153
+ r"""
154
+ The Flux pipeline for text-to-image generation.
155
+
156
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
157
+
158
+ Args:
159
+ transformer ([`FluxTransformer2DModel`]):
160
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
161
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
162
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163
+ vae ([`AutoencoderKL`]):
164
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
165
+ text_encoder ([`CLIPTextModel`]):
166
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
167
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
168
+ text_encoder_2 ([`T5EncoderModel`]):
169
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
170
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
171
+ tokenizer (`CLIPTokenizer`):
172
+ Tokenizer of class
173
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
174
+ tokenizer_2 (`T5TokenizerFast`):
175
+ Second Tokenizer of class
176
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
177
+ """
178
+
179
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
180
+ _optional_components = ["image_encoder", "feature_extractor"]
181
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
182
+
183
+ def __init__(
184
+ self,
185
+ scheduler: FlowMatchEulerDiscreteScheduler,
186
+ vae: AutoencoderKL,
187
+ text_encoder: CLIPTextModel,
188
+ tokenizer: CLIPTokenizer,
189
+ text_encoder_2: T5EncoderModel,
190
+ tokenizer_2: T5TokenizerFast,
191
+ transformer: FluxTransformer2DModel,
192
+ image_encoder: CLIPVisionModelWithProjection = None,
193
+ feature_extractor: CLIPImageProcessor = None,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.register_modules(
198
+ vae=vae,
199
+ text_encoder=text_encoder,
200
+ text_encoder_2=text_encoder_2,
201
+ tokenizer=tokenizer,
202
+ tokenizer_2=tokenizer_2,
203
+ transformer=transformer,
204
+ scheduler=scheduler,
205
+ image_encoder=image_encoder,
206
+ feature_extractor=feature_extractor,
207
+ )
208
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
209
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
210
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
211
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
212
+ self.tokenizer_max_length = (
213
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
214
+ )
215
+ self.default_sample_size = 128
216
+
217
+ def _get_t5_prompt_embeds(
218
+ self,
219
+ prompt: Union[str, List[str]] = None,
220
+ num_images_per_prompt: int = 1,
221
+ max_sequence_length: int = 512,
222
+ device: Optional[torch.device] = None,
223
+ dtype: Optional[torch.dtype] = None,
224
+ ):
225
+ device = device or self._execution_device
226
+ dtype = dtype or self.text_encoder.dtype
227
+
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ batch_size = len(prompt)
230
+
231
+ if isinstance(self, TextualInversionLoaderMixin):
232
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
233
+
234
+ text_inputs = self.tokenizer_2(
235
+ prompt,
236
+ padding="max_length",
237
+ max_length=max_sequence_length,
238
+ truncation=True,
239
+ return_length=False,
240
+ return_overflowing_tokens=False,
241
+ return_tensors="pt",
242
+ )
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
245
+
246
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
247
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
248
+ logger.warning(
249
+ "The following part of your input was truncated because `max_sequence_length` is set to "
250
+ f" {max_sequence_length} tokens: {removed_text}"
251
+ )
252
+
253
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
254
+
255
+ dtype = self.text_encoder_2.dtype
256
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
257
+
258
+ _, seq_len, _ = prompt_embeds.shape
259
+
260
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
261
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
262
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
263
+
264
+ return prompt_embeds
265
+
266
+ def _get_clip_prompt_embeds(
267
+ self,
268
+ prompt: Union[str, List[str]],
269
+ num_images_per_prompt: int = 1,
270
+ device: Optional[torch.device] = None,
271
+ ):
272
+ device = device or self._execution_device
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ batch_size = len(prompt)
276
+
277
+ if isinstance(self, TextualInversionLoaderMixin):
278
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
279
+
280
+ text_inputs = self.tokenizer(
281
+ prompt,
282
+ padding="max_length",
283
+ max_length=self.tokenizer_max_length,
284
+ truncation=True,
285
+ return_overflowing_tokens=False,
286
+ return_length=False,
287
+ return_tensors="pt",
288
+ )
289
+
290
+ text_input_ids = text_inputs.input_ids
291
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
294
+ logger.warning(
295
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
296
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
297
+ )
298
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
299
+
300
+ # Use pooled output of CLIPTextModel
301
+ prompt_embeds = prompt_embeds.pooler_output
302
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
303
+
304
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
305
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
306
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307
+
308
+ return prompt_embeds
309
+
310
+ def encode_prompt(
311
+ self,
312
+ prompt: Union[str, List[str]],
313
+ prompt_2: Union[str, List[str]],
314
+ device: Optional[torch.device] = None,
315
+ num_images_per_prompt: int = 1,
316
+ prompt_embeds: Optional[torch.FloatTensor] = None,
317
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ max_sequence_length: int = 512,
319
+ lora_scale: Optional[float] = None,
320
+ ):
321
+ r"""
322
+
323
+ Args:
324
+ prompt (`str` or `List[str]`, *optional*):
325
+ prompt to be encoded
326
+ prompt_2 (`str` or `List[str]`, *optional*):
327
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
328
+ used in all text-encoders
329
+ device: (`torch.device`):
330
+ torch device
331
+ num_images_per_prompt (`int`):
332
+ number of images that should be generated per prompt
333
+ prompt_embeds (`torch.FloatTensor`, *optional*):
334
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
335
+ provided, text embeddings will be generated from `prompt` input argument.
336
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
338
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
339
+ lora_scale (`float`, *optional*):
340
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
341
+ """
342
+ device = device or self._execution_device
343
+
344
+ # set lora scale so that monkey patched LoRA
345
+ # function of text encoder can correctly access it
346
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
347
+ self._lora_scale = lora_scale
348
+
349
+ # dynamically adjust the LoRA scale
350
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
351
+ scale_lora_layers(self.text_encoder, lora_scale)
352
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
353
+ scale_lora_layers(self.text_encoder_2, lora_scale)
354
+
355
+ prompt = [prompt] if isinstance(prompt, str) else prompt
356
+
357
+ if prompt_embeds is None:
358
+ prompt_2 = prompt_2 or prompt
359
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
360
+
361
+ # We only use the pooled prompt output from the CLIPTextModel
362
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
363
+ prompt=prompt,
364
+ device=device,
365
+ num_images_per_prompt=num_images_per_prompt,
366
+ )
367
+ prompt_embeds = self._get_t5_prompt_embeds(
368
+ prompt=prompt_2,
369
+ num_images_per_prompt=num_images_per_prompt,
370
+ max_sequence_length=max_sequence_length,
371
+ device=device,
372
+ )
373
+
374
+ if self.text_encoder is not None:
375
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
376
+ # Retrieve the original scale by scaling back the LoRA layers
377
+ unscale_lora_layers(self.text_encoder, lora_scale)
378
+
379
+ if self.text_encoder_2 is not None:
380
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
381
+ # Retrieve the original scale by scaling back the LoRA layers
382
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
383
+
384
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
385
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
386
+
387
+ return prompt_embeds, pooled_prompt_embeds, text_ids
388
+
389
+ def encode_image(self, image, device, num_images_per_prompt):
390
+ dtype = next(self.image_encoder.parameters()).dtype
391
+
392
+ if not isinstance(image, torch.Tensor):
393
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
394
+
395
+ image = image.to(device=device, dtype=dtype)
396
+ image_embeds = self.image_encoder(image).image_embeds
397
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
398
+ return image_embeds
399
+
400
+ def prepare_ip_adapter_image_embeds(
401
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
402
+ ):
403
+ image_embeds = []
404
+ if ip_adapter_image_embeds is None:
405
+ if not isinstance(ip_adapter_image, list):
406
+ ip_adapter_image = [ip_adapter_image]
407
+
408
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
409
+ raise ValueError(
410
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411
+ )
412
+
413
+ for single_ip_adapter_image in ip_adapter_image:
414
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
415
+ image_embeds.append(single_image_embeds[None, :])
416
+ else:
417
+ if not isinstance(ip_adapter_image_embeds, list):
418
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
419
+
420
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
421
+ raise ValueError(
422
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
423
+ )
424
+
425
+ for single_image_embeds in ip_adapter_image_embeds:
426
+ image_embeds.append(single_image_embeds)
427
+
428
+ ip_adapter_image_embeds = []
429
+ for single_image_embeds in image_embeds:
430
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
431
+ single_image_embeds = single_image_embeds.to(device=device)
432
+ ip_adapter_image_embeds.append(single_image_embeds)
433
+
434
+ return ip_adapter_image_embeds
435
+
436
+ def check_inputs(
437
+ self,
438
+ prompt,
439
+ prompt_2,
440
+ height,
441
+ width,
442
+ negative_prompt=None,
443
+ negative_prompt_2=None,
444
+ prompt_embeds=None,
445
+ negative_prompt_embeds=None,
446
+ pooled_prompt_embeds=None,
447
+ negative_pooled_prompt_embeds=None,
448
+ callback_on_step_end_tensor_inputs=None,
449
+ max_sequence_length=None,
450
+ ):
451
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
452
+ logger.warning(
453
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
454
+ )
455
+
456
+ if callback_on_step_end_tensor_inputs is not None and not all(
457
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
458
+ ):
459
+ raise ValueError(
460
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
461
+ )
462
+
463
+ if prompt is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt_2 is not None and prompt_embeds is not None:
469
+ raise ValueError(
470
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
471
+ " only forward one of the two."
472
+ )
473
+ elif prompt is None and prompt_embeds is None:
474
+ raise ValueError(
475
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
476
+ )
477
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
478
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
479
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
480
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
481
+
482
+ if negative_prompt is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
488
+ raise ValueError(
489
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
490
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
491
+ )
492
+
493
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
494
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
495
+ raise ValueError(
496
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
497
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
498
+ f" {negative_prompt_embeds.shape}."
499
+ )
500
+
501
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
502
+ raise ValueError(
503
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
504
+ )
505
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
506
+ raise ValueError(
507
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
508
+ )
509
+
510
+ if max_sequence_length is not None and max_sequence_length > 512:
511
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
512
+
513
+ @staticmethod
514
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
515
+ latent_image_ids = torch.zeros(height, width, 3)
516
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
517
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
518
+
519
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
520
+
521
+ latent_image_ids = latent_image_ids.reshape(
522
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
523
+ )
524
+
525
+ return latent_image_ids.to(device=device, dtype=dtype)
526
+
527
+ @staticmethod
528
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
529
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
530
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
531
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
532
+
533
+ return latents
534
+
535
+ @staticmethod
536
+ def _unpack_latents(latents, height, width, vae_scale_factor):
537
+ batch_size, num_patches, channels = latents.shape
538
+
539
+ # VAE applies 8x compression on images but we must also account for packing which requires
540
+ # latent height and width to be divisible by 2.
541
+ height = 2 * (int(height) // (vae_scale_factor * 2))
542
+ width = 2 * (int(width) // (vae_scale_factor * 2))
543
+
544
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
545
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
546
+
547
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
548
+
549
+ return latents
550
+
551
+ def enable_vae_slicing(self):
552
+ r"""
553
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
554
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
555
+ """
556
+ self.vae.enable_slicing()
557
+
558
+ def disable_vae_slicing(self):
559
+ r"""
560
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
561
+ computing decoding in one step.
562
+ """
563
+ self.vae.disable_slicing()
564
+
565
+ def enable_vae_tiling(self):
566
+ r"""
567
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
568
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
569
+ processing larger images.
570
+ """
571
+ self.vae.enable_tiling()
572
+
573
+ def disable_vae_tiling(self):
574
+ r"""
575
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
576
+ computing decoding in one step.
577
+ """
578
+ self.vae.disable_tiling()
579
+
580
+ def prepare_latents(
581
+ self,
582
+ batch_size,
583
+ num_channels_latents,
584
+ height,
585
+ width,
586
+ dtype,
587
+ device,
588
+ generator,
589
+ latents=None,
590
+ ):
591
+ # VAE applies 8x compression on images but we must also account for packing which requires
592
+ # latent height and width to be divisible by 2.
593
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
594
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
595
+
596
+ shape = (batch_size, num_channels_latents, height, width)
597
+
598
+ if latents is not None:
599
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
600
+ return latents.to(device=device, dtype=dtype), latent_image_ids
601
+
602
+ if isinstance(generator, list) and len(generator) != batch_size:
603
+ raise ValueError(
604
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
605
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
606
+ )
607
+
608
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
609
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
610
+
611
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
612
+
613
+ return latents, latent_image_ids
614
+
615
+ @property
616
+ def guidance_scale(self):
617
+ return self._guidance_scale
618
+
619
+ @property
620
+ def joint_attention_kwargs(self):
621
+ return self._joint_attention_kwargs
622
+
623
+ @property
624
+ def num_timesteps(self):
625
+ return self._num_timesteps
626
+
627
+ @property
628
+ def current_timestep(self):
629
+ return self._current_timestep
630
+
631
+ @property
632
+ def interrupt(self):
633
+ return self._interrupt
634
+
635
+ @torch.no_grad()
636
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
637
+ def __call__(
638
+ self,
639
+ prompt: Union[str, List[str]] = None,
640
+ prompt_2: Optional[Union[str, List[str]]] = None,
641
+ negative_prompt: Union[str, List[str]] = None,
642
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
643
+ true_cfg_scale: float = 1.0,
644
+ height: Optional[int] = None,
645
+ width: Optional[int] = None,
646
+ num_inference_steps: int = 28,
647
+ sigmas: Optional[List[float]] = None,
648
+ guidance_scale: float = 3.5,
649
+ num_images_per_prompt: Optional[int] = 1,
650
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
651
+ latents: Optional[torch.FloatTensor] = None,
652
+ prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
654
+ ip_adapter_image: Optional[PipelineImageInput] = None,
655
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
656
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
657
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
658
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
659
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
660
+ output_type: Optional[str] = "pil",
661
+ return_dict: bool = True,
662
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
663
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
664
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
665
+ max_sequence_length: int = 512,
666
+ ):
667
+ r"""
668
+ Function invoked when calling the pipeline for generation.
669
+
670
+ Args:
671
+ prompt (`str` or `List[str]`, *optional*):
672
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
673
+ instead.
674
+ prompt_2 (`str` or `List[str]`, *optional*):
675
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
676
+ will be used instead.
677
+ negative_prompt (`str` or `List[str]`, *optional*):
678
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
679
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
680
+ not greater than `1`).
681
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
682
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
683
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
684
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
685
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
686
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
687
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
688
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
689
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
690
+ num_inference_steps (`int`, *optional*, defaults to 50):
691
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
692
+ expense of slower inference.
693
+ sigmas (`List[float]`, *optional*):
694
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
695
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
696
+ will be used.
697
+ guidance_scale (`float`, *optional*, defaults to 3.5):
698
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
699
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
700
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
701
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
702
+ usually at the expense of lower image quality.
703
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
704
+ The number of images to generate per prompt.
705
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
706
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
707
+ to make generation deterministic.
708
+ latents (`torch.FloatTensor`, *optional*):
709
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
710
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
711
+ tensor will ge generated by sampling using the supplied random `generator`.
712
+ prompt_embeds (`torch.FloatTensor`, *optional*):
713
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
714
+ provided, text embeddings will be generated from `prompt` input argument.
715
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
716
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
717
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
718
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
719
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
720
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
721
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
722
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
723
+ negative_ip_adapter_image:
724
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
725
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
726
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
727
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
728
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
729
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
730
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
731
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
732
+ argument.
733
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
734
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
735
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
736
+ input argument.
737
+ output_type (`str`, *optional*, defaults to `"pil"`):
738
+ The output format of the generate image. Choose between
739
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
740
+ return_dict (`bool`, *optional*, defaults to `True`):
741
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
742
+ joint_attention_kwargs (`dict`, *optional*):
743
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
744
+ `self.processor` in
745
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
746
+ callback_on_step_end (`Callable`, *optional*):
747
+ A function that calls at the end of each denoising steps during the inference. The function is called
748
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
749
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
750
+ `callback_on_step_end_tensor_inputs`.
751
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
752
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
753
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
754
+ `._callback_tensor_inputs` attribute of your pipeline class.
755
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
756
+
757
+ Examples:
758
+
759
+ Returns:
760
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
761
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
762
+ images.
763
+ """
764
+
765
+ height = height or self.default_sample_size * self.vae_scale_factor
766
+ width = width or self.default_sample_size * self.vae_scale_factor
767
+
768
+ # 1. Check inputs. Raise error if not correct
769
+ self.check_inputs(
770
+ prompt,
771
+ prompt_2,
772
+ height,
773
+ width,
774
+ negative_prompt=negative_prompt,
775
+ negative_prompt_2=negative_prompt_2,
776
+ prompt_embeds=prompt_embeds,
777
+ negative_prompt_embeds=negative_prompt_embeds,
778
+ pooled_prompt_embeds=pooled_prompt_embeds,
779
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
780
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
781
+ max_sequence_length=max_sequence_length,
782
+ )
783
+
784
+ self._guidance_scale = guidance_scale
785
+ self._joint_attention_kwargs = joint_attention_kwargs
786
+ self._current_timestep = None
787
+ self._interrupt = False
788
+
789
+ # 2. Define call parameters
790
+ if prompt is not None and isinstance(prompt, str):
791
+ batch_size = 1
792
+ elif prompt is not None and isinstance(prompt, list):
793
+ batch_size = len(prompt)
794
+ else:
795
+ batch_size = prompt_embeds.shape[0]
796
+
797
+ device = self._execution_device
798
+
799
+ lora_scale = (
800
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
801
+ )
802
+ has_neg_prompt = negative_prompt is not None or (
803
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
804
+ )
805
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
806
+ (
807
+ prompt_embeds,
808
+ pooled_prompt_embeds,
809
+ text_ids,
810
+ ) = self.encode_prompt(
811
+ prompt=prompt,
812
+ prompt_2=prompt_2,
813
+ prompt_embeds=prompt_embeds,
814
+ pooled_prompt_embeds=pooled_prompt_embeds,
815
+ device=device,
816
+ num_images_per_prompt=num_images_per_prompt,
817
+ max_sequence_length=max_sequence_length,
818
+ lora_scale=lora_scale,
819
+ )
820
+ if do_true_cfg:
821
+ (
822
+ negative_prompt_embeds,
823
+ negative_pooled_prompt_embeds,
824
+ _,
825
+ ) = self.encode_prompt(
826
+ prompt=negative_prompt,
827
+ prompt_2=negative_prompt_2,
828
+ prompt_embeds=negative_prompt_embeds,
829
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
830
+ device=device,
831
+ num_images_per_prompt=num_images_per_prompt,
832
+ max_sequence_length=max_sequence_length,
833
+ lora_scale=lora_scale,
834
+ )
835
+
836
+ # 4. Prepare latent variables
837
+ num_channels_latents = self.transformer.config.in_channels // 4
838
+ latents, latent_image_ids = self.prepare_latents(
839
+ batch_size * num_images_per_prompt,
840
+ num_channels_latents,
841
+ height,
842
+ width,
843
+ prompt_embeds.dtype,
844
+ device,
845
+ generator,
846
+ latents,
847
+ )
848
+
849
+ # 5. Prepare timesteps
850
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
851
+ image_seq_len = latents.shape[1]
852
+ mu = calculate_shift(
853
+ image_seq_len,
854
+ self.scheduler.config.get("base_image_seq_len", 256),
855
+ self.scheduler.config.get("max_image_seq_len", 4096),
856
+ self.scheduler.config.get("base_shift", 0.5),
857
+ self.scheduler.config.get("max_shift", 1.15),
858
+ )
859
+ timesteps, num_inference_steps = retrieve_timesteps(
860
+ self.scheduler,
861
+ num_inference_steps,
862
+ device,
863
+ sigmas=sigmas,
864
+ mu=mu,
865
+ )
866
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
867
+ self._num_timesteps = len(timesteps)
868
+
869
+ # handle guidance
870
+ if self.transformer.config.guidance_embeds:
871
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
872
+ guidance = guidance.expand(latents.shape[0])
873
+ else:
874
+ guidance = None
875
+
876
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
877
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
878
+ ):
879
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
880
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
881
+
882
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
883
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
884
+ ):
885
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
886
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
887
+
888
+ if self.joint_attention_kwargs is None:
889
+ self._joint_attention_kwargs = {}
890
+
891
+ image_embeds = None
892
+ negative_image_embeds = None
893
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
894
+ image_embeds = self.prepare_ip_adapter_image_embeds(
895
+ ip_adapter_image,
896
+ ip_adapter_image_embeds,
897
+ device,
898
+ batch_size * num_images_per_prompt,
899
+ )
900
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
901
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
902
+ negative_ip_adapter_image,
903
+ negative_ip_adapter_image_embeds,
904
+ device,
905
+ batch_size * num_images_per_prompt,
906
+ )
907
+
908
+ # 6. Denoising loop
909
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
910
+ for i, t in enumerate(timesteps):
911
+ if self.interrupt:
912
+ continue
913
+
914
+ self._current_timestep = t
915
+ if image_embeds is not None:
916
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
917
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
918
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
919
+ noise_pred = self.transformer(
920
+ hidden_states=latents,
921
+ timestep=timestep / 1000,
922
+ guidance=guidance,
923
+ pooled_projections=pooled_prompt_embeds,
924
+ encoder_hidden_states=prompt_embeds,
925
+ txt_ids=text_ids,
926
+ img_ids=latent_image_ids,
927
+ joint_attention_kwargs=self.joint_attention_kwargs,
928
+ return_dict=False,
929
+ )[0]
930
+
931
+ if do_true_cfg:
932
+ if negative_image_embeds is not None:
933
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
934
+ neg_noise_pred = self.transformer(
935
+ hidden_states=latents,
936
+ timestep=timestep / 1000,
937
+ guidance=guidance,
938
+ pooled_projections=negative_pooled_prompt_embeds,
939
+ encoder_hidden_states=negative_prompt_embeds,
940
+ txt_ids=text_ids,
941
+ img_ids=latent_image_ids,
942
+ joint_attention_kwargs=self.joint_attention_kwargs,
943
+ return_dict=False,
944
+ )[0]
945
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
946
+
947
+ # compute the previous noisy sample x_t -> x_t-1
948
+ latents_dtype = latents.dtype
949
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
950
+
951
+ if latents.dtype != latents_dtype:
952
+ if torch.backends.mps.is_available():
953
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
954
+ latents = latents.to(latents_dtype)
955
+
956
+ if callback_on_step_end is not None:
957
+ callback_kwargs = {}
958
+ for k in callback_on_step_end_tensor_inputs:
959
+ callback_kwargs[k] = locals()[k]
960
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
961
+
962
+ latents = callback_outputs.pop("latents", latents)
963
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
964
+
965
+ # call the callback, if provided
966
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
967
+ progress_bar.update()
968
+
969
+ if XLA_AVAILABLE:
970
+ xm.mark_step()
971
+
972
+ self._current_timestep = None
973
+
974
+ if output_type == "latent":
975
+ image = latents
976
+ else:
977
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
978
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
979
+ image = self.vae.decode(latents, return_dict=False)[0]
980
+ image = self.image_processor.postprocess(image, output_type=output_type)
981
+
982
+ # Offload all models
983
+ self.maybe_free_model_hooks()
984
+
985
+ if not return_dict:
986
+ return (image,)
987
+
988
+ return FluxPipelineOutput(images=image)
fastvideo/models/hunyuan/__init__.py ADDED
File without changes
fastvideo/models/hunyuan/constants.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ __all__ = [
6
+ "C_SCALE",
7
+ "PROMPT_TEMPLATE",
8
+ "MODEL_BASE",
9
+ "PRECISIONS",
10
+ "NORMALIZATION_TYPE",
11
+ "ACTIVATION_TYPE",
12
+ "VAE_PATH",
13
+ "TEXT_ENCODER_PATH",
14
+ "TOKENIZER_PATH",
15
+ "TEXT_PROJECTION",
16
+ "DATA_TYPE",
17
+ "NEGATIVE_PROMPT",
18
+ ]
19
+
20
+ PRECISION_TO_TYPE = {
21
+ "fp32": torch.float32,
22
+ "fp16": torch.float16,
23
+ "bf16": torch.bfloat16,
24
+ }
25
+
26
+ # =================== Constant Values =====================
27
+ # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
28
+ # overflow error when tensorboard logging values.
29
+ C_SCALE = 1_000_000_000_000_000
30
+
31
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
32
+ # on how to generate the text.
33
+ # --------------------------------------------------------------------
34
+ PROMPT_TEMPLATE_ENCODE = (
35
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
36
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
37
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
38
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
39
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
40
+ "1. The main content and theme of the video."
41
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
42
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
43
+ "4. background environment, light, style and atmosphere."
44
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
45
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
46
+
47
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
48
+
49
+ PROMPT_TEMPLATE = {
50
+ "dit-llm-encode": {
51
+ "template": PROMPT_TEMPLATE_ENCODE,
52
+ "crop_start": 36,
53
+ },
54
+ "dit-llm-encode-video": {
55
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
56
+ "crop_start": 95,
57
+ },
58
+ }
59
+
60
+ # ======================= Model ======================
61
+ PRECISIONS = {"fp32", "fp16", "bf16"}
62
+ NORMALIZATION_TYPE = {"layer", "rms"}
63
+ ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
64
+
65
+ # =================== Model Path =====================
66
+ MODEL_BASE = os.getenv("MODEL_BASE", "./data/hunyuan")
67
+
68
+ # =================== Data =======================
69
+ DATA_TYPE = {"image", "video", "image_video"}
70
+
71
+ # 3D VAE
72
+ VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
73
+
74
+ # Text Encoder
75
+ TEXT_ENCODER_PATH = {
76
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
77
+ "llm": f"{MODEL_BASE}/text_encoder",
78
+ }
79
+
80
+ # Tokenizer
81
+ TOKENIZER_PATH = {
82
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
83
+ "llm": f"{MODEL_BASE}/text_encoder",
84
+ }
85
+
86
+ TEXT_PROJECTION = {
87
+ "linear", # Default, an nn.Linear() layer
88
+ "single_refiner", # Single TokenRefiner. Refer to LI-DiT
89
+ }
fastvideo/models/hunyuan/idle_config.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: F405, F403
2
+ import argparse
3
+ import re
4
+
5
+ from .constants import *
6
+ from .modules.models import HUNYUAN_VIDEO_CONFIG
7
+
8
+
9
+ def parse_args(namespace=None):
10
+ parser = argparse.ArgumentParser(
11
+ description="HunyuanVideo inference script")
12
+
13
+ parser = add_network_args(parser)
14
+ parser = add_extra_models_args(parser)
15
+ parser = add_denoise_schedule_args(parser)
16
+ parser = add_inference_args(parser)
17
+ parser = add_parallel_args(parser)
18
+
19
+ args = parser.parse_args(namespace=namespace)
20
+ args = sanity_check_args(args)
21
+
22
+ return args
23
+
24
+
25
+ def add_network_args(parser: argparse.ArgumentParser):
26
+ group = parser.add_argument_group(title="HunyuanVideo network args")
27
+
28
+ # Main model
29
+ group.add_argument(
30
+ "--model",
31
+ type=str,
32
+ choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
33
+ default="HYVideo-T/2-cfgdistill",
34
+ )
35
+ group.add_argument(
36
+ "--latent-channels",
37
+ type=str,
38
+ default=16,
39
+ help=
40
+ "Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
41
+ "it still needs to match the latent channels of the VAE model.",
42
+ )
43
+ group.add_argument(
44
+ "--precision",
45
+ type=str,
46
+ default="bf16",
47
+ choices=PRECISIONS,
48
+ help=
49
+ "Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
50
+ )
51
+
52
+ # RoPE
53
+ group.add_argument("--rope-theta",
54
+ type=int,
55
+ default=256,
56
+ help="Theta used in RoPE.")
57
+ return parser
58
+
59
+
60
+ def add_extra_models_args(parser: argparse.ArgumentParser):
61
+ group = parser.add_argument_group(
62
+ title="Extra models args, including vae, text encoders and tokenizers)"
63
+ )
64
+
65
+ # - VAE
66
+ group.add_argument(
67
+ "--vae",
68
+ type=str,
69
+ default="884-16c-hy",
70
+ choices=list(VAE_PATH),
71
+ help="Name of the VAE model.",
72
+ )
73
+ group.add_argument(
74
+ "--vae-precision",
75
+ type=str,
76
+ default="fp16",
77
+ choices=PRECISIONS,
78
+ help="Precision mode for the VAE model.",
79
+ )
80
+ group.add_argument(
81
+ "--vae-tiling",
82
+ action="store_true",
83
+ help="Enable tiling for the VAE model to save GPU memory.",
84
+ )
85
+ group.set_defaults(vae_tiling=True)
86
+
87
+ group.add_argument(
88
+ "--text-encoder",
89
+ type=str,
90
+ default="llm",
91
+ choices=list(TEXT_ENCODER_PATH),
92
+ help="Name of the text encoder model.",
93
+ )
94
+ group.add_argument(
95
+ "--text-encoder-precision",
96
+ type=str,
97
+ default="fp16",
98
+ choices=PRECISIONS,
99
+ help="Precision mode for the text encoder model.",
100
+ )
101
+ group.add_argument(
102
+ "--text-states-dim",
103
+ type=int,
104
+ default=4096,
105
+ help="Dimension of the text encoder hidden states.",
106
+ )
107
+ group.add_argument("--text-len",
108
+ type=int,
109
+ default=256,
110
+ help="Maximum length of the text input.")
111
+ group.add_argument(
112
+ "--tokenizer",
113
+ type=str,
114
+ default="llm",
115
+ choices=list(TOKENIZER_PATH),
116
+ help="Name of the tokenizer model.",
117
+ )
118
+ group.add_argument(
119
+ "--prompt-template",
120
+ type=str,
121
+ default="dit-llm-encode",
122
+ choices=PROMPT_TEMPLATE,
123
+ help="Image prompt template for the decoder-only text encoder model.",
124
+ )
125
+ group.add_argument(
126
+ "--prompt-template-video",
127
+ type=str,
128
+ default="dit-llm-encode-video",
129
+ choices=PROMPT_TEMPLATE,
130
+ help="Video prompt template for the decoder-only text encoder model.",
131
+ )
132
+ group.add_argument(
133
+ "--hidden-state-skip-layer",
134
+ type=int,
135
+ default=2,
136
+ help="Skip layer for hidden states.",
137
+ )
138
+ group.add_argument(
139
+ "--apply-final-norm",
140
+ action="store_true",
141
+ help=
142
+ "Apply final normalization to the used text encoder hidden states.",
143
+ )
144
+
145
+ # - CLIP
146
+ group.add_argument(
147
+ "--text-encoder-2",
148
+ type=str,
149
+ default="clipL",
150
+ choices=list(TEXT_ENCODER_PATH),
151
+ help="Name of the second text encoder model.",
152
+ )
153
+ group.add_argument(
154
+ "--text-encoder-precision-2",
155
+ type=str,
156
+ default="fp16",
157
+ choices=PRECISIONS,
158
+ help="Precision mode for the second text encoder model.",
159
+ )
160
+ group.add_argument(
161
+ "--text-states-dim-2",
162
+ type=int,
163
+ default=768,
164
+ help="Dimension of the second text encoder hidden states.",
165
+ )
166
+ group.add_argument(
167
+ "--tokenizer-2",
168
+ type=str,
169
+ default="clipL",
170
+ choices=list(TOKENIZER_PATH),
171
+ help="Name of the second tokenizer model.",
172
+ )
173
+ group.add_argument(
174
+ "--text-len-2",
175
+ type=int,
176
+ default=77,
177
+ help="Maximum length of the second text input.",
178
+ )
179
+
180
+ return parser
181
+
182
+
183
+ def add_denoise_schedule_args(parser: argparse.ArgumentParser):
184
+ group = parser.add_argument_group(title="Denoise schedule args")
185
+
186
+ group.add_argument(
187
+ "--denoise-type",
188
+ type=str,
189
+ default="flow",
190
+ help="Denoise type for noised inputs.",
191
+ )
192
+
193
+ # Flow Matching
194
+ group.add_argument(
195
+ "--flow-shift",
196
+ type=float,
197
+ default=7.0,
198
+ help="Shift factor for flow matching schedulers.",
199
+ )
200
+ group.add_argument(
201
+ "--flow-reverse",
202
+ action="store_true",
203
+ help="If reverse, learning/sampling from t=1 -> t=0.",
204
+ )
205
+ group.add_argument(
206
+ "--flow-solver",
207
+ type=str,
208
+ default="euler",
209
+ help="Solver for flow matching.",
210
+ )
211
+ group.add_argument(
212
+ "--use-linear-quadratic-schedule",
213
+ action="store_true",
214
+ help="Use linear quadratic schedule for flow matching."
215
+ "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
216
+ )
217
+ group.add_argument(
218
+ "--linear-schedule-end",
219
+ type=int,
220
+ default=25,
221
+ help="End step for linear quadratic schedule for flow matching.",
222
+ )
223
+
224
+ return parser
225
+
226
+
227
+ def add_inference_args(parser: argparse.ArgumentParser):
228
+ group = parser.add_argument_group(title="Inference args")
229
+
230
+ # ======================== Model loads ========================
231
+ group.add_argument(
232
+ "--model-base",
233
+ type=str,
234
+ default="ckpts",
235
+ help=
236
+ "Root path of all the models, including t2v models and extra models.",
237
+ )
238
+ group.add_argument(
239
+ "--dit-weight",
240
+ type=str,
241
+ default=
242
+ "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
243
+ help=
244
+ "Path to the HunyuanVideo model. If None, search the model in the args.model_root."
245
+ "1. If it is a file, load the model directly."
246
+ "2. If it is a directory, search the model in the directory. Support two types of models: "
247
+ "1) named `pytorch_model_*.pt`"
248
+ "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
249
+ )
250
+ group.add_argument(
251
+ "--model-resolution",
252
+ type=str,
253
+ default="540p",
254
+ choices=["540p", "720p"],
255
+ help=
256
+ "Root path of all the models, including t2v models and extra models.",
257
+ )
258
+ group.add_argument(
259
+ "--load-key",
260
+ type=str,
261
+ default="module",
262
+ help=
263
+ "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
264
+ )
265
+ group.add_argument(
266
+ "--use-cpu-offload",
267
+ action="store_true",
268
+ help="Use CPU offload for the model load.",
269
+ )
270
+
271
+ # ======================== Inference general setting ========================
272
+ group.add_argument(
273
+ "--batch-size",
274
+ type=int,
275
+ default=1,
276
+ help="Batch size for inference and evaluation.",
277
+ )
278
+ group.add_argument(
279
+ "--infer-steps",
280
+ type=int,
281
+ default=50,
282
+ help="Number of denoising steps for inference.",
283
+ )
284
+ group.add_argument(
285
+ "--disable-autocast",
286
+ action="store_true",
287
+ help=
288
+ "Disable autocast for denoising loop and vae decoding in pipeline sampling.",
289
+ )
290
+ group.add_argument(
291
+ "--save-path",
292
+ type=str,
293
+ default="./results",
294
+ help="Path to save the generated samples.",
295
+ )
296
+ group.add_argument(
297
+ "--save-path-suffix",
298
+ type=str,
299
+ default="",
300
+ help="Suffix for the directory of saved samples.",
301
+ )
302
+ group.add_argument(
303
+ "--name-suffix",
304
+ type=str,
305
+ default="",
306
+ help="Suffix for the names of saved samples.",
307
+ )
308
+ group.add_argument(
309
+ "--num-videos",
310
+ type=int,
311
+ default=1,
312
+ help="Number of videos to generate for each prompt.",
313
+ )
314
+ # ---sample size---
315
+ group.add_argument(
316
+ "--video-size",
317
+ type=int,
318
+ nargs="+",
319
+ default=(720, 1280),
320
+ help=
321
+ "Video size for training. If a single value is provided, it will be used for both height "
322
+ "and width. If two values are provided, they will be used for height and width "
323
+ "respectively.",
324
+ )
325
+ group.add_argument(
326
+ "--video-length",
327
+ type=int,
328
+ default=129,
329
+ help=
330
+ "How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
331
+ )
332
+ # --- prompt ---
333
+ group.add_argument(
334
+ "--prompt",
335
+ type=str,
336
+ default=None,
337
+ help="Prompt for sampling during evaluation.",
338
+ )
339
+ group.add_argument(
340
+ "--seed-type",
341
+ type=str,
342
+ default="auto",
343
+ choices=["file", "random", "fixed", "auto"],
344
+ help=
345
+ "Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
346
+ "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
347
+ "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
348
+ "fixed `seed` value.",
349
+ )
350
+ group.add_argument("--seed",
351
+ type=int,
352
+ default=None,
353
+ help="Seed for evaluation.")
354
+
355
+ # Classifier-Free Guidance
356
+ group.add_argument("--neg-prompt",
357
+ type=str,
358
+ default=None,
359
+ help="Negative prompt for sampling.")
360
+ group.add_argument("--cfg-scale",
361
+ type=float,
362
+ default=1.0,
363
+ help="Classifier free guidance scale.")
364
+ group.add_argument(
365
+ "--embedded-cfg-scale",
366
+ type=float,
367
+ default=6.0,
368
+ help="Embedded classifier free guidance scale.",
369
+ )
370
+
371
+ group.add_argument(
372
+ "--reproduce",
373
+ action="store_true",
374
+ help=
375
+ "Enable reproducibility by setting random seeds and deterministic algorithms.",
376
+ )
377
+
378
+ return parser
379
+
380
+
381
+ def add_parallel_args(parser: argparse.ArgumentParser):
382
+ group = parser.add_argument_group(title="Parallel args")
383
+
384
+ # ======================== Model loads ========================
385
+ group.add_argument(
386
+ "--ulysses-degree",
387
+ type=int,
388
+ default=1,
389
+ help="Ulysses degree.",
390
+ )
391
+ group.add_argument(
392
+ "--ring-degree",
393
+ type=int,
394
+ default=1,
395
+ help="Ulysses degree.",
396
+ )
397
+
398
+ return parser
399
+
400
+
401
+ def sanity_check_args(args):
402
+ # VAE channels
403
+ vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
404
+ if not re.match(vae_pattern, args.vae):
405
+ raise ValueError(
406
+ f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
407
+ )
408
+ vae_channels = int(args.vae.split("-")[1][:-1])
409
+ if args.latent_channels is None:
410
+ args.latent_channels = vae_channels
411
+ if vae_channels != args.latent_channels:
412
+ raise ValueError(
413
+ f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
414
+ )
415
+ return args
fastvideo/models/hunyuan/inference.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from loguru import logger
8
+ from safetensors.torch import load_file as safetensors_load_file
9
+
10
+ from fastvideo.models.hunyuan.constants import (NEGATIVE_PROMPT,
11
+ PRECISION_TO_TYPE,
12
+ PROMPT_TEMPLATE)
13
+ from fastvideo.models.hunyuan.diffusion.pipelines import HunyuanVideoPipeline
14
+ from fastvideo.models.hunyuan.diffusion.schedulers import \
15
+ FlowMatchDiscreteScheduler
16
+ from fastvideo.models.hunyuan.modules import load_model
17
+ from fastvideo.models.hunyuan.text_encoder import TextEncoder
18
+ from fastvideo.models.hunyuan.utils.data_utils import align_to
19
+ from fastvideo.models.hunyuan.vae import load_vae
20
+ from fastvideo.utils.parallel_states import nccl_info
21
+
22
+
23
+ class Inference(object):
24
+
25
+ def __init__(
26
+ self,
27
+ args,
28
+ vae,
29
+ vae_kwargs,
30
+ text_encoder,
31
+ model,
32
+ text_encoder_2=None,
33
+ pipeline=None,
34
+ use_cpu_offload=False,
35
+ device=None,
36
+ logger=None,
37
+ parallel_args=None,
38
+ ):
39
+ self.vae = vae
40
+ self.vae_kwargs = vae_kwargs
41
+
42
+ self.text_encoder = text_encoder
43
+ self.text_encoder_2 = text_encoder_2
44
+
45
+ self.model = model
46
+ self.pipeline = pipeline
47
+ self.use_cpu_offload = use_cpu_offload
48
+
49
+ self.args = args
50
+ self.device = (device if device is not None else
51
+ "cuda" if torch.cuda.is_available() else "cpu")
52
+ self.logger = logger
53
+ self.parallel_args = parallel_args
54
+
55
+ @classmethod
56
+ def from_pretrained(cls,
57
+ pretrained_model_path,
58
+ args,
59
+ device=None,
60
+ **kwargs):
61
+ """
62
+ Initialize the Inference pipeline.
63
+
64
+ Args:
65
+ pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
66
+ args (argparse.Namespace): The arguments for the pipeline.
67
+ device (int): The device for inference. Default is 0.
68
+ """
69
+ # ========================================================================
70
+ logger.info(
71
+ f"Got text-to-video model root path: {pretrained_model_path}")
72
+
73
+ # ==================== Initialize Distributed Environment ================
74
+ if nccl_info.sp_size > 1:
75
+ device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
76
+ if device is None:
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+ parallel_args = None # {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
80
+
81
+ # ======================== Get the args path =============================
82
+
83
+ # Disable gradient
84
+ torch.set_grad_enabled(False)
85
+
86
+ # =========================== Build main model ===========================
87
+ logger.info("Building model...")
88
+ factor_kwargs = {
89
+ "device": device,
90
+ "dtype": PRECISION_TO_TYPE[args.precision]
91
+ }
92
+ in_channels = args.latent_channels
93
+ out_channels = args.latent_channels
94
+
95
+ model = load_model(
96
+ args,
97
+ in_channels=in_channels,
98
+ out_channels=out_channels,
99
+ factor_kwargs=factor_kwargs,
100
+ )
101
+ model = model.to(device)
102
+ model = Inference.load_state_dict(args, model, pretrained_model_path)
103
+ model.eval()
104
+
105
+ # ============================= Build extra models ========================
106
+ # VAE
107
+ vae, _, s_ratio, t_ratio = load_vae(
108
+ args.vae,
109
+ args.vae_precision,
110
+ logger=logger,
111
+ device=device if not args.use_cpu_offload else "cpu",
112
+ )
113
+ vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
114
+
115
+ # Text encoder
116
+ if args.prompt_template_video is not None:
117
+ crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
118
+ "crop_start", 0)
119
+ elif args.prompt_template is not None:
120
+ crop_start = PROMPT_TEMPLATE[args.prompt_template].get(
121
+ "crop_start", 0)
122
+ else:
123
+ crop_start = 0
124
+ max_length = args.text_len + crop_start
125
+
126
+ # prompt_template
127
+ prompt_template = (PROMPT_TEMPLATE[args.prompt_template]
128
+ if args.prompt_template is not None else None)
129
+
130
+ # prompt_template_video
131
+ prompt_template_video = (PROMPT_TEMPLATE[args.prompt_template_video]
132
+ if args.prompt_template_video is not None else
133
+ None)
134
+
135
+ text_encoder = TextEncoder(
136
+ text_encoder_type=args.text_encoder,
137
+ max_length=max_length,
138
+ text_encoder_precision=args.text_encoder_precision,
139
+ tokenizer_type=args.tokenizer,
140
+ prompt_template=prompt_template,
141
+ prompt_template_video=prompt_template_video,
142
+ hidden_state_skip_layer=args.hidden_state_skip_layer,
143
+ apply_final_norm=args.apply_final_norm,
144
+ reproduce=args.reproduce,
145
+ logger=logger,
146
+ device=device if not args.use_cpu_offload else "cpu",
147
+ )
148
+ text_encoder_2 = None
149
+ if args.text_encoder_2 is not None:
150
+ text_encoder_2 = TextEncoder(
151
+ text_encoder_type=args.text_encoder_2,
152
+ max_length=args.text_len_2,
153
+ text_encoder_precision=args.text_encoder_precision_2,
154
+ tokenizer_type=args.tokenizer_2,
155
+ reproduce=args.reproduce,
156
+ logger=logger,
157
+ device=device if not args.use_cpu_offload else "cpu",
158
+ )
159
+
160
+ return cls(
161
+ args=args,
162
+ vae=vae,
163
+ vae_kwargs=vae_kwargs,
164
+ text_encoder=text_encoder,
165
+ text_encoder_2=text_encoder_2,
166
+ model=model,
167
+ use_cpu_offload=args.use_cpu_offload,
168
+ device=device,
169
+ logger=logger,
170
+ parallel_args=parallel_args,
171
+ )
172
+
173
+ @staticmethod
174
+ def load_state_dict(args, model, pretrained_model_path):
175
+ load_key = args.load_key
176
+ dit_weight = Path(args.dit_weight)
177
+
178
+ if dit_weight is None:
179
+ model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
180
+ files = list(model_dir.glob("*.pt"))
181
+ if len(files) == 0:
182
+ raise ValueError(f"No model weights found in {model_dir}")
183
+ if str(files[0]).startswith("pytorch_model_"):
184
+ model_path = dit_weight / f"pytorch_model_{load_key}.pt"
185
+ bare_model = True
186
+ elif any(str(f).endswith("_model_states.pt") for f in files):
187
+ files = [
188
+ f for f in files if str(f).endswith("_model_states.pt")
189
+ ]
190
+ model_path = files[0]
191
+ if len(files) > 1:
192
+ logger.warning(
193
+ f"Multiple model weights found in {dit_weight}, using {model_path}"
194
+ )
195
+ bare_model = False
196
+ else:
197
+ raise ValueError(
198
+ f"Invalid model path: {dit_weight} with unrecognized weight format: "
199
+ f"{list(map(str, files))}. When given a directory as --dit-weight, only "
200
+ f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
201
+ f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
202
+ f"specific weight file, please provide the full path to the file."
203
+ )
204
+ else:
205
+ if dit_weight.is_dir():
206
+ files = list(dit_weight.glob("*.pt"))
207
+ if len(files) == 0:
208
+ raise ValueError(f"No model weights found in {dit_weight}")
209
+ if str(files[0]).startswith("pytorch_model_"):
210
+ model_path = dit_weight / f"pytorch_model_{load_key}.pt"
211
+ bare_model = True
212
+ elif any(str(f).endswith("_model_states.pt") for f in files):
213
+ files = [
214
+ f for f in files if str(f).endswith("_model_states.pt")
215
+ ]
216
+ model_path = files[0]
217
+ if len(files) > 1:
218
+ logger.warning(
219
+ f"Multiple model weights found in {dit_weight}, using {model_path}"
220
+ )
221
+ bare_model = False
222
+ else:
223
+ raise ValueError(
224
+ f"Invalid model path: {dit_weight} with unrecognized weight format: "
225
+ f"{list(map(str, files))}. When given a directory as --dit-weight, only "
226
+ f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
227
+ f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
228
+ f"specific weight file, please provide the full path to the file."
229
+ )
230
+ elif dit_weight.is_file():
231
+ model_path = dit_weight
232
+ bare_model = "unknown"
233
+ else:
234
+ raise ValueError(f"Invalid model path: {dit_weight}")
235
+
236
+ if not model_path.exists():
237
+ raise ValueError(f"model_path not exists: {model_path}")
238
+ logger.info(f"Loading torch model {model_path}...")
239
+ if model_path.suffix == ".safetensors":
240
+ # Use safetensors library for .safetensors files
241
+ state_dict = safetensors_load_file(model_path)
242
+ elif model_path.suffix == ".pt":
243
+ # Use torch for .pt files
244
+ state_dict = torch.load(model_path,
245
+ map_location=lambda storage, loc: storage)
246
+ else:
247
+ raise ValueError(f"Unsupported file format: {model_path}")
248
+
249
+ if bare_model == "unknown" and ("ema" in state_dict
250
+ or "module" in state_dict):
251
+ bare_model = False
252
+ if bare_model is False:
253
+ if load_key in state_dict:
254
+ state_dict = state_dict[load_key]
255
+ else:
256
+ raise KeyError(
257
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
258
+ f"are: {list(state_dict.keys())}.")
259
+ model.load_state_dict(state_dict, strict=True)
260
+ return model
261
+
262
+ @staticmethod
263
+ def parse_size(size):
264
+ if isinstance(size, int):
265
+ size = [size]
266
+ if not isinstance(size, (list, tuple)):
267
+ raise ValueError(
268
+ f"Size must be an integer or (height, width), got {size}.")
269
+ if len(size) == 1:
270
+ size = [size[0], size[0]]
271
+ if len(size) != 2:
272
+ raise ValueError(
273
+ f"Size must be an integer or (height, width), got {size}.")
274
+ return size
275
+
276
+
277
+ class HunyuanVideoSampler(Inference):
278
+
279
+ def __init__(
280
+ self,
281
+ args,
282
+ vae,
283
+ vae_kwargs,
284
+ text_encoder,
285
+ model,
286
+ text_encoder_2=None,
287
+ pipeline=None,
288
+ use_cpu_offload=False,
289
+ device=0,
290
+ logger=None,
291
+ parallel_args=None,
292
+ ):
293
+ super().__init__(
294
+ args,
295
+ vae,
296
+ vae_kwargs,
297
+ text_encoder,
298
+ model,
299
+ text_encoder_2=text_encoder_2,
300
+ pipeline=pipeline,
301
+ use_cpu_offload=use_cpu_offload,
302
+ device=device,
303
+ logger=logger,
304
+ parallel_args=parallel_args,
305
+ )
306
+
307
+ self.pipeline = self.load_diffusion_pipeline(
308
+ args=args,
309
+ vae=self.vae,
310
+ text_encoder=self.text_encoder,
311
+ text_encoder_2=self.text_encoder_2,
312
+ model=self.model,
313
+ device=self.device,
314
+ )
315
+
316
+ self.default_negative_prompt = NEGATIVE_PROMPT
317
+
318
+ def load_diffusion_pipeline(
319
+ self,
320
+ args,
321
+ vae,
322
+ text_encoder,
323
+ text_encoder_2,
324
+ model,
325
+ scheduler=None,
326
+ device=None,
327
+ progress_bar_config=None,
328
+ data_type="video",
329
+ ):
330
+ """Load the denoising scheduler for inference."""
331
+ if scheduler is None:
332
+ if args.denoise_type == "flow":
333
+ scheduler = FlowMatchDiscreteScheduler(
334
+ shift=args.flow_shift,
335
+ reverse=args.flow_reverse,
336
+ solver=args.flow_solver,
337
+ )
338
+ else:
339
+ raise ValueError(f"Invalid denoise type {args.denoise_type}")
340
+
341
+ pipeline = HunyuanVideoPipeline(
342
+ vae=vae,
343
+ text_encoder=text_encoder,
344
+ text_encoder_2=text_encoder_2,
345
+ transformer=model,
346
+ scheduler=scheduler,
347
+ progress_bar_config=progress_bar_config,
348
+ args=args,
349
+ )
350
+ if self.use_cpu_offload:
351
+ pipeline.enable_sequential_cpu_offload()
352
+ else:
353
+ pipeline = pipeline.to(device)
354
+
355
+ return pipeline
356
+
357
+ @torch.no_grad()
358
+ def predict(
359
+ self,
360
+ prompt,
361
+ height=192,
362
+ width=336,
363
+ video_length=129,
364
+ seed=None,
365
+ negative_prompt=None,
366
+ infer_steps=50,
367
+ guidance_scale=6,
368
+ flow_shift=5.0,
369
+ embedded_guidance_scale=None,
370
+ batch_size=1,
371
+ num_videos_per_prompt=1,
372
+ **kwargs,
373
+ ):
374
+ """
375
+ Predict the image/video from the given text.
376
+
377
+ Args:
378
+ prompt (str or List[str]): The input text.
379
+ kwargs:
380
+ height (int): The height of the output video. Default is 192.
381
+ width (int): The width of the output video. Default is 336.
382
+ video_length (int): The frame number of the output video. Default is 129.
383
+ seed (int or List[str]): The random seed for the generation. Default is a random integer.
384
+ negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
385
+ guidance_scale (float): The guidance scale for the generation. Default is 6.0.
386
+ num_images_per_prompt (int): The number of images per prompt. Default is 1.
387
+ infer_steps (int): The number of inference steps. Default is 100.
388
+ """
389
+
390
+ out_dict = dict()
391
+
392
+ # ========================================================================
393
+ # Arguments: seed
394
+ # ========================================================================
395
+ if isinstance(seed, torch.Tensor):
396
+ seed = seed.tolist()
397
+ if seed is None:
398
+ seeds = [
399
+ random.randint(0, 1_000_000)
400
+ for _ in range(batch_size * num_videos_per_prompt)
401
+ ]
402
+ elif isinstance(seed, int):
403
+ seeds = [
404
+ seed + i for _ in range(batch_size)
405
+ for i in range(num_videos_per_prompt)
406
+ ]
407
+ elif isinstance(seed, (list, tuple)):
408
+ if len(seed) == batch_size:
409
+ seeds = [
410
+ int(seed[i]) + j for i in range(batch_size)
411
+ for j in range(num_videos_per_prompt)
412
+ ]
413
+ elif len(seed) == batch_size * num_videos_per_prompt:
414
+ seeds = [int(s) for s in seed]
415
+ else:
416
+ raise ValueError(
417
+ f"Length of seed must be equal to number of prompt(batch_size) or "
418
+ f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
419
+ )
420
+ else:
421
+ raise ValueError(
422
+ f"Seed must be an integer, a list of integers, or None, got {seed}."
423
+ )
424
+ # Peiyuan: using GPU seed will cause A100 and H100 to generate different results...
425
+ generator = [
426
+ torch.Generator("cpu").manual_seed(seed) for seed in seeds
427
+ ]
428
+ out_dict["seeds"] = seeds
429
+
430
+ # ========================================================================
431
+ # Arguments: target_width, target_height, target_video_length
432
+ # ========================================================================
433
+ if width <= 0 or height <= 0 or video_length <= 0:
434
+ raise ValueError(
435
+ f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
436
+ )
437
+ if (video_length - 1) % 4 != 0:
438
+ raise ValueError(
439
+ f"`video_length-1` must be a multiple of 4, got {video_length}"
440
+ )
441
+
442
+ logger.info(
443
+ f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
444
+ )
445
+
446
+ target_height = align_to(height, 16)
447
+ target_width = align_to(width, 16)
448
+ target_video_length = video_length
449
+
450
+ out_dict["size"] = (target_height, target_width, target_video_length)
451
+
452
+ # ========================================================================
453
+ # Arguments: prompt, new_prompt, negative_prompt
454
+ # ========================================================================
455
+ if not isinstance(prompt, str):
456
+ raise TypeError(
457
+ f"`prompt` must be a string, but got {type(prompt)}")
458
+ prompt = [prompt.strip()]
459
+
460
+ # negative prompt
461
+ if negative_prompt is None or negative_prompt == "":
462
+ negative_prompt = self.default_negative_prompt
463
+ if not isinstance(negative_prompt, str):
464
+ raise TypeError(
465
+ f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
466
+ )
467
+ negative_prompt = [negative_prompt.strip()]
468
+
469
+ # ========================================================================
470
+ # Scheduler
471
+ # ========================================================================
472
+ scheduler = FlowMatchDiscreteScheduler(
473
+ shift=flow_shift,
474
+ reverse=self.args.flow_reverse,
475
+ solver=self.args.flow_solver,
476
+ )
477
+ self.pipeline.scheduler = scheduler
478
+
479
+ if "884" in self.args.vae:
480
+ latents_size = [(video_length - 1) // 4 + 1, height // 8,
481
+ width // 8]
482
+ elif "888" in self.args.vae:
483
+ latents_size = [(video_length - 1) // 8 + 1, height // 8,
484
+ width // 8]
485
+ n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
486
+
487
+ # ========================================================================
488
+ # Print infer args
489
+ # ========================================================================
490
+ debug_str = f"""
491
+ height: {target_height}
492
+ width: {target_width}
493
+ video_length: {target_video_length}
494
+ prompt: {prompt}
495
+ neg_prompt: {negative_prompt}
496
+ seed: {seed}
497
+ infer_steps: {infer_steps}
498
+ num_videos_per_prompt: {num_videos_per_prompt}
499
+ guidance_scale: {guidance_scale}
500
+ n_tokens: {n_tokens}
501
+ flow_shift: {flow_shift}
502
+ embedded_guidance_scale: {embedded_guidance_scale}"""
503
+ logger.debug(debug_str)
504
+
505
+ # ========================================================================
506
+ # Pipeline inference
507
+ # ========================================================================
508
+ start_time = time.time()
509
+ samples = self.pipeline(
510
+ prompt=prompt,
511
+ height=target_height,
512
+ width=target_width,
513
+ video_length=target_video_length,
514
+ num_inference_steps=infer_steps,
515
+ guidance_scale=guidance_scale,
516
+ negative_prompt=negative_prompt,
517
+ num_videos_per_prompt=num_videos_per_prompt,
518
+ generator=generator,
519
+ output_type="pil",
520
+ n_tokens=n_tokens,
521
+ embedded_guidance_scale=embedded_guidance_scale,
522
+ data_type="video" if target_video_length > 1 else "image",
523
+ is_progress_bar=True,
524
+ vae_ver=self.args.vae,
525
+ enable_tiling=self.args.vae_tiling,
526
+ enable_vae_sp=self.args.vae_sp,
527
+ )[0]
528
+ out_dict["samples"] = samples
529
+ out_dict["prompts"] = prompt
530
+
531
+ gen_time = time.time() - start_time
532
+ logger.info(f"Success, time: {gen_time}")
533
+
534
+ return out_dict
fastvideo/models/hunyuan/modules/mlp_layers.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..utils.helpers import to_2tuple
10
+ from .modulate_layers import modulate
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d,
36
+ kernel_size=1) if use_conv else nn.Linear
37
+
38
+ self.fc1 = linear_layer(in_channels,
39
+ hidden_channels,
40
+ bias=bias[0],
41
+ **factory_kwargs)
42
+ self.act = act_layer()
43
+ self.drop1 = nn.Dropout(drop_probs[0])
44
+ self.norm = (norm_layer(hidden_channels, **factory_kwargs)
45
+ if norm_layer is not None else nn.Identity())
46
+ self.fc2 = linear_layer(hidden_channels,
47
+ out_features,
48
+ bias=bias[1],
49
+ **factory_kwargs)
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+
66
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.in_layer = nn.Linear(in_dim,
70
+ hidden_dim,
71
+ bias=True,
72
+ **factory_kwargs)
73
+ self.silu = nn.SiLU()
74
+ self.out_layer = nn.Linear(hidden_dim,
75
+ hidden_dim,
76
+ bias=True,
77
+ **factory_kwargs)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ return self.out_layer(self.silu(self.in_layer(x)))
81
+
82
+
83
+ class FinalLayer(nn.Module):
84
+ """The final layer of DiT."""
85
+
86
+ def __init__(self,
87
+ hidden_size,
88
+ patch_size,
89
+ out_channels,
90
+ act_layer,
91
+ device=None,
92
+ dtype=None):
93
+ factory_kwargs = {"device": device, "dtype": dtype}
94
+ super().__init__()
95
+
96
+ # Just use LayerNorm for the final layer
97
+ self.norm_final = nn.LayerNorm(hidden_size,
98
+ elementwise_affine=False,
99
+ eps=1e-6,
100
+ **factory_kwargs)
101
+ if isinstance(patch_size, int):
102
+ self.linear = nn.Linear(
103
+ hidden_size,
104
+ patch_size * patch_size * out_channels,
105
+ bias=True,
106
+ **factory_kwargs,
107
+ )
108
+ else:
109
+ self.linear = nn.Linear(
110
+ hidden_size,
111
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
112
+ bias=True,
113
+ )
114
+ nn.init.zeros_(self.linear.weight)
115
+ nn.init.zeros_(self.linear.bias)
116
+
117
+ # Here we don't distinguish between the modulate types. Just use the simple one.
118
+ self.adaLN_modulation = nn.Sequential(
119
+ act_layer(),
120
+ nn.Linear(hidden_size,
121
+ 2 * hidden_size,
122
+ bias=True,
123
+ **factory_kwargs),
124
+ )
125
+ # Zero-initialize the modulation
126
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
127
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
128
+
129
+ def forward(self, x, c):
130
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
131
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
132
+ x = self.linear(x)
133
+ return x
fastvideo/models/hunyuan/prompt_rewrite.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ normal_mode_prompt = """Normal mode - Video Recaption Task:
2
+
3
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4
+
5
+ 0. Preserve ALL information, including style words and technical terms.
6
+
7
+ 1. If the input is in Chinese, translate the entire description to English.
8
+
9
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10
+
11
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12
+
13
+ 4. Output ALL must be in English.
14
+
15
+ Given Input:
16
+ input: "{input}"
17
+ """
18
+
19
+ master_mode_prompt = """Master mode - Video Recaption Task:
20
+
21
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
22
+
23
+ 0. Preserve ALL information, including style words and technical terms.
24
+
25
+ 1. If the input is in Chinese, translate the entire description to English.
26
+
27
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
28
+
29
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
30
+
31
+ 4. Output ALL must be in English.
32
+
33
+ Given Input:
34
+ input: "{input}"
35
+ """
36
+
37
+
38
+ def get_rewrite_prompt(ori_prompt, mode="Normal"):
39
+ if mode == "Normal":
40
+ prompt = normal_mode_prompt.format(input=ori_prompt)
41
+ elif mode == "Master":
42
+ prompt = master_mode_prompt.format(input=ori_prompt)
43
+ else:
44
+ raise Exception("Only supports Normal and Normal", mode)
45
+ return prompt
46
+
47
+
48
+ ori_prompt = "一只小狗在草地上奔跑。"
49
+ normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
50
+ master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
51
+
52
+ # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
fastvideo/models/hunyuan_hf/__pycache__/modeling_hunyuan.cpython-310.pyc ADDED
Binary file (24.7 kB). View file
 
fastvideo/models/hunyuan_hf/__pycache__/modeling_hunyuan.cpython-312.pyc ADDED
Binary file (39.7 kB). View file
 
fastvideo/models/hunyuan_hf/modeling_hunyuan.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
22
+ from diffusers.models.attention import FeedForward
23
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
24
+ from diffusers.models.embeddings import (
25
+ CombinedTimestepGuidanceTextProjEmbeddings,
26
+ CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
27
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.normalization import (AdaLayerNormContinuous,
30
+ AdaLayerNormZero,
31
+ AdaLayerNormZeroSingle)
32
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
33
+ scale_lora_layers, unscale_lora_layers)
34
+
35
+ from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad
36
+ from fastvideo.utils.communications import all_gather, all_to_all_4D
37
+ from fastvideo.utils.parallel_states import (get_sequence_parallel_state,
38
+ nccl_info)
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def shrink_head(encoder_state, dim):
44
+ local_heads = encoder_state.shape[dim] // nccl_info.sp_size
45
+ return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads,
46
+ local_heads)
47
+
48
+
49
+ class HunyuanVideoAttnProcessor2_0:
50
+
51
+ def __init__(self):
52
+ if not hasattr(F, "scaled_dot_product_attention"):
53
+ raise ImportError(
54
+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
55
+ )
56
+
57
+ def __call__(
58
+ self,
59
+ attn: Attention,
60
+ hidden_states: torch.Tensor,
61
+ encoder_hidden_states: Optional[torch.Tensor] = None,
62
+ attention_mask: Optional[torch.Tensor] = None,
63
+ image_rotary_emb: Optional[torch.Tensor] = None,
64
+ ) -> torch.Tensor:
65
+
66
+ sequence_length = hidden_states.size(1)
67
+ encoder_sequence_length = encoder_hidden_states.size(1)
68
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
69
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states],
70
+ dim=1)
71
+
72
+ # 1. QKV projections
73
+ query = attn.to_q(hidden_states)
74
+ key = attn.to_k(hidden_states)
75
+ value = attn.to_v(hidden_states)
76
+
77
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
78
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
79
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
80
+
81
+ # 2. QK normalization
82
+ if attn.norm_q is not None:
83
+ query = attn.norm_q(query).to(value)
84
+ if attn.norm_k is not None:
85
+ key = attn.norm_k(key).to(value)
86
+
87
+ image_rotary_emb = (
88
+ shrink_head(image_rotary_emb[0], dim=0),
89
+ shrink_head(image_rotary_emb[1], dim=0),
90
+ )
91
+
92
+ # 3. Rotational positional embeddings applied to latent stream
93
+ if image_rotary_emb is not None:
94
+ from diffusers.models.embeddings import apply_rotary_emb
95
+
96
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
97
+ query = torch.cat(
98
+ [
99
+ apply_rotary_emb(
100
+ query[:, :, :-encoder_hidden_states.shape[1]],
101
+ image_rotary_emb),
102
+ query[:, :, -encoder_hidden_states.shape[1]:],
103
+ ],
104
+ dim=2,
105
+ )
106
+ key = torch.cat(
107
+ [
108
+ apply_rotary_emb(
109
+ key[:, :, :-encoder_hidden_states.shape[1]],
110
+ image_rotary_emb),
111
+ key[:, :, -encoder_hidden_states.shape[1]:],
112
+ ],
113
+ dim=2,
114
+ )
115
+ else:
116
+ query = apply_rotary_emb(query, image_rotary_emb)
117
+ key = apply_rotary_emb(key, image_rotary_emb)
118
+
119
+ # 4. Encoder condition QKV projection and normalization
120
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
121
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
122
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
123
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
124
+
125
+ encoder_query = encoder_query.unflatten(
126
+ 2, (attn.heads, -1)).transpose(1, 2)
127
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(
128
+ 1, 2)
129
+ encoder_value = encoder_value.unflatten(
130
+ 2, (attn.heads, -1)).transpose(1, 2)
131
+
132
+ if attn.norm_added_q is not None:
133
+ encoder_query = attn.norm_added_q(encoder_query).to(
134
+ encoder_value)
135
+ if attn.norm_added_k is not None:
136
+ encoder_key = attn.norm_added_k(encoder_key).to(encoder_value)
137
+
138
+ query = torch.cat([query, encoder_query], dim=2)
139
+ key = torch.cat([key, encoder_key], dim=2)
140
+ value = torch.cat([value, encoder_value], dim=2)
141
+
142
+ if get_sequence_parallel_state():
143
+ query_img, query_txt = query[:, :, :
144
+ sequence_length, :], query[:, :,
145
+ sequence_length:, :]
146
+ key_img, key_txt = key[:, :, :
147
+ sequence_length, :], key[:, :,
148
+ sequence_length:, :]
149
+ value_img, value_txt = value[:, :, :
150
+ sequence_length, :], value[:, :,
151
+ sequence_length:, :]
152
+ query_img = all_to_all_4D(query_img, scatter_dim=1,
153
+ gather_dim=2) #
154
+ key_img = all_to_all_4D(key_img, scatter_dim=1, gather_dim=2)
155
+ value_img = all_to_all_4D(value_img, scatter_dim=1, gather_dim=2)
156
+
157
+ query_txt = shrink_head(query_txt, dim=1)
158
+ key_txt = shrink_head(key_txt, dim=1)
159
+ value_txt = shrink_head(value_txt, dim=1)
160
+ query = torch.cat([query_img, query_txt], dim=2)
161
+ key = torch.cat([key_img, key_txt], dim=2)
162
+ value = torch.cat([value_img, value_txt], dim=2)
163
+
164
+ query = query.unsqueeze(2)
165
+ key = key.unsqueeze(2)
166
+ value = value.unsqueeze(2)
167
+ qkv = torch.cat([query, key, value], dim=2)
168
+ qkv = qkv.transpose(1, 3)
169
+
170
+ # 5. Attention
171
+ attention_mask = attention_mask[:, 0, :]
172
+ seq_len = qkv.shape[1]
173
+ attn_len = attention_mask.shape[1]
174
+ attention_mask = F.pad(attention_mask, (seq_len - attn_len, 0),
175
+ value=True)
176
+
177
+ hidden_states = flash_attn_no_pad(qkv,
178
+ attention_mask,
179
+ causal=False,
180
+ dropout_p=0.0,
181
+ softmax_scale=None)
182
+
183
+ if get_sequence_parallel_state():
184
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
185
+ (sequence_length * nccl_info.sp_size, encoder_sequence_length),
186
+ dim=1)
187
+ hidden_states = all_to_all_4D(hidden_states,
188
+ scatter_dim=1,
189
+ gather_dim=2)
190
+ encoder_hidden_states = all_gather(encoder_hidden_states,
191
+ dim=2).contiguous()
192
+ hidden_states = hidden_states.flatten(2, 3)
193
+ hidden_states = hidden_states.to(query.dtype)
194
+ encoder_hidden_states = encoder_hidden_states.flatten(2, 3)
195
+ encoder_hidden_states = encoder_hidden_states.to(query.dtype)
196
+ else:
197
+ hidden_states = hidden_states.flatten(2, 3)
198
+ hidden_states = hidden_states.to(query.dtype)
199
+
200
+ # 6. Output projection
201
+ if encoder_hidden_states is not None:
202
+ hidden_states, encoder_hidden_states = (
203
+ hidden_states[:, :-encoder_hidden_states.shape[1]],
204
+ hidden_states[:, -encoder_hidden_states.shape[1]:],
205
+ )
206
+
207
+ if encoder_hidden_states is not None:
208
+ if getattr(attn, "to_out", None) is not None:
209
+ hidden_states = attn.to_out[0](hidden_states)
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+
212
+ if getattr(attn, "to_add_out", None) is not None:
213
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
214
+
215
+ return hidden_states, encoder_hidden_states
216
+
217
+
218
+ class HunyuanVideoPatchEmbed(nn.Module):
219
+
220
+ def __init__(
221
+ self,
222
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
223
+ in_chans: int = 3,
224
+ embed_dim: int = 768,
225
+ ) -> None:
226
+ super().__init__()
227
+
228
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(
229
+ patch_size, int) else patch_size
230
+ self.proj = nn.Conv3d(in_chans,
231
+ embed_dim,
232
+ kernel_size=patch_size,
233
+ stride=patch_size)
234
+
235
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
236
+ hidden_states = self.proj(hidden_states)
237
+ hidden_states = hidden_states.flatten(2).transpose(1,
238
+ 2) # BCFHW -> BNC
239
+ return hidden_states
240
+
241
+
242
+ class HunyuanVideoAdaNorm(nn.Module):
243
+
244
+ def __init__(self,
245
+ in_features: int,
246
+ out_features: Optional[int] = None) -> None:
247
+ super().__init__()
248
+
249
+ out_features = out_features or 2 * in_features
250
+ self.linear = nn.Linear(in_features, out_features)
251
+ self.nonlinearity = nn.SiLU()
252
+
253
+ def forward(
254
+ self, temb: torch.Tensor
255
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
256
+ torch.Tensor]:
257
+ temb = self.linear(self.nonlinearity(temb))
258
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
259
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
260
+ return gate_msa, gate_mlp
261
+
262
+
263
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
264
+
265
+ def __init__(
266
+ self,
267
+ num_attention_heads: int,
268
+ attention_head_dim: int,
269
+ mlp_width_ratio: str = 4.0,
270
+ mlp_drop_rate: float = 0.0,
271
+ attention_bias: bool = True,
272
+ ) -> None:
273
+ super().__init__()
274
+
275
+ hidden_size = num_attention_heads * attention_head_dim
276
+
277
+ self.norm1 = nn.LayerNorm(hidden_size,
278
+ elementwise_affine=True,
279
+ eps=1e-6)
280
+ self.attn = Attention(
281
+ query_dim=hidden_size,
282
+ cross_attention_dim=None,
283
+ heads=num_attention_heads,
284
+ dim_head=attention_head_dim,
285
+ bias=attention_bias,
286
+ )
287
+
288
+ self.norm2 = nn.LayerNorm(hidden_size,
289
+ elementwise_affine=True,
290
+ eps=1e-6)
291
+ self.ff = FeedForward(hidden_size,
292
+ mult=mlp_width_ratio,
293
+ activation_fn="linear-silu",
294
+ dropout=mlp_drop_rate)
295
+
296
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ temb: torch.Tensor,
302
+ attention_mask: Optional[torch.Tensor] = None,
303
+ ) -> torch.Tensor:
304
+ norm_hidden_states = self.norm1(hidden_states)
305
+
306
+ attn_output = self.attn(
307
+ hidden_states=norm_hidden_states,
308
+ encoder_hidden_states=None,
309
+ attention_mask=attention_mask,
310
+ )
311
+
312
+ gate_msa, gate_mlp = self.norm_out(temb)
313
+ hidden_states = hidden_states + attn_output * gate_msa
314
+
315
+ ff_output = self.ff(self.norm2(hidden_states))
316
+ hidden_states = hidden_states + ff_output * gate_mlp
317
+
318
+ return hidden_states
319
+
320
+
321
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
322
+
323
+ def __init__(
324
+ self,
325
+ num_attention_heads: int,
326
+ attention_head_dim: int,
327
+ num_layers: int,
328
+ mlp_width_ratio: float = 4.0,
329
+ mlp_drop_rate: float = 0.0,
330
+ attention_bias: bool = True,
331
+ ) -> None:
332
+ super().__init__()
333
+
334
+ self.refiner_blocks = nn.ModuleList([
335
+ HunyuanVideoIndividualTokenRefinerBlock(
336
+ num_attention_heads=num_attention_heads,
337
+ attention_head_dim=attention_head_dim,
338
+ mlp_width_ratio=mlp_width_ratio,
339
+ mlp_drop_rate=mlp_drop_rate,
340
+ attention_bias=attention_bias,
341
+ ) for _ in range(num_layers)
342
+ ])
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ temb: torch.Tensor,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ ) -> None:
350
+ self_attn_mask = None
351
+ if attention_mask is not None:
352
+ batch_size = attention_mask.shape[0]
353
+ seq_len = attention_mask.shape[1]
354
+ attention_mask = attention_mask.to(hidden_states.device).bool()
355
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1,
356
+ seq_len).repeat(
357
+ 1, 1, seq_len, 1)
358
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
359
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
360
+ self_attn_mask[:, :, :, 0] = True
361
+
362
+ for block in self.refiner_blocks:
363
+ hidden_states = block(hidden_states, temb, self_attn_mask)
364
+
365
+ return hidden_states
366
+
367
+
368
+ class HunyuanVideoTokenRefiner(nn.Module):
369
+
370
+ def __init__(
371
+ self,
372
+ in_channels: int,
373
+ num_attention_heads: int,
374
+ attention_head_dim: int,
375
+ num_layers: int,
376
+ mlp_ratio: float = 4.0,
377
+ mlp_drop_rate: float = 0.0,
378
+ attention_bias: bool = True,
379
+ ) -> None:
380
+ super().__init__()
381
+
382
+ hidden_size = num_attention_heads * attention_head_dim
383
+
384
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
385
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels)
386
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
387
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
388
+ num_attention_heads=num_attention_heads,
389
+ attention_head_dim=attention_head_dim,
390
+ num_layers=num_layers,
391
+ mlp_width_ratio=mlp_ratio,
392
+ mlp_drop_rate=mlp_drop_rate,
393
+ attention_bias=attention_bias,
394
+ )
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states: torch.Tensor,
399
+ timestep: torch.LongTensor,
400
+ attention_mask: Optional[torch.LongTensor] = None,
401
+ ) -> torch.Tensor:
402
+ if attention_mask is None:
403
+ pooled_projections = hidden_states.mean(dim=1)
404
+ else:
405
+ original_dtype = hidden_states.dtype
406
+ mask_float = attention_mask.float().unsqueeze(-1)
407
+ pooled_projections = (hidden_states * mask_float).sum(
408
+ dim=1) / mask_float.sum(dim=1)
409
+ pooled_projections = pooled_projections.to(original_dtype)
410
+
411
+ temb = self.time_text_embed(timestep, pooled_projections)
412
+ hidden_states = self.proj_in(hidden_states)
413
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
414
+
415
+ return hidden_states
416
+
417
+
418
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
419
+
420
+ def __init__(self,
421
+ patch_size: int,
422
+ patch_size_t: int,
423
+ rope_dim: List[int],
424
+ theta: float = 256.0) -> None:
425
+ super().__init__()
426
+
427
+ self.patch_size = patch_size
428
+ self.patch_size_t = patch_size_t
429
+ self.rope_dim = rope_dim
430
+ self.theta = theta
431
+
432
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
433
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
434
+ rope_sizes = [
435
+ num_frames * nccl_info.sp_size // self.patch_size_t,
436
+ height // self.patch_size, width // self.patch_size
437
+ ]
438
+
439
+ axes_grids = []
440
+ for i in range(3):
441
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
442
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
443
+ # differences in layerwise debugging outputs, but visually it is the same.
444
+ grid = torch.arange(0,
445
+ rope_sizes[i],
446
+ device=hidden_states.device,
447
+ dtype=torch.float32)
448
+ axes_grids.append(grid)
449
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
450
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
451
+
452
+ freqs = []
453
+ for i in range(3):
454
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i],
455
+ grid[i].reshape(-1),
456
+ self.theta,
457
+ use_real=True)
458
+ freqs.append(freq)
459
+
460
+ freqs_cos = torch.cat([f[0] for f in freqs],
461
+ dim=1) # (W * H * T, D / 2)
462
+ freqs_sin = torch.cat([f[1] for f in freqs],
463
+ dim=1) # (W * H * T, D / 2)
464
+ return freqs_cos, freqs_sin
465
+
466
+
467
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
468
+
469
+ def __init__(
470
+ self,
471
+ num_attention_heads: int,
472
+ attention_head_dim: int,
473
+ mlp_ratio: float = 4.0,
474
+ qk_norm: str = "rms_norm",
475
+ ) -> None:
476
+ super().__init__()
477
+
478
+ hidden_size = num_attention_heads * attention_head_dim
479
+ mlp_dim = int(hidden_size * mlp_ratio)
480
+
481
+ self.attn = Attention(
482
+ query_dim=hidden_size,
483
+ cross_attention_dim=None,
484
+ dim_head=attention_head_dim,
485
+ heads=num_attention_heads,
486
+ out_dim=hidden_size,
487
+ bias=True,
488
+ processor=HunyuanVideoAttnProcessor2_0(),
489
+ qk_norm=qk_norm,
490
+ eps=1e-6,
491
+ pre_only=True,
492
+ )
493
+
494
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
495
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
496
+ self.act_mlp = nn.GELU(approximate="tanh")
497
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ encoder_hidden_states: torch.Tensor,
503
+ temb: torch.Tensor,
504
+ attention_mask: Optional[torch.Tensor] = None,
505
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
506
+ ) -> torch.Tensor:
507
+ text_seq_length = encoder_hidden_states.shape[1]
508
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states],
509
+ dim=1)
510
+
511
+ residual = hidden_states
512
+
513
+ # 1. Input normalization
514
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
515
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
516
+
517
+ norm_hidden_states, norm_encoder_hidden_states = (
518
+ norm_hidden_states[:, :-text_seq_length, :],
519
+ norm_hidden_states[:, -text_seq_length:, :],
520
+ )
521
+
522
+ # 2. Attention
523
+ attn_output, context_attn_output = self.attn(
524
+ hidden_states=norm_hidden_states,
525
+ encoder_hidden_states=norm_encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ image_rotary_emb=image_rotary_emb,
528
+ )
529
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
530
+
531
+ # 3. Modulation and residual connection
532
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
533
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
534
+ hidden_states = hidden_states + residual
535
+
536
+ hidden_states, encoder_hidden_states = (
537
+ hidden_states[:, :-text_seq_length, :],
538
+ hidden_states[:, -text_seq_length:, :],
539
+ )
540
+ return hidden_states, encoder_hidden_states
541
+
542
+
543
+ class HunyuanVideoTransformerBlock(nn.Module):
544
+
545
+ def __init__(
546
+ self,
547
+ num_attention_heads: int,
548
+ attention_head_dim: int,
549
+ mlp_ratio: float,
550
+ qk_norm: str = "rms_norm",
551
+ ) -> None:
552
+ super().__init__()
553
+
554
+ hidden_size = num_attention_heads * attention_head_dim
555
+
556
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
557
+ self.norm1_context = AdaLayerNormZero(hidden_size,
558
+ norm_type="layer_norm")
559
+
560
+ self.attn = Attention(
561
+ query_dim=hidden_size,
562
+ cross_attention_dim=None,
563
+ added_kv_proj_dim=hidden_size,
564
+ dim_head=attention_head_dim,
565
+ heads=num_attention_heads,
566
+ out_dim=hidden_size,
567
+ context_pre_only=False,
568
+ bias=True,
569
+ processor=HunyuanVideoAttnProcessor2_0(),
570
+ qk_norm=qk_norm,
571
+ eps=1e-6,
572
+ )
573
+
574
+ self.norm2 = nn.LayerNorm(hidden_size,
575
+ elementwise_affine=False,
576
+ eps=1e-6)
577
+ self.ff = FeedForward(hidden_size,
578
+ mult=mlp_ratio,
579
+ activation_fn="gelu-approximate")
580
+
581
+ self.norm2_context = nn.LayerNorm(hidden_size,
582
+ elementwise_affine=False,
583
+ eps=1e-6)
584
+ self.ff_context = FeedForward(hidden_size,
585
+ mult=mlp_ratio,
586
+ activation_fn="gelu-approximate")
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.Tensor,
591
+ encoder_hidden_states: torch.Tensor,
592
+ temb: torch.Tensor,
593
+ attention_mask: Optional[torch.Tensor] = None,
594
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
595
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
596
+ # 1. Input normalization
597
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
598
+ hidden_states, emb=temb)
599
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
600
+ encoder_hidden_states, emb=temb)
601
+
602
+ # 2. Joint attention
603
+ attn_output, context_attn_output = self.attn(
604
+ hidden_states=norm_hidden_states,
605
+ encoder_hidden_states=norm_encoder_hidden_states,
606
+ attention_mask=attention_mask,
607
+ image_rotary_emb=freqs_cis,
608
+ )
609
+
610
+ # 3. Modulation and residual connection
611
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
612
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(
613
+ 1)
614
+
615
+ norm_hidden_states = self.norm2(hidden_states)
616
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
617
+
618
+ norm_hidden_states = norm_hidden_states * (
619
+ 1 + scale_mlp[:, None]) + shift_mlp[:, None]
620
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (
621
+ 1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
622
+
623
+ # 4. Feed-forward
624
+ ff_output = self.ff(norm_hidden_states)
625
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
626
+
627
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
628
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(
629
+ 1) * context_ff_output
630
+
631
+ return hidden_states, encoder_hidden_states
632
+
633
+
634
+ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
635
+ FromOriginalModelMixin):
636
+ r"""
637
+ A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
638
+
639
+ Args:
640
+ in_channels (`int`, defaults to `16`):
641
+ The number of channels in the input.
642
+ out_channels (`int`, defaults to `16`):
643
+ The number of channels in the output.
644
+ num_attention_heads (`int`, defaults to `24`):
645
+ The number of heads to use for multi-head attention.
646
+ attention_head_dim (`int`, defaults to `128`):
647
+ The number of channels in each head.
648
+ num_layers (`int`, defaults to `20`):
649
+ The number of layers of dual-stream blocks to use.
650
+ num_single_layers (`int`, defaults to `40`):
651
+ The number of layers of single-stream blocks to use.
652
+ num_refiner_layers (`int`, defaults to `2`):
653
+ The number of layers of refiner blocks to use.
654
+ mlp_ratio (`float`, defaults to `4.0`):
655
+ The ratio of the hidden layer size to the input size in the feedforward network.
656
+ patch_size (`int`, defaults to `2`):
657
+ The size of the spatial patches to use in the patch embedding layer.
658
+ patch_size_t (`int`, defaults to `1`):
659
+ The size of the tmeporal patches to use in the patch embedding layer.
660
+ qk_norm (`str`, defaults to `rms_norm`):
661
+ The normalization to use for the query and key projections in the attention layers.
662
+ guidance_embeds (`bool`, defaults to `True`):
663
+ Whether to use guidance embeddings in the model.
664
+ text_embed_dim (`int`, defaults to `4096`):
665
+ Input dimension of text embeddings from the text encoder.
666
+ pooled_projection_dim (`int`, defaults to `768`):
667
+ The dimension of the pooled projection of the text embeddings.
668
+ rope_theta (`float`, defaults to `256.0`):
669
+ The value of theta to use in the RoPE layer.
670
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
671
+ The dimensions of the axes to use in the RoPE layer.
672
+ """
673
+
674
+ _supports_gradient_checkpointing = True
675
+
676
+ @register_to_config
677
+ def __init__(
678
+ self,
679
+ in_channels: int = 16,
680
+ out_channels: int = 16,
681
+ num_attention_heads: int = 24,
682
+ attention_head_dim: int = 128,
683
+ num_layers: int = 20,
684
+ num_single_layers: int = 40,
685
+ num_refiner_layers: int = 2,
686
+ mlp_ratio: float = 4.0,
687
+ patch_size: int = 2,
688
+ patch_size_t: int = 1,
689
+ qk_norm: str = "rms_norm",
690
+ guidance_embeds: bool = True,
691
+ text_embed_dim: int = 4096,
692
+ pooled_projection_dim: int = 768,
693
+ rope_theta: float = 256.0,
694
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
695
+ ) -> None:
696
+ super().__init__()
697
+
698
+ inner_dim = num_attention_heads * attention_head_dim
699
+ out_channels = out_channels or in_channels
700
+
701
+ # 1. Latent and condition embedders
702
+ self.x_embedder = HunyuanVideoPatchEmbed(
703
+ (patch_size_t, patch_size, patch_size), in_channels, inner_dim)
704
+ self.context_embedder = HunyuanVideoTokenRefiner(
705
+ text_embed_dim,
706
+ num_attention_heads,
707
+ attention_head_dim,
708
+ num_layers=num_refiner_layers)
709
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(
710
+ inner_dim, pooled_projection_dim)
711
+
712
+ # 2. RoPE
713
+ self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t,
714
+ rope_axes_dim, rope_theta)
715
+
716
+ # 3. Dual stream transformer blocks
717
+ self.transformer_blocks = nn.ModuleList([
718
+ HunyuanVideoTransformerBlock(num_attention_heads,
719
+ attention_head_dim,
720
+ mlp_ratio=mlp_ratio,
721
+ qk_norm=qk_norm)
722
+ for _ in range(num_layers)
723
+ ])
724
+
725
+ # 4. Single stream transformer blocks
726
+ self.single_transformer_blocks = nn.ModuleList([
727
+ HunyuanVideoSingleTransformerBlock(num_attention_heads,
728
+ attention_head_dim,
729
+ mlp_ratio=mlp_ratio,
730
+ qk_norm=qk_norm)
731
+ for _ in range(num_single_layers)
732
+ ])
733
+
734
+ # 5. Output projection
735
+ self.norm_out = AdaLayerNormContinuous(inner_dim,
736
+ inner_dim,
737
+ elementwise_affine=False,
738
+ eps=1e-6)
739
+ self.proj_out = nn.Linear(
740
+ inner_dim, patch_size_t * patch_size * patch_size * out_channels)
741
+
742
+ self.gradient_checkpointing = False
743
+
744
+ @property
745
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
746
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
747
+ r"""
748
+ Returns:
749
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
750
+ indexed by its weight name.
751
+ """
752
+ # set recursively
753
+ processors = {}
754
+
755
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module,
756
+ processors: Dict[str,
757
+ AttentionProcessor]):
758
+ if hasattr(module, "get_processor"):
759
+ processors[f"{name}.processor"] = module.get_processor()
760
+
761
+ for sub_name, child in module.named_children():
762
+ fn_recursive_add_processors(f"{name}.{sub_name}", child,
763
+ processors)
764
+
765
+ return processors
766
+
767
+ for name, module in self.named_children():
768
+ fn_recursive_add_processors(name, module, processors)
769
+
770
+ return processors
771
+
772
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
773
+ def set_attn_processor(self, processor: Union[AttentionProcessor,
774
+ Dict[str,
775
+ AttentionProcessor]]):
776
+ r"""
777
+ Sets the attention processor to use to compute attention.
778
+
779
+ Parameters:
780
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
781
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
782
+ for **all** `Attention` layers.
783
+
784
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
785
+ processor. This is strongly recommended when setting trainable attention processors.
786
+
787
+ """
788
+ count = len(self.attn_processors.keys())
789
+
790
+ if isinstance(processor, dict) and len(processor) != count:
791
+ raise ValueError(
792
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
793
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
794
+ )
795
+
796
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
797
+ processor):
798
+ if hasattr(module, "set_processor"):
799
+ if not isinstance(processor, dict):
800
+ module.set_processor(processor)
801
+ else:
802
+ module.set_processor(processor.pop(f"{name}.processor"))
803
+
804
+ for sub_name, child in module.named_children():
805
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child,
806
+ processor)
807
+
808
+ for name, module in self.named_children():
809
+ fn_recursive_attn_processor(name, module, processor)
810
+
811
+ def _set_gradient_checkpointing(self, module, value=False):
812
+ if hasattr(module, "gradient_checkpointing"):
813
+ module.gradient_checkpointing = value
814
+
815
+ def forward(
816
+ self,
817
+ hidden_states: torch.Tensor,
818
+ encoder_hidden_states: torch.Tensor,
819
+ timestep: torch.LongTensor,
820
+ encoder_attention_mask: torch.Tensor,
821
+ guidance: torch.Tensor = None,
822
+ attention_kwargs: Optional[Dict[str, Any]] = None,
823
+ return_dict: bool = True,
824
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
825
+ if guidance is None:
826
+ guidance = torch.tensor([6016.0],
827
+ device=hidden_states.device,
828
+ dtype=torch.bfloat16)
829
+
830
+ if attention_kwargs is not None:
831
+ attention_kwargs = attention_kwargs.copy()
832
+ lora_scale = attention_kwargs.pop("scale", 1.0)
833
+ else:
834
+ lora_scale = 1.0
835
+
836
+ if USE_PEFT_BACKEND:
837
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
838
+ scale_lora_layers(self, lora_scale)
839
+ else:
840
+ if attention_kwargs is not None and attention_kwargs.get(
841
+ "scale", None) is not None:
842
+ logger.warning(
843
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
844
+ )
845
+
846
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
847
+ p, p_t = self.config.patch_size, self.config.patch_size_t
848
+ post_patch_num_frames = num_frames // p_t
849
+ post_patch_height = height // p
850
+ post_patch_width = width // p
851
+
852
+ pooled_projections = encoder_hidden_states[:, 0, :self.config.
853
+ pooled_projection_dim]
854
+ encoder_hidden_states = encoder_hidden_states[:, 1:]
855
+
856
+ # 1. RoPE
857
+ image_rotary_emb = self.rope(hidden_states)
858
+
859
+ # 2. Conditional embeddings
860
+ temb = self.time_text_embed(timestep, guidance, pooled_projections)
861
+ hidden_states = self.x_embedder(hidden_states)
862
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states,
863
+ timestep,
864
+ encoder_attention_mask)
865
+
866
+ # 3. Attention mask preparation
867
+ latent_sequence_length = hidden_states.shape[1]
868
+ condition_sequence_length = encoder_hidden_states.shape[1]
869
+ sequence_length = latent_sequence_length + condition_sequence_length
870
+ attention_mask = torch.zeros(batch_size,
871
+ sequence_length,
872
+ sequence_length,
873
+ device=hidden_states.device,
874
+ dtype=torch.bool) # [B, N, N]
875
+
876
+ effective_condition_sequence_length = encoder_attention_mask.sum(
877
+ dim=1, dtype=torch.int)
878
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
879
+
880
+ for i in range(batch_size):
881
+ attention_mask[i, :effective_sequence_length[i], :
882
+ effective_sequence_length[i]] = True
883
+
884
+ # 4. Transformer blocks
885
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
886
+
887
+ def create_custom_forward(module, return_dict=None):
888
+
889
+ def custom_forward(*inputs):
890
+ if return_dict is not None:
891
+ return module(*inputs, return_dict=return_dict)
892
+ else:
893
+ return module(*inputs)
894
+
895
+ return custom_forward
896
+
897
+ ckpt_kwargs: Dict[str, Any] = {
898
+ "use_reentrant": False
899
+ } if is_torch_version(">=", "1.11.0") else {}
900
+
901
+ for block in self.transformer_blocks:
902
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
903
+ create_custom_forward(block),
904
+ hidden_states,
905
+ encoder_hidden_states,
906
+ temb,
907
+ attention_mask,
908
+ image_rotary_emb,
909
+ **ckpt_kwargs,
910
+ )
911
+
912
+ for block in self.single_transformer_blocks:
913
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
914
+ create_custom_forward(block),
915
+ hidden_states,
916
+ encoder_hidden_states,
917
+ temb,
918
+ attention_mask,
919
+ image_rotary_emb,
920
+ **ckpt_kwargs,
921
+ )
922
+
923
+ else:
924
+ for block in self.transformer_blocks:
925
+ hidden_states, encoder_hidden_states = block(
926
+ hidden_states, encoder_hidden_states, temb, attention_mask,
927
+ image_rotary_emb)
928
+
929
+ for block in self.single_transformer_blocks:
930
+ hidden_states, encoder_hidden_states = block(
931
+ hidden_states, encoder_hidden_states, temb, attention_mask,
932
+ image_rotary_emb)
933
+
934
+ # 5. Output projection
935
+ hidden_states = self.norm_out(hidden_states, temb)
936
+ hidden_states = self.proj_out(hidden_states)
937
+
938
+ hidden_states = hidden_states.reshape(batch_size,
939
+ post_patch_num_frames,
940
+ post_patch_height,
941
+ post_patch_width, -1, p_t, p, p)
942
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
943
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
944
+
945
+ if USE_PEFT_BACKEND:
946
+ # remove `lora_scale` from each PEFT layer
947
+ unscale_lora_layers(self, lora_scale)
948
+
949
+ if not return_dict:
950
+ return (hidden_states, )
951
+
952
+ return Transformer2DModelOutput(sample=hidden_states)
fastvideo/models/hunyuan_hf/pipeline_hunyuan.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
23
+ from diffusers.models import (AutoencoderKLHunyuanVideo,
24
+ HunyuanVideoTransformer3DModel)
25
+ from diffusers.pipelines.hunyuan_video.pipeline_output import \
26
+ HunyuanVideoPipelineOutput
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from einops import rearrange
33
+ from transformers import (CLIPTextModel, CLIPTokenizer, LlamaModel,
34
+ LlamaTokenizerFast)
35
+
36
+ from fastvideo.utils.communications import all_gather
37
+ from fastvideo.utils.parallel_states import (get_sequence_parallel_state,
38
+ nccl_info)
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```python
45
+ >>> import torch
46
+ >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
47
+ >>> from diffusers.utils import export_to_video
48
+
49
+ >>> model_id = "tencent/HunyuanVideo"
50
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
51
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
52
+ ... )
53
+ >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
54
+ >>> pipe.vae.enable_tiling()
55
+ >>> pipe.to("cuda")
56
+
57
+ >>> output = pipe(
58
+ ... prompt="A cat walks on the grass, realistic",
59
+ ... height=320,
60
+ ... width=512,
61
+ ... num_frames=61,
62
+ ... num_inference_steps=30,
63
+ ... ).frames[0]
64
+ >>> export_to_video(output, "output.mp4", fps=15)
65
+ ```
66
+ """
67
+
68
+ DEFAULT_PROMPT_TEMPLATE = {
69
+ "template":
70
+ ("<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
71
+ "1. The main content and theme of the video."
72
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
73
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
74
+ "4. background environment, light, style and atmosphere."
75
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
76
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"),
77
+ "crop_start":
78
+ 95,
79
+ }
80
+
81
+
82
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
83
+ def retrieve_timesteps(
84
+ scheduler,
85
+ num_inference_steps: Optional[int] = None,
86
+ device: Optional[Union[str, torch.device]] = None,
87
+ timesteps: Optional[List[int]] = None,
88
+ sigmas: Optional[List[float]] = None,
89
+ **kwargs,
90
+ ):
91
+ r"""
92
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
93
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
94
+
95
+ Args:
96
+ scheduler (`SchedulerMixin`):
97
+ The scheduler to get timesteps from.
98
+ num_inference_steps (`int`):
99
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
100
+ must be `None`.
101
+ device (`str` or `torch.device`, *optional*):
102
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
103
+ timesteps (`List[int]`, *optional*):
104
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
105
+ `num_inference_steps` and `sigmas` must be `None`.
106
+ sigmas (`List[float]`, *optional*):
107
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
108
+ `num_inference_steps` and `timesteps` must be `None`.
109
+
110
+ Returns:
111
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
112
+ second element is the number of inference steps.
113
+ """
114
+ if timesteps is not None and sigmas is not None:
115
+ raise ValueError(
116
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
117
+ )
118
+ if timesteps is not None:
119
+ accepts_timesteps = "timesteps" in set(
120
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
121
+ if not accepts_timesteps:
122
+ raise ValueError(
123
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
124
+ f" timestep schedules. Please check whether you are using the correct scheduler."
125
+ )
126
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
127
+ timesteps = scheduler.timesteps
128
+ num_inference_steps = len(timesteps)
129
+ elif sigmas is not None:
130
+ accept_sigmas = "sigmas" in set(
131
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accept_sigmas:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ else:
141
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ return timesteps, num_inference_steps
144
+
145
+
146
+ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
147
+ r"""
148
+ Pipeline for text-to-video generation using HunyuanVideo.
149
+
150
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
151
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
152
+
153
+ Args:
154
+ text_encoder ([`LlamaModel`]):
155
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
156
+ tokenizer_2 (`LlamaTokenizer`):
157
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
158
+ transformer ([`HunyuanVideoTransformer3DModel`]):
159
+ Conditional Transformer to denoise the encoded image latents.
160
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
161
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
162
+ vae ([`AutoencoderKLHunyuanVideo`]):
163
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
164
+ text_encoder_2 ([`CLIPTextModel`]):
165
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
166
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
167
+ tokenizer_2 (`CLIPTokenizer`):
168
+ Tokenizer of class
169
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
170
+ """
171
+
172
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
173
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
174
+
175
+ def __init__(
176
+ self,
177
+ text_encoder: LlamaModel,
178
+ tokenizer: LlamaTokenizerFast,
179
+ transformer: HunyuanVideoTransformer3DModel,
180
+ vae: AutoencoderKLHunyuanVideo,
181
+ scheduler: FlowMatchEulerDiscreteScheduler,
182
+ text_encoder_2: CLIPTextModel,
183
+ tokenizer_2: CLIPTokenizer,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.register_modules(
188
+ vae=vae,
189
+ text_encoder=text_encoder,
190
+ tokenizer=tokenizer,
191
+ transformer=transformer,
192
+ scheduler=scheduler,
193
+ text_encoder_2=text_encoder_2,
194
+ tokenizer_2=tokenizer_2,
195
+ )
196
+
197
+ self.vae_scale_factor_temporal = (self.vae.temporal_compression_ratio
198
+ if hasattr(self, "vae")
199
+ and self.vae is not None else 4)
200
+ self.vae_scale_factor_spatial = (self.vae.spatial_compression_ratio
201
+ if hasattr(self, "vae")
202
+ and self.vae is not None else 8)
203
+ self.video_processor = VideoProcessor(
204
+ vae_scale_factor=self.vae_scale_factor_spatial)
205
+
206
+ def _get_llama_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]],
209
+ prompt_template: Dict[str, Any],
210
+ num_videos_per_prompt: int = 1,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ max_sequence_length: int = 256,
214
+ num_hidden_layers_to_skip: int = 2,
215
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ device = device or self._execution_device
217
+ dtype = dtype or self.text_encoder.dtype
218
+
219
+ prompt = [prompt] if isinstance(prompt, str) else prompt
220
+ batch_size = len(prompt)
221
+
222
+ prompt = [prompt_template["template"].format(p) for p in prompt]
223
+
224
+ crop_start = prompt_template.get("crop_start", None)
225
+ if crop_start is None:
226
+ prompt_template_input = self.tokenizer(
227
+ prompt_template["template"],
228
+ padding="max_length",
229
+ return_tensors="pt",
230
+ return_length=False,
231
+ return_overflowing_tokens=False,
232
+ return_attention_mask=False,
233
+ )
234
+ crop_start = prompt_template_input["input_ids"].shape[-1]
235
+ # Remove <|eot_id|> token and placeholder {}
236
+ crop_start -= 2
237
+
238
+ max_sequence_length += crop_start
239
+ text_inputs = self.tokenizer(
240
+ prompt,
241
+ max_length=max_sequence_length,
242
+ padding="max_length",
243
+ truncation=True,
244
+ return_tensors="pt",
245
+ return_length=False,
246
+ return_overflowing_tokens=False,
247
+ return_attention_mask=True,
248
+ )
249
+ text_input_ids = text_inputs.input_ids.to(device=device)
250
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
251
+
252
+ prompt_embeds = self.text_encoder(
253
+ input_ids=text_input_ids,
254
+ attention_mask=prompt_attention_mask,
255
+ output_hidden_states=True,
256
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
257
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
258
+
259
+ if crop_start is not None and crop_start > 0:
260
+ prompt_embeds = prompt_embeds[:, crop_start:]
261
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
262
+
263
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
264
+ _, seq_len, _ = prompt_embeds.shape
265
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt,
267
+ seq_len, -1)
268
+ prompt_attention_mask = prompt_attention_mask.repeat(
269
+ 1, num_videos_per_prompt)
270
+ prompt_attention_mask = prompt_attention_mask.view(
271
+ batch_size * num_videos_per_prompt, seq_len)
272
+
273
+ return prompt_embeds, prompt_attention_mask
274
+
275
+ def _get_clip_prompt_embeds(
276
+ self,
277
+ prompt: Union[str, List[str]],
278
+ num_videos_per_prompt: int = 1,
279
+ device: Optional[torch.device] = None,
280
+ dtype: Optional[torch.dtype] = None,
281
+ max_sequence_length: int = 77,
282
+ ) -> torch.Tensor:
283
+ device = device or self._execution_device
284
+ dtype = dtype or self.text_encoder_2.dtype
285
+
286
+ prompt = [prompt] if isinstance(prompt, str) else prompt
287
+ batch_size = len(prompt)
288
+
289
+ text_inputs = self.tokenizer_2(
290
+ prompt,
291
+ padding="max_length",
292
+ max_length=max_sequence_length,
293
+ truncation=True,
294
+ return_tensors="pt",
295
+ )
296
+
297
+ text_input_ids = text_inputs.input_ids
298
+ untruncated_ids = self.tokenizer_2(prompt,
299
+ padding="longest",
300
+ return_tensors="pt").input_ids
301
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
302
+ -1] and not torch.equal(text_input_ids, untruncated_ids):
303
+ removed_text = self.tokenizer_2.batch_decode(
304
+ untruncated_ids[:, max_sequence_length - 1:-1])
305
+ logger.warning(
306
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
307
+ f" {max_sequence_length} tokens: {removed_text}")
308
+
309
+ prompt_embeds = self.text_encoder_2(
310
+ text_input_ids.to(device),
311
+ output_hidden_states=False).pooler_output
312
+
313
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
314
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
315
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt,
316
+ -1)
317
+
318
+ return prompt_embeds
319
+
320
+ def encode_prompt(
321
+ self,
322
+ prompt: Union[str, List[str]],
323
+ prompt_2: Union[str, List[str]] = None,
324
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
325
+ num_videos_per_prompt: int = 1,
326
+ prompt_embeds: Optional[torch.Tensor] = None,
327
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
328
+ prompt_attention_mask: Optional[torch.Tensor] = None,
329
+ device: Optional[torch.device] = None,
330
+ dtype: Optional[torch.dtype] = None,
331
+ max_sequence_length: int = 256,
332
+ ):
333
+
334
+ if prompt_embeds is None:
335
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
336
+ prompt,
337
+ prompt_template,
338
+ num_videos_per_prompt,
339
+ device=device,
340
+ dtype=dtype,
341
+ max_sequence_length=max_sequence_length,
342
+ )
343
+
344
+ if pooled_prompt_embeds is None:
345
+ if prompt_2 is None and pooled_prompt_embeds is None:
346
+ prompt_2 = prompt
347
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
348
+ prompt,
349
+ num_videos_per_prompt,
350
+ device=device,
351
+ dtype=dtype,
352
+ max_sequence_length=77,
353
+ )
354
+
355
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
356
+
357
+ def check_inputs(
358
+ self,
359
+ prompt,
360
+ prompt_2,
361
+ height,
362
+ width,
363
+ prompt_embeds=None,
364
+ callback_on_step_end_tensor_inputs=None,
365
+ prompt_template=None,
366
+ ):
367
+ if height % 16 != 0 or width % 16 != 0:
368
+ raise ValueError(
369
+ f"`height` and `width` have to be divisible by 16 but are {height} and {width}."
370
+ )
371
+
372
+ if callback_on_step_end_tensor_inputs is not None and not all(
373
+ k in self._callback_tensor_inputs
374
+ for k in callback_on_step_end_tensor_inputs):
375
+ raise ValueError(
376
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
377
+ )
378
+
379
+ if prompt is not None and prompt_embeds is not None:
380
+ raise ValueError(
381
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
382
+ " only forward one of the two.")
383
+ elif prompt_2 is not None and prompt_embeds is not None:
384
+ raise ValueError(
385
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
386
+ " only forward one of the two.")
387
+ elif prompt is None and prompt_embeds is None:
388
+ raise ValueError(
389
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
390
+ )
391
+ elif prompt is not None and (not isinstance(prompt, str)
392
+ and not isinstance(prompt, list)):
393
+ raise ValueError(
394
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
395
+ )
396
+ elif prompt_2 is not None and (not isinstance(prompt_2, str)
397
+ and not isinstance(prompt_2, list)):
398
+ raise ValueError(
399
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
400
+ )
401
+
402
+ if prompt_template is not None:
403
+ if not isinstance(prompt_template, dict):
404
+ raise ValueError(
405
+ f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}"
406
+ )
407
+ if "template" not in prompt_template:
408
+ raise ValueError(
409
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
410
+ )
411
+
412
+ def prepare_latents(
413
+ self,
414
+ batch_size: int,
415
+ num_channels_latents: 32,
416
+ height: int = 720,
417
+ width: int = 1280,
418
+ num_frames: int = 129,
419
+ dtype: Optional[torch.dtype] = None,
420
+ device: Optional[torch.device] = None,
421
+ generator: Optional[Union[torch.Generator,
422
+ List[torch.Generator]]] = None,
423
+ latents: Optional[torch.Tensor] = None,
424
+ ) -> torch.Tensor:
425
+ if latents is not None:
426
+ return latents.to(device=device, dtype=dtype)
427
+
428
+ shape = (
429
+ batch_size,
430
+ num_channels_latents,
431
+ num_frames,
432
+ int(height) // self.vae_scale_factor_spatial,
433
+ int(width) // self.vae_scale_factor_spatial,
434
+ )
435
+ if isinstance(generator, list) and len(generator) != batch_size:
436
+ raise ValueError(
437
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
438
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
439
+ )
440
+
441
+ latents = randn_tensor(shape,
442
+ generator=generator,
443
+ device=device,
444
+ dtype=dtype)
445
+ return latents
446
+
447
+ def enable_vae_slicing(self):
448
+ r"""
449
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
450
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
451
+ """
452
+ self.vae.enable_slicing()
453
+
454
+ def disable_vae_slicing(self):
455
+ r"""
456
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
457
+ computing decoding in one step.
458
+ """
459
+ self.vae.disable_slicing()
460
+
461
+ def enable_vae_tiling(self):
462
+ r"""
463
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
464
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
465
+ processing larger images.
466
+ """
467
+ self.vae.enable_tiling()
468
+
469
+ def disable_vae_tiling(self):
470
+ r"""
471
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
472
+ computing decoding in one step.
473
+ """
474
+ self.vae.disable_tiling()
475
+
476
+ @property
477
+ def guidance_scale(self):
478
+ return self._guidance_scale
479
+
480
+ @property
481
+ def num_timesteps(self):
482
+ return self._num_timesteps
483
+
484
+ @property
485
+ def attention_kwargs(self):
486
+ return self._attention_kwargs
487
+
488
+ @property
489
+ def interrupt(self):
490
+ return self._interrupt
491
+
492
+ @torch.no_grad()
493
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
494
+ def __call__(
495
+ self,
496
+ prompt: Union[str, List[str]] = None,
497
+ prompt_2: Union[str, List[str]] = None,
498
+ height: int = 720,
499
+ width: int = 1280,
500
+ num_frames: int = 129,
501
+ num_inference_steps: int = 50,
502
+ sigmas: List[float] = None,
503
+ guidance_scale: float = 6.0,
504
+ num_videos_per_prompt: Optional[int] = 1,
505
+ generator: Optional[Union[torch.Generator,
506
+ List[torch.Generator]]] = None,
507
+ latents: Optional[torch.Tensor] = None,
508
+ prompt_embeds: Optional[torch.Tensor] = None,
509
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
510
+ prompt_attention_mask: Optional[torch.Tensor] = None,
511
+ output_type: Optional[str] = "pil",
512
+ return_dict: bool = True,
513
+ attention_kwargs: Optional[Dict[str, Any]] = None,
514
+ callback_on_step_end: Optional[Union[Callable[[int, int, Dict],
515
+ None], PipelineCallback,
516
+ MultiPipelineCallbacks]] = None,
517
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
518
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
519
+ max_sequence_length: int = 256,
520
+ ):
521
+ r"""
522
+ The call function to the pipeline for generation.
523
+
524
+ Args:
525
+ prompt (`str` or `List[str]`, *optional*):
526
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
527
+ instead.
528
+ prompt_2 (`str` or `List[str]`, *optional*):
529
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
530
+ will be used instead.
531
+ height (`int`, defaults to `720`):
532
+ The height in pixels of the generated image.
533
+ width (`int`, defaults to `1280`):
534
+ The width in pixels of the generated image.
535
+ num_frames (`int`, defaults to `129`):
536
+ The number of frames in the generated video.
537
+ num_inference_steps (`int`, defaults to `50`):
538
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
539
+ expense of slower inference.
540
+ sigmas (`List[float]`, *optional*):
541
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
542
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
543
+ will be used.
544
+ guidance_scale (`float`, defaults to `6.0`):
545
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
546
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
547
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
548
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
549
+ usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
550
+ CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
551
+ not applied.
552
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
553
+ The number of images to generate per prompt.
554
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
555
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
556
+ generation deterministic.
557
+ latents (`torch.Tensor`, *optional*):
558
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
559
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
560
+ tensor is generated by sampling using the supplied random `generator`.
561
+ prompt_embeds (`torch.Tensor`, *optional*):
562
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
563
+ provided, text embeddings are generated from the `prompt` input argument.
564
+ output_type (`str`, *optional*, defaults to `"pil"`):
565
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
566
+ return_dict (`bool`, *optional*, defaults to `True`):
567
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
568
+ attention_kwargs (`dict`, *optional*):
569
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
570
+ `self.processor` in
571
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
572
+ clip_skip (`int`, *optional*):
573
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
574
+ the output of the pre-final layer will be used for computing the prompt embeddings.
575
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
576
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
577
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
578
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
579
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
580
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
581
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
582
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
583
+ `._callback_tensor_inputs` attribute of your pipeline class.
584
+
585
+ Examples:
586
+
587
+ Returns:
588
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
589
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
590
+ where the first element is a list with the generated images and the second element is a list of `bool`s
591
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
592
+ """
593
+
594
+ if isinstance(callback_on_step_end,
595
+ (PipelineCallback, MultiPipelineCallbacks)):
596
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
597
+
598
+ # 1. Check inputs. Raise error if not correct
599
+ self.check_inputs(
600
+ prompt,
601
+ prompt_2,
602
+ height,
603
+ width,
604
+ prompt_embeds,
605
+ callback_on_step_end_tensor_inputs,
606
+ prompt_template,
607
+ )
608
+
609
+ self._guidance_scale = guidance_scale
610
+ self._attention_kwargs = attention_kwargs
611
+ self._interrupt = False
612
+
613
+ device = self._execution_device
614
+
615
+ # 2. Define call parameters
616
+ if prompt is not None and isinstance(prompt, str):
617
+ batch_size = 1
618
+ elif prompt is not None and isinstance(prompt, list):
619
+ batch_size = len(prompt)
620
+ else:
621
+ batch_size = prompt_embeds.shape[0]
622
+
623
+ # 3. Encode input prompt
624
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
625
+ prompt=prompt,
626
+ prompt_2=prompt,
627
+ prompt_template=prompt_template,
628
+ num_videos_per_prompt=num_videos_per_prompt,
629
+ prompt_embeds=prompt_embeds,
630
+ pooled_prompt_embeds=pooled_prompt_embeds,
631
+ prompt_attention_mask=prompt_attention_mask,
632
+ device=device,
633
+ max_sequence_length=max_sequence_length,
634
+ )
635
+
636
+ transformer_dtype = self.transformer.dtype
637
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
638
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
639
+ if pooled_prompt_embeds is not None:
640
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
641
+
642
+ # 4. Prepare timesteps
643
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps +
644
+ 1)[:-1] if sigmas is None else sigmas
645
+ timesteps, num_inference_steps = retrieve_timesteps(
646
+ self.scheduler,
647
+ num_inference_steps,
648
+ device,
649
+ sigmas=sigmas,
650
+ )
651
+
652
+ # 5. Prepare latent variables
653
+ num_channels_latents = self.transformer.config.in_channels
654
+ num_latent_frames = (num_frames -
655
+ 1) // self.vae_scale_factor_temporal + 1
656
+
657
+ latents = self.prepare_latents(
658
+ batch_size * num_videos_per_prompt,
659
+ num_channels_latents,
660
+ height,
661
+ width,
662
+ num_latent_frames,
663
+ torch.float32,
664
+ device,
665
+ generator,
666
+ latents,
667
+ )
668
+ # check sequence_parallel
669
+ world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
670
+ if get_sequence_parallel_state():
671
+ latents = rearrange(latents,
672
+ "b t (n s) h w -> b t n s h w",
673
+ n=world_size).contiguous()
674
+ latents = latents[:, :, rank, :, :, :]
675
+
676
+ # 6. Prepare guidance condition
677
+ guidance = torch.tensor([guidance_scale] * latents.shape[0],
678
+ dtype=transformer_dtype,
679
+ device=device) * 1000.0
680
+
681
+ # 7. Denoising loop
682
+ num_warmup_steps = len(
683
+ timesteps) - num_inference_steps * self.scheduler.order
684
+ self._num_timesteps = len(timesteps)
685
+
686
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
687
+ for i, t in enumerate(timesteps):
688
+ if self.interrupt:
689
+ continue
690
+
691
+ latent_model_input = latents.to(transformer_dtype)
692
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
693
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
694
+ if pooled_prompt_embeds.shape[-1] != prompt_embeds.shape[-1]:
695
+ pooled_prompt_embeds_padding = F.pad(
696
+ pooled_prompt_embeds,
697
+ (0, prompt_embeds.shape[2] -
698
+ pooled_prompt_embeds.shape[1]),
699
+ value=0,
700
+ ).unsqueeze(1)
701
+ encoder_hidden_states = torch.cat(
702
+ [pooled_prompt_embeds_padding, prompt_embeds], dim=1)
703
+
704
+ noise_pred = self.transformer(
705
+ hidden_states=latent_model_input,
706
+ encoder_hidden_states=
707
+ encoder_hidden_states, # [1, 257, 4096]
708
+ timestep=timestep,
709
+ encoder_attention_mask=prompt_attention_mask,
710
+ guidance=guidance,
711
+ attention_kwargs=attention_kwargs,
712
+ return_dict=False,
713
+ )[0]
714
+
715
+ # compute the previous noisy sample x_t -> x_t-1
716
+ latents = self.scheduler.step(noise_pred,
717
+ t,
718
+ latents,
719
+ return_dict=False)[0]
720
+
721
+ if callback_on_step_end is not None:
722
+ callback_kwargs = {}
723
+ for k in callback_on_step_end_tensor_inputs:
724
+ callback_kwargs[k] = locals()[k]
725
+ callback_outputs = callback_on_step_end(
726
+ self, i, t, callback_kwargs)
727
+
728
+ latents = callback_outputs.pop("latents", latents)
729
+ prompt_embeds = callback_outputs.pop(
730
+ "prompt_embeds", prompt_embeds)
731
+
732
+ # call the callback, if provided
733
+ if i == len(timesteps) - 1 or (
734
+ (i + 1) > num_warmup_steps and
735
+ (i + 1) % self.scheduler.order == 0):
736
+ progress_bar.update()
737
+
738
+ if get_sequence_parallel_state():
739
+ latents = all_gather(latents, dim=2)
740
+
741
+ if not output_type == "latent":
742
+ latents = latents.to(
743
+ self.vae.dtype) / self.vae.config.scaling_factor
744
+ video = self.vae.decode(latents, return_dict=False)[0]
745
+ video = self.video_processor.postprocess_video(
746
+ video, output_type=output_type)
747
+ else:
748
+ video = latents
749
+
750
+ # Offload all models
751
+ self.maybe_free_model_hooks()
752
+
753
+ if not return_dict:
754
+ return (video, )
755
+
756
+ return HunyuanVideoPipelineOutput(frames=video)
fastvideo/models/mochi_hf/__pycache__/modeling_mochi.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
fastvideo/models/mochi_hf/__pycache__/modeling_mochi.cpython-312.pyc ADDED
Binary file (29.4 kB). View file
 
fastvideo/models/mochi_hf/__pycache__/norm.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
fastvideo/models/mochi_hf/__pycache__/norm.cpython-312.pyc ADDED
Binary file (6.37 kB). View file
 
fastvideo/models/mochi_hf/__pycache__/pipeline_mochi.cpython-312.pyc ADDED
Binary file (36.2 kB). View file
 
fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import argparse
4
+ import os
5
+
6
+ import torch
7
+ from safetensors.torch import save_file
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--diffusers_path", required=True, type=str)
11
+ parser.add_argument("--transformer_path",
12
+ type=str,
13
+ default=None,
14
+ help="Path to save transformer model")
15
+ parser.add_argument("--vae_encoder_path",
16
+ type=str,
17
+ default=None,
18
+ help="Path to save VAE encoder model")
19
+ parser.add_argument("--vae_decoder_path",
20
+ type=str,
21
+ default=None,
22
+ help="Path to save VAE decoder model")
23
+
24
+ args = parser.parse_args()
25
+
26
+
27
+ def reverse_scale_shift(weight, dim):
28
+ scale, shift = weight.chunk(2, dim=0)
29
+ new_weight = torch.cat([shift, scale], dim=0)
30
+ return new_weight
31
+
32
+
33
+ def reverse_proj_gate(weight):
34
+ gate, proj = weight.chunk(2, dim=0)
35
+ new_weight = torch.cat([proj, gate], dim=0)
36
+ return new_weight
37
+
38
+
39
+ def convert_diffusers_transformer_to_mochi(state_dict):
40
+ original_state_dict = state_dict.copy()
41
+ new_state_dict = {}
42
+
43
+ # Convert patch_embed
44
+ new_state_dict["x_embedder.proj.weight"] = original_state_dict.pop(
45
+ "patch_embed.proj.weight")
46
+ new_state_dict["x_embedder.proj.bias"] = original_state_dict.pop(
47
+ "patch_embed.proj.bias")
48
+
49
+ # Convert time_embed
50
+ new_state_dict["t_embedder.mlp.0.weight"] = original_state_dict.pop(
51
+ "time_embed.timestep_embedder.linear_1.weight")
52
+ new_state_dict["t_embedder.mlp.0.bias"] = original_state_dict.pop(
53
+ "time_embed.timestep_embedder.linear_1.bias")
54
+ new_state_dict["t_embedder.mlp.2.weight"] = original_state_dict.pop(
55
+ "time_embed.timestep_embedder.linear_2.weight")
56
+ new_state_dict["t_embedder.mlp.2.bias"] = original_state_dict.pop(
57
+ "time_embed.timestep_embedder.linear_2.bias")
58
+ new_state_dict["t5_y_embedder.to_kv.weight"] = original_state_dict.pop(
59
+ "time_embed.pooler.to_kv.weight")
60
+ new_state_dict["t5_y_embedder.to_kv.bias"] = original_state_dict.pop(
61
+ "time_embed.pooler.to_kv.bias")
62
+ new_state_dict["t5_y_embedder.to_q.weight"] = original_state_dict.pop(
63
+ "time_embed.pooler.to_q.weight")
64
+ new_state_dict["t5_y_embedder.to_q.bias"] = original_state_dict.pop(
65
+ "time_embed.pooler.to_q.bias")
66
+ new_state_dict["t5_y_embedder.to_out.weight"] = original_state_dict.pop(
67
+ "time_embed.pooler.to_out.weight")
68
+ new_state_dict["t5_y_embedder.to_out.bias"] = original_state_dict.pop(
69
+ "time_embed.pooler.to_out.bias")
70
+ new_state_dict["t5_yproj.weight"] = original_state_dict.pop(
71
+ "time_embed.caption_proj.weight")
72
+ new_state_dict["t5_yproj.bias"] = original_state_dict.pop(
73
+ "time_embed.caption_proj.bias")
74
+
75
+ # Convert transformer blocks
76
+ num_layers = 48
77
+ for i in range(num_layers):
78
+ block_prefix = f"transformer_blocks.{i}."
79
+ new_prefix = f"blocks.{i}."
80
+
81
+ # norm1
82
+ new_state_dict[new_prefix + "mod_x.weight"] = original_state_dict.pop(
83
+ block_prefix + "norm1.linear.weight")
84
+ new_state_dict[new_prefix + "mod_x.bias"] = original_state_dict.pop(
85
+ block_prefix + "norm1.linear.bias")
86
+
87
+ if i < num_layers - 1:
88
+ new_state_dict[new_prefix +
89
+ "mod_y.weight"] = original_state_dict.pop(
90
+ block_prefix + "norm1_context.linear.weight")
91
+ new_state_dict[new_prefix +
92
+ "mod_y.bias"] = original_state_dict.pop(
93
+ block_prefix + "norm1_context.linear.bias")
94
+ else:
95
+ new_state_dict[new_prefix +
96
+ "mod_y.weight"] = original_state_dict.pop(
97
+ block_prefix + "norm1_context.linear_1.weight")
98
+ new_state_dict[new_prefix +
99
+ "mod_y.bias"] = original_state_dict.pop(
100
+ block_prefix + "norm1_context.linear_1.bias")
101
+
102
+ # Visual attention
103
+ q = original_state_dict.pop(block_prefix + "attn1.to_q.weight")
104
+ k = original_state_dict.pop(block_prefix + "attn1.to_k.weight")
105
+ v = original_state_dict.pop(block_prefix + "attn1.to_v.weight")
106
+ qkv_weight = torch.cat([q, k, v], dim=0)
107
+ new_state_dict[new_prefix + "attn.qkv_x.weight"] = qkv_weight
108
+
109
+ new_state_dict[new_prefix +
110
+ "attn.q_norm_x.weight"] = original_state_dict.pop(
111
+ block_prefix + "attn1.norm_q.weight")
112
+ new_state_dict[new_prefix +
113
+ "attn.k_norm_x.weight"] = original_state_dict.pop(
114
+ block_prefix + "attn1.norm_k.weight")
115
+ new_state_dict[new_prefix +
116
+ "attn.proj_x.weight"] = original_state_dict.pop(
117
+ block_prefix + "attn1.to_out.0.weight")
118
+ new_state_dict[new_prefix +
119
+ "attn.proj_x.bias"] = original_state_dict.pop(
120
+ block_prefix + "attn1.to_out.0.bias")
121
+
122
+ # Context attention
123
+ q = original_state_dict.pop(block_prefix + "attn1.add_q_proj.weight")
124
+ k = original_state_dict.pop(block_prefix + "attn1.add_k_proj.weight")
125
+ v = original_state_dict.pop(block_prefix + "attn1.add_v_proj.weight")
126
+ qkv_weight = torch.cat([q, k, v], dim=0)
127
+ new_state_dict[new_prefix + "attn.qkv_y.weight"] = qkv_weight
128
+
129
+ new_state_dict[new_prefix +
130
+ "attn.q_norm_y.weight"] = original_state_dict.pop(
131
+ block_prefix + "attn1.norm_added_q.weight")
132
+ new_state_dict[new_prefix +
133
+ "attn.k_norm_y.weight"] = original_state_dict.pop(
134
+ block_prefix + "attn1.norm_added_k.weight")
135
+ if i < num_layers - 1:
136
+ new_state_dict[new_prefix +
137
+ "attn.proj_y.weight"] = original_state_dict.pop(
138
+ block_prefix + "attn1.to_add_out.weight")
139
+ new_state_dict[new_prefix +
140
+ "attn.proj_y.bias"] = original_state_dict.pop(
141
+ block_prefix + "attn1.to_add_out.bias")
142
+
143
+ # MLP
144
+ new_state_dict[new_prefix + "mlp_x.w1.weight"] = reverse_proj_gate(
145
+ original_state_dict.pop(block_prefix + "ff.net.0.proj.weight"))
146
+ new_state_dict[new_prefix +
147
+ "mlp_x.w2.weight"] = original_state_dict.pop(
148
+ block_prefix + "ff.net.2.weight")
149
+ if i < num_layers - 1:
150
+ new_state_dict[new_prefix + "mlp_y.w1.weight"] = reverse_proj_gate(
151
+ original_state_dict.pop(block_prefix +
152
+ "ff_context.net.0.proj.weight"))
153
+ new_state_dict[new_prefix +
154
+ "mlp_y.w2.weight"] = original_state_dict.pop(
155
+ block_prefix + "ff_context.net.2.weight")
156
+
157
+ # Output layers
158
+ new_state_dict["final_layer.mod.weight"] = reverse_scale_shift(
159
+ original_state_dict.pop("norm_out.linear.weight"), dim=0)
160
+ new_state_dict["final_layer.mod.bias"] = reverse_scale_shift(
161
+ original_state_dict.pop("norm_out.linear.bias"), dim=0)
162
+ new_state_dict["final_layer.linear.weight"] = original_state_dict.pop(
163
+ "proj_out.weight")
164
+ new_state_dict["final_layer.linear.bias"] = original_state_dict.pop(
165
+ "proj_out.bias")
166
+
167
+ new_state_dict["pos_frequencies"] = original_state_dict.pop(
168
+ "pos_frequencies")
169
+
170
+ print("Remaining Keys:", original_state_dict.keys())
171
+
172
+ return new_state_dict
173
+
174
+
175
+ def convert_diffusers_vae_to_mochi(state_dict):
176
+ original_state_dict = state_dict.copy()
177
+ encoder_state_dict = {}
178
+ decoder_state_dict = {}
179
+
180
+ # Convert encoder
181
+ prefix = "encoder."
182
+
183
+ encoder_state_dict["layers.0.weight"] = original_state_dict.pop(
184
+ f"{prefix}proj_in.weight")
185
+ encoder_state_dict["layers.0.bias"] = original_state_dict.pop(
186
+ f"{prefix}proj_in.bias")
187
+
188
+ # Convert block_in
189
+ for i in range(3):
190
+ encoder_state_dict[
191
+ f"layers.{i+1}.stack.0.weight"] = original_state_dict.pop(
192
+ f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight")
193
+ encoder_state_dict[
194
+ f"layers.{i+1}.stack.0.bias"] = original_state_dict.pop(
195
+ f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias")
196
+ encoder_state_dict[
197
+ f"layers.{i+1}.stack.2.weight"] = original_state_dict.pop(
198
+ f"{prefix}block_in.resnets.{i}.conv1.conv.weight")
199
+ encoder_state_dict[
200
+ f"layers.{i+1}.stack.2.bias"] = original_state_dict.pop(
201
+ f"{prefix}block_in.resnets.{i}.conv1.conv.bias")
202
+ encoder_state_dict[
203
+ f"layers.{i+1}.stack.3.weight"] = original_state_dict.pop(
204
+ f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight")
205
+ encoder_state_dict[
206
+ f"layers.{i+1}.stack.3.bias"] = original_state_dict.pop(
207
+ f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias")
208
+ encoder_state_dict[
209
+ f"layers.{i+1}.stack.5.weight"] = original_state_dict.pop(
210
+ f"{prefix}block_in.resnets.{i}.conv2.conv.weight")
211
+ encoder_state_dict[
212
+ f"layers.{i+1}.stack.5.bias"] = original_state_dict.pop(
213
+ f"{prefix}block_in.resnets.{i}.conv2.conv.bias")
214
+
215
+ # Convert down_blocks
216
+ down_block_layers = [3, 4, 6]
217
+ for block in range(3):
218
+ encoder_state_dict[
219
+ f"layers.{block+4}.layers.0.weight"] = original_state_dict.pop(
220
+ f"{prefix}down_blocks.{block}.conv_in.conv.weight")
221
+ encoder_state_dict[
222
+ f"layers.{block+4}.layers.0.bias"] = original_state_dict.pop(
223
+ f"{prefix}down_blocks.{block}.conv_in.conv.bias")
224
+
225
+ for i in range(down_block_layers[block]):
226
+ # Convert resnets
227
+ encoder_state_dict[
228
+ f"layers.{block+4}.layers.{i+1}.stack.0.weight"] = original_state_dict.pop(
229
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
230
+ )
231
+ encoder_state_dict[
232
+ f"layers.{block+4}.layers.{i+1}.stack.0.bias"] = original_state_dict.pop(
233
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"
234
+ )
235
+ encoder_state_dict[
236
+ f"layers.{block+4}.layers.{i+1}.stack.2.weight"] = original_state_dict.pop(
237
+ f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"
238
+ )
239
+ encoder_state_dict[
240
+ f"layers.{block+4}.layers.{i+1}.stack.2.bias"] = original_state_dict.pop(
241
+ f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias")
242
+ encoder_state_dict[
243
+ f"layers.{block+4}.layers.{i+1}.stack.3.weight"] = original_state_dict.pop(
244
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
245
+ )
246
+ encoder_state_dict[
247
+ f"layers.{block+4}.layers.{i+1}.stack.3.bias"] = original_state_dict.pop(
248
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"
249
+ )
250
+ encoder_state_dict[
251
+ f"layers.{block+4}.layers.{i+1}.stack.5.weight"] = original_state_dict.pop(
252
+ f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"
253
+ )
254
+ encoder_state_dict[
255
+ f"layers.{block+4}.layers.{i+1}.stack.5.bias"] = original_state_dict.pop(
256
+ f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias")
257
+
258
+ # Convert attentions
259
+ q = original_state_dict.pop(
260
+ f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight")
261
+ k = original_state_dict.pop(
262
+ f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight")
263
+ v = original_state_dict.pop(
264
+ f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight")
265
+ qkv_weight = torch.cat([q, k, v], dim=0)
266
+ encoder_state_dict[
267
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight"] = qkv_weight
268
+
269
+ encoder_state_dict[
270
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"] = original_state_dict.pop(
271
+ f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"
272
+ )
273
+ encoder_state_dict[
274
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"] = original_state_dict.pop(
275
+ f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"
276
+ )
277
+ encoder_state_dict[
278
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"] = original_state_dict.pop(
279
+ f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight")
280
+ encoder_state_dict[
281
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"] = original_state_dict.pop(
282
+ f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias")
283
+
284
+ # Convert block_out
285
+ for i in range(3):
286
+ encoder_state_dict[
287
+ f"layers.{i+7}.stack.0.weight"] = original_state_dict.pop(
288
+ f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight")
289
+ encoder_state_dict[
290
+ f"layers.{i+7}.stack.0.bias"] = original_state_dict.pop(
291
+ f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias")
292
+ encoder_state_dict[
293
+ f"layers.{i+7}.stack.2.weight"] = original_state_dict.pop(
294
+ f"{prefix}block_out.resnets.{i}.conv1.conv.weight")
295
+ encoder_state_dict[
296
+ f"layers.{i+7}.stack.2.bias"] = original_state_dict.pop(
297
+ f"{prefix}block_out.resnets.{i}.conv1.conv.bias")
298
+ encoder_state_dict[
299
+ f"layers.{i+7}.stack.3.weight"] = original_state_dict.pop(
300
+ f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight")
301
+ encoder_state_dict[
302
+ f"layers.{i+7}.stack.3.bias"] = original_state_dict.pop(
303
+ f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias")
304
+ encoder_state_dict[
305
+ f"layers.{i+7}.stack.5.weight"] = original_state_dict.pop(
306
+ f"{prefix}block_out.resnets.{i}.conv2.conv.weight")
307
+ encoder_state_dict[
308
+ f"layers.{i+7}.stack.5.bias"] = original_state_dict.pop(
309
+ f"{prefix}block_out.resnets.{i}.conv2.conv.bias")
310
+
311
+ q = original_state_dict.pop(
312
+ f"{prefix}block_out.attentions.{i}.to_q.weight")
313
+ k = original_state_dict.pop(
314
+ f"{prefix}block_out.attentions.{i}.to_k.weight")
315
+ v = original_state_dict.pop(
316
+ f"{prefix}block_out.attentions.{i}.to_v.weight")
317
+ qkv_weight = torch.cat([q, k, v], dim=0)
318
+ encoder_state_dict[
319
+ f"layers.{i+7}.attn_block.attn.qkv.weight"] = qkv_weight
320
+
321
+ encoder_state_dict[
322
+ f"layers.{i+7}.attn_block.attn.out.weight"] = original_state_dict.pop(
323
+ f"{prefix}block_out.attentions.{i}.to_out.0.weight")
324
+ encoder_state_dict[
325
+ f"layers.{i+7}.attn_block.attn.out.bias"] = original_state_dict.pop(
326
+ f"{prefix}block_out.attentions.{i}.to_out.0.bias")
327
+ encoder_state_dict[
328
+ f"layers.{i+7}.attn_block.norm.weight"] = original_state_dict.pop(
329
+ f"{prefix}block_out.norms.{i}.norm_layer.weight")
330
+ encoder_state_dict[
331
+ f"layers.{i+7}.attn_block.norm.bias"] = original_state_dict.pop(
332
+ f"{prefix}block_out.norms.{i}.norm_layer.bias")
333
+
334
+ # Convert output layers
335
+ encoder_state_dict["output_norm.weight"] = original_state_dict.pop(
336
+ f"{prefix}norm_out.norm_layer.weight")
337
+ encoder_state_dict["output_norm.bias"] = original_state_dict.pop(
338
+ f"{prefix}norm_out.norm_layer.bias")
339
+ encoder_state_dict["output_proj.weight"] = original_state_dict.pop(
340
+ f"{prefix}proj_out.weight")
341
+
342
+ # Convert decoder
343
+ prefix = "decoder."
344
+
345
+ decoder_state_dict["blocks.0.0.weight"] = original_state_dict.pop(
346
+ f"{prefix}conv_in.weight")
347
+ decoder_state_dict["blocks.0.0.bias"] = original_state_dict.pop(
348
+ f"{prefix}conv_in.bias")
349
+
350
+ # Convert block_in
351
+ for i in range(3):
352
+ decoder_state_dict[
353
+ f"blocks.0.{i+1}.stack.0.weight"] = original_state_dict.pop(
354
+ f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight")
355
+ decoder_state_dict[
356
+ f"blocks.0.{i+1}.stack.0.bias"] = original_state_dict.pop(
357
+ f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias")
358
+ decoder_state_dict[
359
+ f"blocks.0.{i+1}.stack.2.weight"] = original_state_dict.pop(
360
+ f"{prefix}block_in.resnets.{i}.conv1.conv.weight")
361
+ decoder_state_dict[
362
+ f"blocks.0.{i+1}.stack.2.bias"] = original_state_dict.pop(
363
+ f"{prefix}block_in.resnets.{i}.conv1.conv.bias")
364
+ decoder_state_dict[
365
+ f"blocks.0.{i+1}.stack.3.weight"] = original_state_dict.pop(
366
+ f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight")
367
+ decoder_state_dict[
368
+ f"blocks.0.{i+1}.stack.3.bias"] = original_state_dict.pop(
369
+ f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias")
370
+ decoder_state_dict[
371
+ f"blocks.0.{i+1}.stack.5.weight"] = original_state_dict.pop(
372
+ f"{prefix}block_in.resnets.{i}.conv2.conv.weight")
373
+ decoder_state_dict[
374
+ f"blocks.0.{i+1}.stack.5.bias"] = original_state_dict.pop(
375
+ f"{prefix}block_in.resnets.{i}.conv2.conv.bias")
376
+
377
+ # Convert up_blocks
378
+ up_block_layers = [6, 4, 3]
379
+ for block in range(3):
380
+ for i in range(up_block_layers[block]):
381
+ decoder_state_dict[
382
+ f"blocks.{block+1}.blocks.{i}.stack.0.weight"] = original_state_dict.pop(
383
+ f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
384
+ )
385
+ decoder_state_dict[
386
+ f"blocks.{block+1}.blocks.{i}.stack.0.bias"] = original_state_dict.pop(
387
+ f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"
388
+ )
389
+ decoder_state_dict[
390
+ f"blocks.{block+1}.blocks.{i}.stack.2.weight"] = original_state_dict.pop(
391
+ f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight")
392
+ decoder_state_dict[
393
+ f"blocks.{block+1}.blocks.{i}.stack.2.bias"] = original_state_dict.pop(
394
+ f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias")
395
+ decoder_state_dict[
396
+ f"blocks.{block+1}.blocks.{i}.stack.3.weight"] = original_state_dict.pop(
397
+ f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
398
+ )
399
+ decoder_state_dict[
400
+ f"blocks.{block+1}.blocks.{i}.stack.3.bias"] = original_state_dict.pop(
401
+ f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"
402
+ )
403
+ decoder_state_dict[
404
+ f"blocks.{block+1}.blocks.{i}.stack.5.weight"] = original_state_dict.pop(
405
+ f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight")
406
+ decoder_state_dict[
407
+ f"blocks.{block+1}.blocks.{i}.stack.5.bias"] = original_state_dict.pop(
408
+ f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias")
409
+ decoder_state_dict[
410
+ f"blocks.{block+1}.proj.weight"] = original_state_dict.pop(
411
+ f"{prefix}up_blocks.{block}.proj.weight")
412
+ decoder_state_dict[
413
+ f"blocks.{block+1}.proj.bias"] = original_state_dict.pop(
414
+ f"{prefix}up_blocks.{block}.proj.bias")
415
+
416
+ # Convert block_out
417
+ for i in range(3):
418
+ decoder_state_dict[
419
+ f"blocks.4.{i}.stack.0.weight"] = original_state_dict.pop(
420
+ f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight")
421
+ decoder_state_dict[
422
+ f"blocks.4.{i}.stack.0.bias"] = original_state_dict.pop(
423
+ f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias")
424
+ decoder_state_dict[
425
+ f"blocks.4.{i}.stack.2.weight"] = original_state_dict.pop(
426
+ f"{prefix}block_out.resnets.{i}.conv1.conv.weight")
427
+ decoder_state_dict[
428
+ f"blocks.4.{i}.stack.2.bias"] = original_state_dict.pop(
429
+ f"{prefix}block_out.resnets.{i}.conv1.conv.bias")
430
+ decoder_state_dict[
431
+ f"blocks.4.{i}.stack.3.weight"] = original_state_dict.pop(
432
+ f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight")
433
+ decoder_state_dict[
434
+ f"blocks.4.{i}.stack.3.bias"] = original_state_dict.pop(
435
+ f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias")
436
+ decoder_state_dict[
437
+ f"blocks.4.{i}.stack.5.weight"] = original_state_dict.pop(
438
+ f"{prefix}block_out.resnets.{i}.conv2.conv.weight")
439
+ decoder_state_dict[
440
+ f"blocks.4.{i}.stack.5.bias"] = original_state_dict.pop(
441
+ f"{prefix}block_out.resnets.{i}.conv2.conv.bias")
442
+
443
+ # Convert output layers
444
+ decoder_state_dict["output_proj.weight"] = original_state_dict.pop(
445
+ f"{prefix}proj_out.weight")
446
+ decoder_state_dict["output_proj.bias"] = original_state_dict.pop(
447
+ f"{prefix}proj_out.bias")
448
+
449
+ return encoder_state_dict, decoder_state_dict
450
+
451
+
452
+ def ensure_safetensors_extension(path):
453
+ if not path.endswith(".safetensors"):
454
+ path = path + ".safetensors"
455
+ return path
456
+
457
+
458
+ def ensure_directory_exists(path):
459
+ directory = os.path.dirname(path)
460
+ if directory:
461
+ os.makedirs(directory, exist_ok=True)
462
+
463
+
464
+ def main(args):
465
+ from diffusers import MochiPipeline
466
+
467
+ pipe = MochiPipeline.from_pretrained(args.diffusers_path)
468
+
469
+ if args.transformer_path:
470
+ transformer_path = ensure_safetensors_extension(args.transformer_path)
471
+ ensure_directory_exists(transformer_path)
472
+
473
+ print("Converting transformer model...")
474
+ transformer_state_dict = convert_diffusers_transformer_to_mochi(
475
+ pipe.transformer.state_dict())
476
+ save_file(transformer_state_dict, transformer_path)
477
+ print(f"Saved transformer to {transformer_path}")
478
+
479
+ if args.vae_encoder_path and args.vae_decoder_path:
480
+ encoder_path = ensure_safetensors_extension(args.vae_encoder_path)
481
+ decoder_path = ensure_safetensors_extension(args.vae_decoder_path)
482
+
483
+ ensure_directory_exists(encoder_path)
484
+ ensure_directory_exists(decoder_path)
485
+
486
+ print("Converting VAE models...")
487
+ encoder_state_dict, decoder_state_dict = convert_diffusers_vae_to_mochi(
488
+ pipe.vae.state_dict())
489
+
490
+ save_file(encoder_state_dict, encoder_path)
491
+ print(f"Saved VAE encoder to {encoder_path}")
492
+
493
+ save_file(decoder_state_dict, decoder_path)
494
+ print(f"Saved VAE decoder to {decoder_path}")
495
+ elif args.vae_encoder_path or args.vae_decoder_path:
496
+ print(
497
+ "Warning: Both VAE encoder and decoder paths must be specified to convert VAE models."
498
+ )
499
+
500
+
501
+ if __name__ == "__main__":
502
+ main(args)
fastvideo/models/mochi_hf/mochi_latents_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
2
+
3
+ import torch
4
+
5
+ mochi_latents_mean = torch.tensor([
6
+ -0.06730895953510081,
7
+ -0.038011381506090416,
8
+ -0.07477820912866141,
9
+ -0.05565264470995561,
10
+ 0.012767231469026969,
11
+ -0.04703542746246419,
12
+ 0.043896967884726704,
13
+ -0.09346305707025976,
14
+ -0.09918314763016893,
15
+ -0.008729793427399178,
16
+ -0.011931556316503654,
17
+ -0.0321993391887285,
18
+ ]).view(1, 12, 1, 1, 1)
19
+ mochi_latents_std = torch.tensor([
20
+ 0.9263795028493863,
21
+ 0.9248894543193766,
22
+ 0.9393059390890617,
23
+ 0.959253732819592,
24
+ 0.8244560132752793,
25
+ 0.917259975397747,
26
+ 0.9294154431013696,
27
+ 1.3720942357788521,
28
+ 0.881393668867029,
29
+ 0.9168315692124348,
30
+ 0.9185249279345552,
31
+ 0.9274757570805041,
32
+ ]).view(1, 12, 1, 1, 1)
33
+ mochi_scaling_factor = 1.0
34
+
35
+
36
+ def normalize_dit_input(model_type, latents):
37
+ if model_type == "mochi":
38
+ latents_mean = mochi_latents_mean.to(latents.device, latents.dtype)
39
+ latents_std = mochi_latents_std.to(latents.device, latents.dtype)
40
+ latents = (latents - latents_mean) / latents_std
41
+ return latents
42
+ elif model_type == "hunyuan_hf":
43
+ return latents * 0.476986
44
+ elif model_type == "hunyuan":
45
+ return latents * 0.476986
46
+ else:
47
+ raise NotImplementedError(f"model_type {model_type} not supported")
fastvideo/models/mochi_hf/modeling_mochi.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Genmo team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.loaders import PeftAdapterMixin
22
+ from diffusers.models.attention import FeedForward as HF_FeedForward
23
+ from diffusers.models.attention_processor import Attention
24
+ from diffusers.models.embeddings import (MochiCombinedTimestepCaptionEmbedding,
25
+ PatchEmbed)
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.models.normalization import AdaLayerNormContinuous
28
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
29
+ scale_lora_layers, unscale_lora_layers)
30
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
31
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
32
+
33
+ from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad
34
+ from fastvideo.models.mochi_hf.norm import (MochiLayerNormContinuous,
35
+ MochiModulatedRMSNorm,
36
+ MochiRMSNorm, MochiRMSNormZero)
37
+ from fastvideo.utils.communications import all_gather, all_to_all_4D
38
+ from fastvideo.utils.parallel_states import (get_sequence_parallel_state,
39
+ nccl_info)
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ class FeedForward(HF_FeedForward):
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ dim_out: Optional[int] = None,
50
+ mult: int = 4,
51
+ dropout: float = 0.0,
52
+ activation_fn: str = "geglu",
53
+ final_dropout: bool = False,
54
+ inner_dim=None,
55
+ bias: bool = True,
56
+ ):
57
+ super().__init__(dim, dim_out, mult, dropout, activation_fn,
58
+ final_dropout, inner_dim, bias)
59
+ assert activation_fn == "swiglu"
60
+
61
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
62
+ hidden_states = self.net[0].proj(hidden_states)
63
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
64
+
65
+ return self.net[2](LigerSiLUMulFunction.apply(gate, hidden_states))
66
+
67
+
68
+ class MochiAttention(nn.Module):
69
+
70
+ def __init__(
71
+ self,
72
+ query_dim: int,
73
+ processor: "MochiAttnProcessor2_0",
74
+ heads: int = 8,
75
+ dim_head: int = 64,
76
+ dropout: float = 0.0,
77
+ bias: bool = False,
78
+ added_kv_proj_dim: Optional[int] = None,
79
+ added_proj_bias: Optional[bool] = True,
80
+ out_dim: int = None,
81
+ out_context_dim: int = None,
82
+ out_bias: bool = True,
83
+ context_pre_only: bool = False,
84
+ eps: float = 1e-5,
85
+ ):
86
+ super().__init__()
87
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
88
+ self.out_dim = out_dim if out_dim is not None else query_dim
89
+ self.out_context_dim = out_context_dim if out_context_dim else query_dim
90
+ self.context_pre_only = context_pre_only
91
+
92
+ self.heads = out_dim // dim_head if out_dim is not None else heads
93
+
94
+ self.norm_q = MochiRMSNorm(dim_head, eps)
95
+ self.norm_k = MochiRMSNorm(dim_head, eps)
96
+ self.norm_added_q = MochiRMSNorm(dim_head, eps)
97
+ self.norm_added_k = MochiRMSNorm(dim_head, eps)
98
+
99
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
100
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
101
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
102
+
103
+ self.add_k_proj = nn.Linear(added_kv_proj_dim,
104
+ self.inner_dim,
105
+ bias=added_proj_bias)
106
+ self.add_v_proj = nn.Linear(added_kv_proj_dim,
107
+ self.inner_dim,
108
+ bias=added_proj_bias)
109
+ if self.context_pre_only is not None:
110
+ self.add_q_proj = nn.Linear(added_kv_proj_dim,
111
+ self.inner_dim,
112
+ bias=added_proj_bias)
113
+
114
+ self.to_out = nn.ModuleList([])
115
+ self.to_out.append(
116
+ nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
117
+ self.to_out.append(nn.Dropout(dropout))
118
+
119
+ if not self.context_pre_only:
120
+ self.to_add_out = nn.Linear(self.inner_dim,
121
+ self.out_context_dim,
122
+ bias=out_bias)
123
+
124
+ self.processor = processor
125
+
126
+ def forward(
127
+ self,
128
+ hidden_states: torch.Tensor,
129
+ encoder_hidden_states: Optional[torch.Tensor] = None,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ **kwargs,
132
+ ):
133
+ return self.processor(
134
+ self,
135
+ hidden_states,
136
+ encoder_hidden_states=encoder_hidden_states,
137
+ attention_mask=attention_mask,
138
+ **kwargs,
139
+ )
140
+
141
+
142
+ class MochiAttnProcessor2_0:
143
+ """Attention processor used in Mochi."""
144
+
145
+ def __init__(self):
146
+ if not hasattr(F, "scaled_dot_product_attention"):
147
+ raise ImportError(
148
+ "MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
149
+ )
150
+
151
+ def __call__(
152
+ self,
153
+ attn: Attention,
154
+ hidden_states: torch.Tensor,
155
+ encoder_hidden_states: torch.Tensor,
156
+ encoder_attention_mask: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ image_rotary_emb: Optional[torch.Tensor] = None,
159
+ ) -> torch.Tensor:
160
+ # [b, s, h * d]
161
+ query = attn.to_q(hidden_states)
162
+ key = attn.to_k(hidden_states)
163
+ value = attn.to_v(hidden_states)
164
+
165
+ # [b, s, h=24, d=128]
166
+ query = query.unflatten(2, (attn.heads, -1))
167
+ key = key.unflatten(2, (attn.heads, -1))
168
+ value = value.unflatten(2, (attn.heads, -1))
169
+
170
+ if attn.norm_q is not None:
171
+ query = attn.norm_q(query)
172
+ if attn.norm_k is not None:
173
+ key = attn.norm_k(key)
174
+ # [b, 256, h * d]
175
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
176
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
177
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
178
+
179
+ # [b, 256, h=24, d=128]
180
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
181
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
182
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
183
+
184
+ if attn.norm_added_q is not None:
185
+ encoder_query = attn.norm_added_q(encoder_query)
186
+ if attn.norm_added_k is not None:
187
+ encoder_key = attn.norm_added_k(encoder_key)
188
+
189
+ if image_rotary_emb is not None:
190
+ freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1]
191
+ # shard the head dimension
192
+ if get_sequence_parallel_state():
193
+ # B, S, H, D to (S, B,) H, D
194
+ # batch_size, seq_len, attn_heads, head_dim
195
+ query = all_to_all_4D(query, scatter_dim=2, gather_dim=1)
196
+ key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
197
+ value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
198
+
199
+ def shrink_head(encoder_state, dim):
200
+ local_heads = encoder_state.shape[dim] // nccl_info.sp_size
201
+ return encoder_state.narrow(
202
+ dim, nccl_info.rank_within_group * local_heads,
203
+ local_heads)
204
+
205
+ encoder_query = shrink_head(encoder_query, dim=2)
206
+ encoder_key = shrink_head(encoder_key, dim=2)
207
+ encoder_value = shrink_head(encoder_value, dim=2)
208
+ if image_rotary_emb is not None:
209
+ freqs_cos = shrink_head(freqs_cos, dim=1)
210
+ freqs_sin = shrink_head(freqs_sin, dim=1)
211
+
212
+ if image_rotary_emb is not None:
213
+
214
+ def apply_rotary_emb(x, freqs_cos, freqs_sin):
215
+ x_even = x[..., 0::2].float()
216
+ x_odd = x[..., 1::2].float()
217
+ cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
218
+ sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
219
+
220
+ return torch.stack([cos, sin], dim=-1).flatten(-2)
221
+
222
+ query = apply_rotary_emb(query, freqs_cos, freqs_sin)
223
+ key = apply_rotary_emb(key, freqs_cos, freqs_sin)
224
+
225
+ # query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
226
+ # encoder_query, encoder_key, encoder_value = (
227
+ # encoder_query.transpose(1, 2),
228
+ # encoder_key.transpose(1, 2),
229
+ # encoder_value.transpose(1, 2),
230
+ # )
231
+ # [b, s, h, d]
232
+ sequence_length = query.size(1)
233
+ encoder_sequence_length = encoder_query.size(1)
234
+
235
+ # H
236
+ query = torch.cat([query, encoder_query], dim=1).unsqueeze(2)
237
+ key = torch.cat([key, encoder_key], dim=1).unsqueeze(2)
238
+ value = torch.cat([value, encoder_value], dim=1).unsqueeze(2)
239
+ # B, S, 3, H, D
240
+ qkv = torch.cat([query, key, value], dim=2)
241
+
242
+ attn_mask = encoder_attention_mask[:, :].bool()
243
+ attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True)
244
+ hidden_states = flash_attn_no_pad(qkv,
245
+ attn_mask,
246
+ causal=False,
247
+ dropout_p=0.0,
248
+ softmax_scale=None)
249
+
250
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False)
251
+
252
+ # valid_lengths = encoder_attention_mask.sum(dim=1) + sequence_length
253
+ # def no_padding_mask(score, b, h, q_idx, kv_idx):
254
+ # return torch.where(kv_idx < valid_lengths[b],score, -float("inf"))
255
+
256
+ # hidden_states = flex_attention(query, key, value, score_mod=no_padding_mask)
257
+ if get_sequence_parallel_state():
258
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
259
+ (sequence_length, encoder_sequence_length), dim=1)
260
+ # B, S, H, D
261
+ hidden_states = all_to_all_4D(hidden_states,
262
+ scatter_dim=1,
263
+ gather_dim=2)
264
+ encoder_hidden_states = all_gather(encoder_hidden_states,
265
+ dim=2).contiguous()
266
+ hidden_states = hidden_states.flatten(2, 3)
267
+ hidden_states = hidden_states.to(query.dtype)
268
+ encoder_hidden_states = encoder_hidden_states.flatten(2, 3)
269
+ encoder_hidden_states = encoder_hidden_states.to(query.dtype)
270
+ else:
271
+ hidden_states = hidden_states.flatten(2, 3)
272
+ hidden_states = hidden_states.to(query.dtype)
273
+
274
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
275
+ (sequence_length, encoder_sequence_length), dim=1)
276
+
277
+ # linear proj
278
+ hidden_states = attn.to_out[0](hidden_states)
279
+ # dropout
280
+ hidden_states = attn.to_out[1](hidden_states)
281
+
282
+ if hasattr(attn, "to_add_out"):
283
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
284
+
285
+ return hidden_states, encoder_hidden_states
286
+
287
+
288
+ @maybe_allow_in_graph
289
+ class MochiTransformerBlock(nn.Module):
290
+ r"""
291
+ Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
292
+
293
+ Args:
294
+ dim (`int`):
295
+ The number of channels in the input and output.
296
+ num_attention_heads (`int`):
297
+ The number of heads to use for multi-head attention.
298
+ attention_head_dim (`int`):
299
+ The number of channels in each head.
300
+ qk_norm (`str`, defaults to `"rms_norm"`):
301
+ The normalization layer to use.
302
+ activation_fn (`str`, defaults to `"swiglu"`):
303
+ Activation function to use in feed-forward.
304
+ context_pre_only (`bool`, defaults to `False`):
305
+ Whether or not to process context-related conditions with additional layers.
306
+ eps (`float`, defaults to `1e-6`):
307
+ Epsilon value for normalization layers.
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ dim: int,
313
+ num_attention_heads: int,
314
+ attention_head_dim: int,
315
+ pooled_projection_dim: int,
316
+ qk_norm: str = "rms_norm",
317
+ activation_fn: str = "swiglu",
318
+ context_pre_only: bool = False,
319
+ eps: float = 1e-6,
320
+ ) -> None:
321
+ super().__init__()
322
+
323
+ self.context_pre_only = context_pre_only
324
+ self.ff_inner_dim = (4 * dim * 2) // 3
325
+ self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
326
+
327
+ self.norm1 = MochiRMSNormZero(dim,
328
+ 4 * dim,
329
+ eps=eps,
330
+ elementwise_affine=False)
331
+
332
+ if not context_pre_only:
333
+ self.norm1_context = MochiRMSNormZero(dim,
334
+ 4 * pooled_projection_dim,
335
+ eps=eps,
336
+ elementwise_affine=False)
337
+ else:
338
+ self.norm1_context = MochiLayerNormContinuous(
339
+ embedding_dim=pooled_projection_dim,
340
+ conditioning_embedding_dim=dim,
341
+ eps=eps,
342
+ )
343
+
344
+ self.attn1 = MochiAttention(
345
+ query_dim=dim,
346
+ heads=num_attention_heads,
347
+ dim_head=attention_head_dim,
348
+ bias=False,
349
+ added_kv_proj_dim=pooled_projection_dim,
350
+ added_proj_bias=False,
351
+ out_dim=dim,
352
+ out_context_dim=pooled_projection_dim,
353
+ context_pre_only=context_pre_only,
354
+ processor=MochiAttnProcessor2_0(),
355
+ eps=1e-5,
356
+ )
357
+
358
+ # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
359
+ self.norm2 = MochiModulatedRMSNorm(eps=eps)
360
+ self.norm2_context = (MochiModulatedRMSNorm(
361
+ eps=eps) if not self.context_pre_only else None)
362
+
363
+ self.norm3 = MochiModulatedRMSNorm(eps)
364
+ self.norm3_context = (MochiModulatedRMSNorm(
365
+ eps=eps) if not self.context_pre_only else None)
366
+
367
+ self.ff = FeedForward(dim,
368
+ inner_dim=self.ff_inner_dim,
369
+ activation_fn=activation_fn,
370
+ bias=False)
371
+ self.ff_context = None
372
+ if not context_pre_only:
373
+ self.ff_context = FeedForward(
374
+ pooled_projection_dim,
375
+ inner_dim=self.ff_context_inner_dim,
376
+ activation_fn=activation_fn,
377
+ bias=False,
378
+ )
379
+
380
+ self.norm4 = MochiModulatedRMSNorm(eps=eps)
381
+ self.norm4_context = MochiModulatedRMSNorm(eps=eps)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ encoder_hidden_states: torch.Tensor,
387
+ encoder_attention_mask: torch.Tensor,
388
+ temb: torch.Tensor,
389
+ image_rotary_emb: Optional[torch.Tensor] = None,
390
+ output_attn=False,
391
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
392
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(
393
+ hidden_states, temb)
394
+
395
+ if not self.context_pre_only:
396
+ (
397
+ norm_encoder_hidden_states,
398
+ enc_gate_msa,
399
+ enc_scale_mlp,
400
+ enc_gate_mlp,
401
+ ) = self.norm1_context(encoder_hidden_states, temb)
402
+ else:
403
+ norm_encoder_hidden_states = self.norm1_context(
404
+ encoder_hidden_states, temb)
405
+
406
+ attn_hidden_states, context_attn_hidden_states = self.attn1(
407
+ hidden_states=norm_hidden_states,
408
+ encoder_hidden_states=norm_encoder_hidden_states,
409
+ image_rotary_emb=image_rotary_emb,
410
+ encoder_attention_mask=encoder_attention_mask,
411
+ )
412
+
413
+ hidden_states = hidden_states + self.norm2(
414
+ attn_hidden_states,
415
+ torch.tanh(gate_msa).unsqueeze(1))
416
+ norm_hidden_states = self.norm3(
417
+ hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
418
+ ff_output = self.ff(norm_hidden_states)
419
+ hidden_states = hidden_states + self.norm4(
420
+ ff_output,
421
+ torch.tanh(gate_mlp).unsqueeze(1))
422
+
423
+ if not self.context_pre_only:
424
+ encoder_hidden_states = encoder_hidden_states + self.norm2_context(
425
+ context_attn_hidden_states,
426
+ torch.tanh(enc_gate_msa).unsqueeze(1))
427
+ norm_encoder_hidden_states = self.norm3_context(
428
+ encoder_hidden_states,
429
+ (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)),
430
+ )
431
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
432
+ encoder_hidden_states = encoder_hidden_states + self.norm4_context(
433
+ context_ff_output,
434
+ torch.tanh(enc_gate_mlp).unsqueeze(1))
435
+
436
+ if not output_attn:
437
+ attn_hidden_states = None
438
+ return hidden_states, encoder_hidden_states, attn_hidden_states
439
+
440
+
441
+ class MochiRoPE(nn.Module):
442
+ r"""
443
+ RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
444
+
445
+ Args:
446
+ base_height (`int`, defaults to `192`):
447
+ Base height used to compute interpolation scale for rotary positional embeddings.
448
+ base_width (`int`, defaults to `192`):
449
+ Base width used to compute interpolation scale for rotary positional embeddings.
450
+ """
451
+
452
+ def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
453
+ super().__init__()
454
+
455
+ self.target_area = base_height * base_width
456
+
457
+ def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
458
+ edges = torch.linspace(start,
459
+ stop,
460
+ num + 1,
461
+ device=device,
462
+ dtype=dtype)
463
+ return (edges[:-1] + edges[1:]) / 2
464
+
465
+ def _get_positions(
466
+ self,
467
+ num_frames: int,
468
+ height: int,
469
+ width: int,
470
+ device: Optional[torch.device] = None,
471
+ dtype: Optional[torch.dtype] = None,
472
+ ) -> torch.Tensor:
473
+ scale = (self.target_area / (height * width))**0.5
474
+ t = torch.arange(num_frames * nccl_info.sp_size,
475
+ device=device,
476
+ dtype=dtype)
477
+ h = self._centers(-height * scale / 2, height * scale / 2, height,
478
+ device, dtype)
479
+ w = self._centers(-width * scale / 2, width * scale / 2, width, device,
480
+ dtype)
481
+
482
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
483
+
484
+ positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
485
+ return positions
486
+
487
+ def _create_rope(self, freqs: torch.Tensor,
488
+ pos: torch.Tensor) -> torch.Tensor:
489
+ with torch.autocast(freqs.device.type, enabled=False):
490
+ # Always run ROPE freqs computation in FP32
491
+ freqs = torch.einsum(
492
+ "nd,dhf->nhf", # codespell:ignore
493
+ pos.to(torch.float32), # codespell:ignore
494
+ freqs.to(torch.float32))
495
+ freqs_cos = torch.cos(freqs)
496
+ freqs_sin = torch.sin(freqs)
497
+ return freqs_cos, freqs_sin
498
+
499
+ def forward(
500
+ self,
501
+ pos_frequencies: torch.Tensor,
502
+ num_frames: int,
503
+ height: int,
504
+ width: int,
505
+ device: Optional[torch.device] = None,
506
+ dtype: Optional[torch.dtype] = None,
507
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
508
+ pos = self._get_positions(num_frames, height, width, device, dtype)
509
+ rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
510
+ return rope_cos, rope_sin
511
+
512
+
513
+ @maybe_allow_in_graph
514
+ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
515
+ r"""
516
+ A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
517
+
518
+ Args:
519
+ patch_size (`int`, defaults to `2`):
520
+ The size of the patches to use in the patch embedding layer.
521
+ num_attention_heads (`int`, defaults to `24`):
522
+ The number of heads to use for multi-head attention.
523
+ attention_head_dim (`int`, defaults to `128`):
524
+ The number of channels in each head.
525
+ num_layers (`int`, defaults to `48`):
526
+ The number of layers of Transformer blocks to use.
527
+ in_channels (`int`, defaults to `12`):
528
+ The number of channels in the input.
529
+ out_channels (`int`, *optional*, defaults to `None`):
530
+ The number of channels in the output.
531
+ qk_norm (`str`, defaults to `"rms_norm"`):
532
+ The normalization layer to use.
533
+ text_embed_dim (`int`, defaults to `4096`):
534
+ Input dimension of text embeddings from the text encoder.
535
+ time_embed_dim (`int`, defaults to `256`):
536
+ Output dimension of timestep embeddings.
537
+ activation_fn (`str`, defaults to `"swiglu"`):
538
+ Activation function to use in feed-forward.
539
+ max_sequence_length (`int`, defaults to `256`):
540
+ The maximum sequence length of text embeddings supported.
541
+ """
542
+
543
+ _supports_gradient_checkpointing = True
544
+
545
+ @register_to_config
546
+ def __init__(
547
+ self,
548
+ patch_size: int = 2,
549
+ num_attention_heads: int = 24,
550
+ attention_head_dim: int = 128,
551
+ num_layers: int = 48,
552
+ pooled_projection_dim: int = 1536,
553
+ in_channels: int = 12,
554
+ out_channels: Optional[int] = None,
555
+ qk_norm: str = "rms_norm",
556
+ text_embed_dim: int = 4096,
557
+ time_embed_dim: int = 256,
558
+ activation_fn: str = "swiglu",
559
+ max_sequence_length: int = 256,
560
+ ) -> None:
561
+ super().__init__()
562
+
563
+ inner_dim = num_attention_heads * attention_head_dim
564
+ out_channels = out_channels or in_channels
565
+
566
+ self.patch_embed = PatchEmbed(
567
+ patch_size=patch_size,
568
+ in_channels=in_channels,
569
+ embed_dim=inner_dim,
570
+ pos_embed_type=None,
571
+ )
572
+
573
+ self.time_embed = MochiCombinedTimestepCaptionEmbedding(
574
+ embedding_dim=inner_dim,
575
+ pooled_projection_dim=pooled_projection_dim,
576
+ text_embed_dim=text_embed_dim,
577
+ time_embed_dim=time_embed_dim,
578
+ num_attention_heads=8,
579
+ )
580
+
581
+ self.pos_frequencies = nn.Parameter(
582
+ torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
583
+ self.rope = MochiRoPE()
584
+
585
+ self.transformer_blocks = nn.ModuleList([
586
+ MochiTransformerBlock(
587
+ dim=inner_dim,
588
+ num_attention_heads=num_attention_heads,
589
+ attention_head_dim=attention_head_dim,
590
+ pooled_projection_dim=pooled_projection_dim,
591
+ qk_norm=qk_norm,
592
+ activation_fn=activation_fn,
593
+ context_pre_only=i == num_layers - 1,
594
+ ) for i in range(num_layers)
595
+ ])
596
+
597
+ self.norm_out = AdaLayerNormContinuous(
598
+ inner_dim,
599
+ inner_dim,
600
+ elementwise_affine=False,
601
+ eps=1e-6,
602
+ norm_type="layer_norm",
603
+ )
604
+ self.proj_out = nn.Linear(inner_dim,
605
+ patch_size * patch_size * out_channels)
606
+
607
+ self.gradient_checkpointing = False
608
+
609
+ def _set_gradient_checkpointing(self, module, value=False):
610
+ if hasattr(module, "gradient_checkpointing"):
611
+ module.gradient_checkpointing = value
612
+
613
+ def forward(
614
+ self,
615
+ hidden_states: torch.Tensor,
616
+ encoder_hidden_states: torch.Tensor,
617
+ timestep: torch.LongTensor,
618
+ encoder_attention_mask: torch.Tensor,
619
+ output_features=False,
620
+ output_features_stride=8,
621
+ attention_kwargs: Optional[Dict[str, Any]] = None,
622
+ return_dict: bool = False,
623
+ ) -> torch.Tensor:
624
+ assert (return_dict is False
625
+ ), "return_dict is not supported in MochiTransformer3DModel"
626
+
627
+ if attention_kwargs is not None:
628
+ attention_kwargs = attention_kwargs.copy()
629
+ lora_scale = attention_kwargs.pop("scale", 1.0)
630
+ else:
631
+ lora_scale = 1.0
632
+
633
+ if USE_PEFT_BACKEND:
634
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
635
+ scale_lora_layers(self, lora_scale)
636
+ else:
637
+ if (attention_kwargs is not None
638
+ and attention_kwargs.get("scale", None) is not None):
639
+ logger.warning(
640
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
641
+ )
642
+
643
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
644
+ p = self.config.patch_size
645
+
646
+ post_patch_height = height // p
647
+ post_patch_width = width // p
648
+ # Peiyuan: This is hacked to force mochi to follow the behaviour of SD3 and Flux
649
+ timestep = 1000 - timestep
650
+ temb, encoder_hidden_states = self.time_embed(
651
+ timestep,
652
+ encoder_hidden_states,
653
+ encoder_attention_mask,
654
+ hidden_dtype=hidden_states.dtype,
655
+ )
656
+
657
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
658
+ hidden_states = self.patch_embed(hidden_states)
659
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(
660
+ 1, 2)
661
+
662
+ image_rotary_emb = self.rope(
663
+ self.pos_frequencies,
664
+ num_frames,
665
+ post_patch_height,
666
+ post_patch_width,
667
+ device=hidden_states.device,
668
+ dtype=torch.float32,
669
+ )
670
+ attn_outputs_list = []
671
+ for i, block in enumerate(self.transformer_blocks):
672
+ if self.gradient_checkpointing:
673
+
674
+ def create_custom_forward(module):
675
+
676
+ def custom_forward(*inputs):
677
+ return module(*inputs)
678
+
679
+ return custom_forward
680
+
681
+ ckpt_kwargs: Dict[str, Any] = ({
682
+ "use_reentrant": False
683
+ } if is_torch_version(">=", "1.11.0") else {})
684
+ (
685
+ hidden_states,
686
+ encoder_hidden_states,
687
+ attn_outputs,
688
+ ) = torch.utils.checkpoint.checkpoint(
689
+ create_custom_forward(block),
690
+ hidden_states,
691
+ encoder_hidden_states,
692
+ encoder_attention_mask,
693
+ temb,
694
+ image_rotary_emb,
695
+ output_features,
696
+ **ckpt_kwargs,
697
+ )
698
+ else:
699
+ hidden_states, encoder_hidden_states, attn_outputs = block(
700
+ hidden_states=hidden_states,
701
+ encoder_hidden_states=encoder_hidden_states,
702
+ encoder_attention_mask=encoder_attention_mask,
703
+ temb=temb,
704
+ image_rotary_emb=image_rotary_emb,
705
+ output_attn=output_features,
706
+ )
707
+ if i % output_features_stride == 0:
708
+ attn_outputs_list.append(attn_outputs)
709
+
710
+ hidden_states = self.norm_out(hidden_states, temb)
711
+ hidden_states = self.proj_out(hidden_states)
712
+
713
+ hidden_states = hidden_states.reshape(batch_size, num_frames,
714
+ post_patch_height,
715
+ post_patch_width, p, p, -1)
716
+ hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
717
+ output = hidden_states.reshape(batch_size, -1, num_frames, height,
718
+ width)
719
+
720
+ if USE_PEFT_BACKEND:
721
+ # remove `lora_scale` from each PEFT layer
722
+ unscale_lora_layers(self, lora_scale)
723
+
724
+ if not output_features:
725
+ attn_outputs_list = None
726
+ else:
727
+ attn_outputs_list = torch.stack(attn_outputs_list, dim=0)
728
+ # Peiyuan: This is hacked to force mochi to follow the behaviour of SD3 and Flux
729
+ return (-output, attn_outputs_list)
fastvideo/models/mochi_hf/norm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Genmo team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class MochiModulatedRMSNorm(nn.Module):
23
+
24
+ def __init__(self, eps: float):
25
+ super().__init__()
26
+
27
+ self.eps = eps
28
+
29
+ def forward(self, hidden_states, scale=None):
30
+ hidden_states_dtype = hidden_states.dtype
31
+ hidden_states = hidden_states.to(torch.float32)
32
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
33
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
34
+ if scale is not None:
35
+ hidden_states = hidden_states * scale
36
+
37
+ hidden_states = hidden_states.to(hidden_states_dtype)
38
+
39
+ return hidden_states
40
+
41
+
42
+ class MochiRMSNorm(nn.Module):
43
+
44
+ def __init__(self, dim, eps: float, elementwise_affine=True):
45
+ super().__init__()
46
+
47
+ self.eps = eps
48
+ if elementwise_affine:
49
+ self.weight = nn.Parameter(torch.ones(dim))
50
+ else:
51
+ self.weight = None
52
+
53
+ def forward(self, hidden_states):
54
+ hidden_states_dtype = hidden_states.dtype
55
+ hidden_states = hidden_states.to(torch.float32)
56
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
57
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
58
+ if self.weight is not None:
59
+ # convert into half-precision if necessary
60
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
61
+ hidden_states = hidden_states.to(self.weight.dtype)
62
+ hidden_states = hidden_states * self.weight
63
+ hidden_states = hidden_states.to(hidden_states_dtype)
64
+
65
+ return hidden_states
66
+
67
+
68
+ class MochiLayerNormContinuous(nn.Module):
69
+
70
+ def __init__(
71
+ self,
72
+ embedding_dim: int,
73
+ conditioning_embedding_dim: int,
74
+ eps=1e-5,
75
+ bias=True,
76
+ ):
77
+ super().__init__()
78
+
79
+ # AdaLN
80
+ self.silu = nn.SiLU()
81
+ self.linear_1 = nn.Linear(conditioning_embedding_dim,
82
+ embedding_dim,
83
+ bias=bias)
84
+ self.norm = MochiModulatedRMSNorm(eps=eps)
85
+
86
+ def forward(
87
+ self,
88
+ x: torch.Tensor,
89
+ conditioning_embedding: torch.Tensor,
90
+ ) -> torch.Tensor:
91
+ input_dtype = x.dtype
92
+
93
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
94
+ scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
95
+ x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
96
+
97
+ return x.to(input_dtype)
98
+
99
+
100
+ class MochiRMSNormZero(nn.Module):
101
+ r"""
102
+ Adaptive RMS Norm used in Mochi.
103
+ Parameters:
104
+ embedding_dim (`int`): The size of each embedding vector.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ embedding_dim: int,
110
+ hidden_dim: int,
111
+ eps: float = 1e-5,
112
+ elementwise_affine: bool = False,
113
+ ) -> None:
114
+ super().__init__()
115
+
116
+ self.silu = nn.SiLU()
117
+ self.linear = nn.Linear(embedding_dim, hidden_dim)
118
+ self.norm = MochiModulatedRMSNorm(eps=eps)
119
+
120
+ def forward(
121
+ self, hidden_states: torch.Tensor, emb: torch.Tensor
122
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
123
+ hidden_states_dtype = hidden_states.dtype
124
+
125
+ emb = self.linear(self.silu(emb))
126
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
127
+
128
+ hidden_states = self.norm(hidden_states,
129
+ (1 + scale_msa[:, None].to(torch.float32)))
130
+ hidden_states = hidden_states.to(hidden_states_dtype)
131
+
132
+ return hidden_states, gate_msa, scale_mlp, gate_mlp
fastvideo/models/mochi_hf/pipeline_mochi.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
+ from diffusers.loaders import Mochi1LoraLoaderMixin
23
+ from diffusers.models.autoencoders import AutoencoderKL
24
+ from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (is_torch_xla_available, logging,
28
+ replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+ from einops import rearrange
32
+ from transformers import T5EncoderModel, T5TokenizerFast
33
+
34
+ from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
35
+ from fastvideo.utils.communications import all_gather
36
+ from fastvideo.utils.parallel_states import (get_sequence_parallel_state,
37
+ nccl_info)
38
+
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import MochiPipeline
53
+ >>> from diffusers.utils import export_to_video
54
+
55
+ >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
56
+ >>> pipe.to("cuda")
57
+ >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
58
+ >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
59
+ >>> export_to_video(frames, "mochi.mp4")
60
+ ```
61
+ """
62
+
63
+
64
+ def calculate_shift(
65
+ image_seq_len,
66
+ base_seq_len: int = 256,
67
+ max_seq_len: int = 4096,
68
+ base_shift: float = 0.5,
69
+ max_shift: float = 1.16,
70
+ ):
71
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
72
+ b = base_shift - m * base_seq_len
73
+ mu = image_seq_len * m + b
74
+ return mu
75
+
76
+
77
+ # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
78
+ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
79
+ if linear_steps is None:
80
+ linear_steps = num_steps // 2
81
+ linear_sigma_schedule = [
82
+ i * threshold_noise / linear_steps for i in range(linear_steps)
83
+ ]
84
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
85
+ quadratic_steps = num_steps - linear_steps
86
+ quadratic_coef = threshold_noise_step_diff / (linear_steps *
87
+ quadratic_steps**2)
88
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (
89
+ quadratic_steps**2)
90
+ const = quadratic_coef * (linear_steps**2)
91
+ quadratic_sigma_schedule = [
92
+ quadratic_coef * (i**2) + linear_coef * i + const
93
+ for i in range(linear_steps, num_steps)
94
+ ]
95
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
96
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
97
+ return sigma_schedule
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
+ def retrieve_timesteps(
102
+ scheduler,
103
+ num_inference_steps: Optional[int] = None,
104
+ device: Optional[Union[str, torch.device]] = None,
105
+ timesteps: Optional[List[int]] = None,
106
+ sigmas: Optional[List[float]] = None,
107
+ **kwargs,
108
+ ):
109
+ r"""
110
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
111
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
112
+
113
+ Args:
114
+ scheduler (`SchedulerMixin`):
115
+ The scheduler to get timesteps from.
116
+ num_inference_steps (`int`):
117
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
118
+ must be `None`.
119
+ device (`str` or `torch.device`, *optional*):
120
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
121
+ timesteps (`List[int]`, *optional*):
122
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
123
+ `num_inference_steps` and `sigmas` must be `None`.
124
+ sigmas (`List[float]`, *optional*):
125
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
126
+ `num_inference_steps` and `timesteps` must be `None`.
127
+
128
+ Returns:
129
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
+ second element is the number of inference steps.
131
+ """
132
+ if timesteps is not None and sigmas is not None:
133
+ raise ValueError(
134
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
135
+ )
136
+ if timesteps is not None:
137
+ accepts_timesteps = "timesteps" in set(
138
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
139
+ if not accepts_timesteps:
140
+ raise ValueError(
141
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
142
+ f" timestep schedules. Please check whether you are using the correct scheduler."
143
+ )
144
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ num_inference_steps = len(timesteps)
147
+ elif sigmas is not None:
148
+ accept_sigmas = "sigmas" in set(
149
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accept_sigmas:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ else:
159
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
160
+ timesteps = scheduler.timesteps
161
+ return timesteps, num_inference_steps
162
+
163
+
164
+ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
165
+ r"""
166
+ The mochi pipeline for text-to-video generation.
167
+
168
+ Reference: https://github.com/genmoai/models
169
+
170
+ Args:
171
+ transformer ([`MochiTransformer3DModel`]):
172
+ Conditional Transformer architecture to denoise the encoded video latents.
173
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
175
+ vae ([`AutoencoderKL`]):
176
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
177
+ text_encoder ([`T5EncoderModel`]):
178
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
179
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
180
+ tokenizer (`CLIPTokenizer`):
181
+ Tokenizer of class
182
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
183
+ tokenizer (`T5TokenizerFast`):
184
+ Second Tokenizer of class
185
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
186
+ """
187
+
188
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
189
+ _optional_components = []
190
+ _callback_tensor_inputs = [
191
+ "latents", "prompt_embeds", "negative_prompt_embeds"
192
+ ]
193
+
194
+ def __init__(
195
+ self,
196
+ scheduler: FlowMatchEulerDiscreteScheduler,
197
+ vae: AutoencoderKL,
198
+ text_encoder: T5EncoderModel,
199
+ tokenizer: T5TokenizerFast,
200
+ transformer: MochiTransformer3DModel,
201
+ ):
202
+ super().__init__()
203
+
204
+ self.register_modules(
205
+ vae=vae,
206
+ text_encoder=text_encoder,
207
+ tokenizer=tokenizer,
208
+ transformer=transformer,
209
+ scheduler=scheduler,
210
+ )
211
+ self.vae_spatial_scale_factor = 8
212
+ self.vae_temporal_scale_factor = 6
213
+ self.patch_size = 2
214
+
215
+ self.video_processor = VideoProcessor(
216
+ vae_scale_factor=self.vae_spatial_scale_factor)
217
+ self.tokenizer_max_length = (self.tokenizer.model_max_length
218
+ if hasattr(self, "tokenizer")
219
+ and self.tokenizer is not None else 77)
220
+ self.default_height = 480
221
+ self.default_width = 848
222
+
223
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
224
+ def _get_t5_prompt_embeds(
225
+ self,
226
+ prompt: Union[str, List[str]] = None,
227
+ num_videos_per_prompt: int = 1,
228
+ max_sequence_length: int = 256,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ ):
232
+ device = device or self._execution_device
233
+ dtype = dtype or self.text_encoder.dtype
234
+
235
+ prompt = [prompt] if isinstance(prompt, str) else prompt
236
+ batch_size = len(prompt)
237
+
238
+ text_inputs = self.tokenizer(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ add_special_tokens=True,
244
+ return_tensors="pt",
245
+ )
246
+ text_input_ids = text_inputs.input_ids
247
+ prompt_attention_mask = text_inputs.attention_mask
248
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
249
+
250
+ untruncated_ids = self.tokenizer(prompt,
251
+ padding="longest",
252
+ return_tensors="pt").input_ids
253
+
254
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
255
+ -1] and not torch.equal(text_input_ids, untruncated_ids):
256
+ removed_text = self.tokenizer.batch_decode(
257
+ untruncated_ids[:, max_sequence_length - 1:-1])
258
+ logger.warning(
259
+ "The following part of your input was truncated because `max_sequence_length` is set to "
260
+ f" {max_sequence_length} tokens: {removed_text}")
261
+
262
+ prompt_embeds = self.text_encoder(
263
+ text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
264
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
265
+
266
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
267
+ _, seq_len, _ = prompt_embeds.shape
268
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
269
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt,
270
+ seq_len, -1)
271
+
272
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
273
+ prompt_attention_mask = prompt_attention_mask.repeat(
274
+ num_videos_per_prompt, 1)
275
+
276
+ return prompt_embeds, prompt_attention_mask
277
+
278
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
279
+ def encode_prompt(
280
+ self,
281
+ prompt: Union[str, List[str]],
282
+ negative_prompt: Optional[Union[str, List[str]]] = None,
283
+ do_classifier_free_guidance: bool = True,
284
+ num_videos_per_prompt: int = 1,
285
+ prompt_embeds: Optional[torch.Tensor] = None,
286
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
287
+ prompt_attention_mask: Optional[torch.Tensor] = None,
288
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
289
+ max_sequence_length: int = 256,
290
+ device: Optional[torch.device] = None,
291
+ dtype: Optional[torch.dtype] = None,
292
+ ):
293
+ r"""
294
+ Encodes the prompt into text encoder hidden states.
295
+
296
+ Args:
297
+ prompt (`str` or `List[str]`, *optional*):
298
+ prompt to be encoded
299
+ negative_prompt (`str` or `List[str]`, *optional*):
300
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
301
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
302
+ less than `1`).
303
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
304
+ Whether to use classifier free guidance or not.
305
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
306
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
307
+ prompt_embeds (`torch.Tensor`, *optional*):
308
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
309
+ provided, text embeddings will be generated from `prompt` input argument.
310
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
311
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
312
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
313
+ argument.
314
+ device: (`torch.device`, *optional*):
315
+ torch device
316
+ dtype: (`torch.dtype`, *optional*):
317
+ torch dtype
318
+ """
319
+ device = device or self._execution_device
320
+
321
+ prompt = [prompt] if isinstance(prompt, str) else prompt
322
+ if prompt is not None:
323
+ batch_size = len(prompt)
324
+ else:
325
+ batch_size = prompt_embeds.shape[0]
326
+
327
+ if prompt_embeds is None:
328
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
329
+ prompt=prompt,
330
+ num_videos_per_prompt=num_videos_per_prompt,
331
+ max_sequence_length=max_sequence_length,
332
+ device=device,
333
+ dtype=dtype,
334
+ )
335
+
336
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
337
+ negative_prompt = negative_prompt or ""
338
+ negative_prompt = (batch_size * [negative_prompt] if isinstance(
339
+ negative_prompt, str) else negative_prompt)
340
+
341
+ if prompt is not None and type(prompt) is not type(
342
+ negative_prompt):
343
+ raise TypeError(
344
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
345
+ f" {type(prompt)}.")
346
+ elif batch_size != len(negative_prompt):
347
+ raise ValueError(
348
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
349
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
350
+ " the batch size of `prompt`.")
351
+
352
+ (
353
+ negative_prompt_embeds,
354
+ negative_prompt_attention_mask,
355
+ ) = self._get_t5_prompt_embeds(
356
+ prompt=negative_prompt,
357
+ num_videos_per_prompt=num_videos_per_prompt,
358
+ max_sequence_length=max_sequence_length,
359
+ device=device,
360
+ dtype=dtype,
361
+ )
362
+
363
+ return (
364
+ prompt_embeds,
365
+ prompt_attention_mask,
366
+ negative_prompt_embeds,
367
+ negative_prompt_attention_mask,
368
+ )
369
+
370
+ def check_inputs(
371
+ self,
372
+ prompt,
373
+ height,
374
+ width,
375
+ callback_on_step_end_tensor_inputs=None,
376
+ prompt_embeds=None,
377
+ negative_prompt_embeds=None,
378
+ prompt_attention_mask=None,
379
+ negative_prompt_attention_mask=None,
380
+ ):
381
+ if height % 8 != 0 or width % 8 != 0:
382
+ raise ValueError(
383
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
384
+ )
385
+
386
+ if callback_on_step_end_tensor_inputs is not None and not all(
387
+ k in self._callback_tensor_inputs
388
+ for k in callback_on_step_end_tensor_inputs):
389
+ raise ValueError(
390
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
391
+ )
392
+
393
+ if prompt is not None and prompt_embeds is not None:
394
+ raise ValueError(
395
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
396
+ " only forward one of the two.")
397
+ elif prompt is None and prompt_embeds is None:
398
+ raise ValueError(
399
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
+ )
401
+ elif prompt is not None and (not isinstance(prompt, str)
402
+ and not isinstance(prompt, list)):
403
+ raise ValueError(
404
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
405
+ )
406
+
407
+ if prompt_embeds is not None and prompt_attention_mask is None:
408
+ raise ValueError(
409
+ "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
410
+ )
411
+
412
+ if (negative_prompt_embeds is not None
413
+ and negative_prompt_attention_mask is None):
414
+ raise ValueError(
415
+ "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
416
+ )
417
+
418
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
419
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
420
+ raise ValueError(
421
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
422
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
423
+ f" {negative_prompt_embeds.shape}.")
424
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
425
+ raise ValueError(
426
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
427
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
428
+ f" {negative_prompt_attention_mask.shape}.")
429
+
430
+ def enable_vae_slicing(self):
431
+ r"""
432
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
433
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
434
+ """
435
+ self.vae.enable_slicing()
436
+
437
+ def disable_vae_slicing(self):
438
+ r"""
439
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
440
+ computing decoding in one step.
441
+ """
442
+ self.vae.disable_slicing()
443
+
444
+ def enable_vae_tiling(self):
445
+ r"""
446
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
447
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
448
+ processing larger images.
449
+ """
450
+ self.vae.enable_tiling()
451
+
452
+ def disable_vae_tiling(self):
453
+ r"""
454
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
455
+ computing decoding in one step.
456
+ """
457
+ self.vae.disable_tiling()
458
+
459
+ def prepare_latents(
460
+ self,
461
+ batch_size,
462
+ num_channels_latents,
463
+ height,
464
+ width,
465
+ num_frames,
466
+ dtype,
467
+ device,
468
+ generator,
469
+ latents=None,
470
+ ):
471
+ height = height // self.vae_spatial_scale_factor
472
+ width = width // self.vae_spatial_scale_factor
473
+ num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
474
+
475
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
476
+
477
+ if latents is not None:
478
+ return latents.to(device=device, dtype=dtype)
479
+ if isinstance(generator, list) and len(generator) != batch_size:
480
+ raise ValueError(
481
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
482
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
483
+ )
484
+
485
+ latents = randn_tensor(shape,
486
+ generator=generator,
487
+ device=device,
488
+ dtype=torch.float32)
489
+ latents = latents.to(dtype)
490
+ return latents
491
+
492
+ @property
493
+ def guidance_scale(self):
494
+ return self._guidance_scale
495
+
496
+ @property
497
+ def do_classifier_free_guidance(self):
498
+ return self._guidance_scale > 1.0
499
+
500
+ @property
501
+ def num_timesteps(self):
502
+ return self._num_timesteps
503
+
504
+ @property
505
+ def attention_kwargs(self):
506
+ return self._attention_kwargs
507
+
508
+ @property
509
+ def interrupt(self):
510
+ return self._interrupt
511
+
512
+ @torch.no_grad()
513
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
514
+ def __call__(
515
+ self,
516
+ prompt: Union[str, List[str]] = None,
517
+ negative_prompt: Optional[Union[str, List[str]]] = None,
518
+ height: Optional[int] = None,
519
+ width: Optional[int] = None,
520
+ num_frames: int = 19,
521
+ num_inference_steps: int = 64,
522
+ timesteps: List[int] = None,
523
+ guidance_scale: float = 4.5,
524
+ num_videos_per_prompt: Optional[int] = 1,
525
+ generator: Optional[Union[torch.Generator,
526
+ List[torch.Generator]]] = None,
527
+ latents: Optional[torch.Tensor] = None,
528
+ prompt_embeds: Optional[torch.Tensor] = None,
529
+ prompt_attention_mask: Optional[torch.Tensor] = None,
530
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
531
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
532
+ output_type: Optional[str] = "pil",
533
+ return_dict: bool = True,
534
+ attention_kwargs: Optional[Dict[str, Any]] = None,
535
+ callback_on_step_end: Optional[Callable[[int, int, Dict],
536
+ None]] = None,
537
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
538
+ max_sequence_length: int = 256,
539
+ return_all_states=False,
540
+ ):
541
+ r"""
542
+ Function invoked when calling the pipeline for generation.
543
+
544
+ Args:
545
+ prompt (`str` or `List[str]`, *optional*):
546
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
547
+ instead.
548
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
549
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
550
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
551
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
552
+ num_frames (`int`, defaults to 16):
553
+ The number of video frames to generate
554
+ num_inference_steps (`int`, *optional*, defaults to 50):
555
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
556
+ expense of slower inference.
557
+ timesteps (`List[int]`, *optional*):
558
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
559
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
560
+ passed will be used. Must be in descending order.
561
+ guidance_scale (`float`, defaults to `4.5`):
562
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
563
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
564
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
565
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
566
+ usually at the expense of lower image quality.
567
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
568
+ The number of videos to generate per prompt.
569
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
570
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
571
+ to make generation deterministic.
572
+ latents (`torch.Tensor`, *optional*):
573
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
574
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
575
+ tensor will ge generated by sampling using the supplied random `generator`.
576
+ prompt_embeds (`torch.Tensor`, *optional*):
577
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
578
+ provided, text embeddings will be generated from `prompt` input argument.
579
+ prompt_attention_mask (`torch.Tensor`, *optional*):
580
+ Pre-generated attention mask for text embeddings.
581
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
582
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
583
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
584
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
585
+ Pre-generated attention mask for negative text embeddings.
586
+ output_type (`str`, *optional*, defaults to `"pil"`):
587
+ The output format of the generate image. Choose between
588
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
589
+ return_dict (`bool`, *optional*, defaults to `True`):
590
+ Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
591
+ attention_kwargs (`dict`, *optional*):
592
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
593
+ `self.processor` in
594
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
595
+ callback_on_step_end (`Callable`, *optional*):
596
+ A function that calls at the end of each denoising steps during the inference. The function is called
597
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
598
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
599
+ `callback_on_step_end_tensor_inputs`.
600
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
601
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
602
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
603
+ `._callback_tensor_inputs` attribute of your pipeline class.
604
+ max_sequence_length (`int` defaults to `256`):
605
+ Maximum sequence length to use with the `prompt`.
606
+
607
+ Examples:
608
+
609
+ Returns:
610
+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
611
+ If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
612
+ is returned where the first element is a list with the generated images.
613
+ """
614
+
615
+ if isinstance(callback_on_step_end,
616
+ (PipelineCallback, MultiPipelineCallbacks)):
617
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
618
+
619
+ height = height or self.default_height
620
+ width = width or self.default_width
621
+
622
+ # 1. Check inputs. Raise error if not correct
623
+ self.check_inputs(
624
+ prompt=prompt,
625
+ height=height,
626
+ width=width,
627
+ callback_on_step_end_tensor_inputs=
628
+ callback_on_step_end_tensor_inputs,
629
+ prompt_embeds=prompt_embeds,
630
+ negative_prompt_embeds=negative_prompt_embeds,
631
+ prompt_attention_mask=prompt_attention_mask,
632
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
633
+ )
634
+
635
+ self._guidance_scale = guidance_scale
636
+ self._attention_kwargs = attention_kwargs
637
+ self._interrupt = False
638
+
639
+ # 2. Define call parameters
640
+ if prompt is not None and isinstance(prompt, str):
641
+ batch_size = 1
642
+ elif prompt is not None and isinstance(prompt, list):
643
+ batch_size = len(prompt)
644
+ else:
645
+ batch_size = prompt_embeds.shape[0]
646
+
647
+ device = self._execution_device
648
+
649
+ # 3. Prepare text embeddings
650
+ (
651
+ prompt_embeds,
652
+ prompt_attention_mask,
653
+ negative_prompt_embeds,
654
+ negative_prompt_attention_mask,
655
+ ) = self.encode_prompt(
656
+ prompt=prompt,
657
+ negative_prompt=negative_prompt,
658
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
659
+ num_videos_per_prompt=num_videos_per_prompt,
660
+ prompt_embeds=prompt_embeds,
661
+ negative_prompt_embeds=negative_prompt_embeds,
662
+ prompt_attention_mask=prompt_attention_mask,
663
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
664
+ max_sequence_length=max_sequence_length,
665
+ device=device,
666
+ )
667
+ if self.do_classifier_free_guidance:
668
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds],
669
+ dim=0)
670
+ prompt_attention_mask = torch.cat(
671
+ [negative_prompt_attention_mask, prompt_attention_mask], dim=0)
672
+
673
+ # 4. Prepare latent variables
674
+ num_channels_latents = self.transformer.config.in_channels
675
+ latents = self.prepare_latents(
676
+ batch_size * num_videos_per_prompt,
677
+ num_channels_latents,
678
+ height,
679
+ width,
680
+ num_frames,
681
+ prompt_embeds.dtype,
682
+ device,
683
+ generator,
684
+ latents,
685
+ )
686
+ world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
687
+ if get_sequence_parallel_state():
688
+ latents = rearrange(latents,
689
+ "b t (n s) h w -> b t n s h w",
690
+ n=world_size).contiguous()
691
+ latents = latents[:, :, rank, :, :, :]
692
+
693
+ original_noise = copy.deepcopy(latents)
694
+ # 5. Prepare timestep
695
+ # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
696
+ threshold_noise = 0.025
697
+ sigmas = linear_quadratic_schedule(num_inference_steps,
698
+ threshold_noise)
699
+ sigmas = np.array(sigmas)
700
+ # check if of type FlowMatchEulerDiscreteScheduler
701
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
702
+ timesteps, num_inference_steps = retrieve_timesteps(
703
+ self.scheduler,
704
+ num_inference_steps,
705
+ device,
706
+ timesteps,
707
+ sigmas,
708
+ )
709
+ else:
710
+ timesteps, num_inference_steps = retrieve_timesteps(
711
+ self.scheduler,
712
+ num_inference_steps,
713
+ device,
714
+ )
715
+ num_warmup_steps = max(
716
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0)
717
+ self._num_timesteps = len(timesteps)
718
+
719
+ # 6. Denoising loop
720
+ self._progress_bar_config = {
721
+ "disable": nccl_info.rank_within_group != 0
722
+ }
723
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
724
+ for i, t in enumerate(timesteps):
725
+ if self.interrupt:
726
+ continue
727
+
728
+ latent_model_input = (torch.cat(
729
+ [latents] *
730
+ 2) if self.do_classifier_free_guidance else latents)
731
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
732
+ timestep = t.expand(latent_model_input.shape[0]).to(
733
+ latents.dtype)
734
+
735
+ noise_pred = self.transformer(
736
+ hidden_states=latent_model_input,
737
+ encoder_hidden_states=prompt_embeds,
738
+ timestep=timestep,
739
+ encoder_attention_mask=prompt_attention_mask,
740
+ attention_kwargs=attention_kwargs,
741
+ return_dict=False,
742
+ )[0]
743
+
744
+ # Mochi CFG + Sampling runs in FP32
745
+ noise_pred = noise_pred.to(torch.float32)
746
+ if self.do_classifier_free_guidance:
747
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
748
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
749
+ noise_pred_text - noise_pred_uncond)
750
+
751
+ # compute the previous noisy sample x_t -> x_t-1
752
+ latents_dtype = latents.dtype
753
+ latents = self.scheduler.step(noise_pred,
754
+ t,
755
+ latents.to(torch.float32),
756
+ return_dict=False)[0]
757
+ latents = latents.to(latents_dtype)
758
+
759
+ if latents.dtype != latents_dtype:
760
+ if torch.backends.mps.is_available():
761
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
762
+ latents = latents.to(latents_dtype)
763
+
764
+ if callback_on_step_end is not None:
765
+ callback_kwargs = {}
766
+ for k in callback_on_step_end_tensor_inputs:
767
+ callback_kwargs[k] = locals()[k]
768
+ callback_outputs = callback_on_step_end(
769
+ self, i, t, callback_kwargs)
770
+
771
+ latents = callback_outputs.pop("latents", latents)
772
+ prompt_embeds = callback_outputs.pop(
773
+ "prompt_embeds", prompt_embeds)
774
+
775
+ # call the callback, if provided
776
+ if i == len(timesteps) - 1 or (
777
+ (i + 1) > num_warmup_steps and
778
+ (i + 1) % self.scheduler.order == 0):
779
+ progress_bar.update()
780
+
781
+ if XLA_AVAILABLE:
782
+ xm.mark_step()
783
+
784
+ if get_sequence_parallel_state():
785
+ latents = all_gather(latents, dim=2)
786
+ # latents_shape = list(latents.shape)
787
+ # full_shape = [latents_shape[0] * world_size] + latents_shape[1:]
788
+ # all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device)
789
+ # torch.distributed.all_gather_into_tensor(all_latents, latents)
790
+ # latents_list = list(all_latents.chunk(world_size, dim=0))
791
+ # latents = torch.cat(latents_list, dim=2)
792
+
793
+ if output_type == "latent":
794
+ video = latents
795
+ else:
796
+ # unscale/denormalize the latents
797
+ # denormalize with the mean and std if available and not None
798
+ has_latents_mean = (hasattr(self.vae.config, "latents_mean")
799
+ and self.vae.config.latents_mean is not None)
800
+ has_latents_std = (hasattr(self.vae.config, "latents_std")
801
+ and self.vae.config.latents_std is not None)
802
+ if has_latents_mean and has_latents_std:
803
+ latents_mean = (torch.tensor(
804
+ self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(
805
+ latents.device, latents.dtype))
806
+ latents_std = (torch.tensor(self.vae.config.latents_std).view(
807
+ 1, 12, 1, 1, 1).to(latents.device, latents.dtype))
808
+ latents = (
809
+ latents * latents_std / self.vae.config.scaling_factor +
810
+ latents_mean)
811
+ else:
812
+ latents = latents / self.vae.config.scaling_factor
813
+
814
+ video = self.vae.decode(latents, return_dict=False)[0]
815
+ video = self.video_processor.postprocess_video(
816
+ video, output_type=output_type)
817
+
818
+ # Offload all models
819
+ self.maybe_free_model_hooks()
820
+ if return_all_states:
821
+ # Pay extra attention here:
822
+ # prompt_embeds with shape torch.Size([2, 256]), where prompt_embeds[1] is the prompt_embeds for the actual prompt
823
+ # prompt_embeds[0] is for negative prompt
824
+ return original_noise, video, latents, prompt_embeds, prompt_attention_mask
825
+
826
+ if not return_dict:
827
+ return (video, )
828
+
829
+ return MochiPipelineOutput(frames=video)
fastvideo/models/qwenimage/__init__.py ADDED
File without changes
fastvideo/models/qwenimage/autoencoder_kl_qwenimage.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # We gratefully acknowledge the Wan Team for their outstanding contributions.
16
+ # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
17
+ # For more information about the Wan VAE, please refer to:
18
+ # - GitHub: https://github.com/Wan-Video/Wan2.1
19
+ # - arXiv: https://arxiv.org/abs/2503.20314
20
+
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.loaders import FromOriginalModelMixin
30
+ from diffusers.utils import logging
31
+ from diffusers.utils.accelerate_utils import apply_forward_hook
32
+ from diffusers.models.activations import get_activation
33
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ CACHE_T = 2
41
+
42
+
43
+ class QwenImageCausalConv3d(nn.Conv3d):
44
+ r"""
45
+ A custom 3D causal convolution layer with feature caching support.
46
+
47
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
48
+ caching for efficient inference.
49
+
50
+ Args:
51
+ in_channels (int): Number of channels in the input image
52
+ out_channels (int): Number of channels produced by the convolution
53
+ kernel_size (int or tuple): Size of the convolving kernel
54
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
55
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size: Union[int, Tuple[int, int, int]],
63
+ stride: Union[int, Tuple[int, int, int]] = 1,
64
+ padding: Union[int, Tuple[int, int, int]] = 0,
65
+ ) -> None:
66
+ super().__init__(
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ kernel_size=kernel_size,
70
+ stride=stride,
71
+ padding=padding,
72
+ )
73
+
74
+ # Set up causal padding
75
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
76
+ self.padding = (0, 0, 0)
77
+
78
+ def forward(self, x, cache_x=None):
79
+ padding = list(self._padding)
80
+ if cache_x is not None and self._padding[4] > 0:
81
+ cache_x = cache_x.to(x.device)
82
+ x = torch.cat([cache_x, x], dim=2)
83
+ padding[4] -= cache_x.shape[2]
84
+ x = F.pad(x, padding)
85
+ return super().forward(x)
86
+
87
+
88
+ class QwenImageRMS_norm(nn.Module):
89
+ r"""
90
+ A custom RMS normalization layer.
91
+
92
+ Args:
93
+ dim (int): The number of dimensions to normalize over.
94
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
95
+ Default is True.
96
+ images (bool, optional): Whether the input represents image data. Default is True.
97
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
98
+ """
99
+
100
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
101
+ super().__init__()
102
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
103
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
104
+
105
+ self.channel_first = channel_first
106
+ self.scale = dim**0.5
107
+ self.gamma = nn.Parameter(torch.ones(shape))
108
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
109
+
110
+ def forward(self, x):
111
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
112
+
113
+
114
+ class QwenImageUpsample(nn.Upsample):
115
+ r"""
116
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
117
+
118
+ Args:
119
+ x (torch.Tensor): Input tensor to be upsampled.
120
+
121
+ Returns:
122
+ torch.Tensor: Upsampled tensor with the same data type as the input.
123
+ """
124
+
125
+ def forward(self, x):
126
+ return super().forward(x.float()).type_as(x)
127
+
128
+
129
+ class QwenImageResample(nn.Module):
130
+ r"""
131
+ A custom resampling module for 2D and 3D data.
132
+
133
+ Args:
134
+ dim (int): The number of input/output channels.
135
+ mode (str): The resampling mode. Must be one of:
136
+ - 'none': No resampling (identity operation).
137
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
138
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
139
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
140
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
141
+ """
142
+
143
+ def __init__(self, dim: int, mode: str) -> None:
144
+ super().__init__()
145
+ self.dim = dim
146
+ self.mode = mode
147
+
148
+ # layers
149
+ if mode == "upsample2d":
150
+ self.resample = nn.Sequential(
151
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
152
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
153
+ )
154
+ elif mode == "upsample3d":
155
+ self.resample = nn.Sequential(
156
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
157
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
158
+ )
159
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
160
+
161
+ elif mode == "downsample2d":
162
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
163
+ elif mode == "downsample3d":
164
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
165
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
166
+
167
+ else:
168
+ self.resample = nn.Identity()
169
+
170
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
171
+ b, c, t, h, w = x.size()
172
+ if self.mode == "upsample3d":
173
+ if feat_cache is not None:
174
+ idx = feat_idx[0]
175
+ if feat_cache[idx] is None:
176
+ feat_cache[idx] = "Rep"
177
+ feat_idx[0] += 1
178
+ else:
179
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
180
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
181
+ # cache last frame of last two chunk
182
+ cache_x = torch.cat(
183
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
184
+ )
185
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
186
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
187
+ if feat_cache[idx] == "Rep":
188
+ x = self.time_conv(x)
189
+ else:
190
+ x = self.time_conv(x, feat_cache[idx])
191
+ feat_cache[idx] = cache_x
192
+ feat_idx[0] += 1
193
+
194
+ x = x.reshape(b, 2, c, t, h, w)
195
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
196
+ x = x.reshape(b, c, t * 2, h, w)
197
+ t = x.shape[2]
198
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
199
+ x = self.resample(x)
200
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
201
+
202
+ if self.mode == "downsample3d":
203
+ if feat_cache is not None:
204
+ idx = feat_idx[0]
205
+ if feat_cache[idx] is None:
206
+ feat_cache[idx] = x.clone()
207
+ feat_idx[0] += 1
208
+ else:
209
+ cache_x = x[:, :, -1:, :, :].clone()
210
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
211
+ feat_cache[idx] = cache_x
212
+ feat_idx[0] += 1
213
+ return x
214
+
215
+
216
+ class QwenImageResidualBlock(nn.Module):
217
+ r"""
218
+ A custom residual block module.
219
+
220
+ Args:
221
+ in_dim (int): Number of input channels.
222
+ out_dim (int): Number of output channels.
223
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
224
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ in_dim: int,
230
+ out_dim: int,
231
+ dropout: float = 0.0,
232
+ non_linearity: str = "silu",
233
+ ) -> None:
234
+ super().__init__()
235
+ self.in_dim = in_dim
236
+ self.out_dim = out_dim
237
+ self.nonlinearity = get_activation(non_linearity)
238
+
239
+ # layers
240
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
241
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
242
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
243
+ self.dropout = nn.Dropout(dropout)
244
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
245
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
246
+
247
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
248
+ # Apply shortcut connection
249
+ h = self.conv_shortcut(x)
250
+
251
+ # First normalization and activation
252
+ x = self.norm1(x)
253
+ x = self.nonlinearity(x)
254
+
255
+ if feat_cache is not None:
256
+ idx = feat_idx[0]
257
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
258
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
259
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
260
+
261
+ x = self.conv1(x, feat_cache[idx])
262
+ feat_cache[idx] = cache_x
263
+ feat_idx[0] += 1
264
+ else:
265
+ x = self.conv1(x)
266
+
267
+ # Second normalization and activation
268
+ x = self.norm2(x)
269
+ x = self.nonlinearity(x)
270
+
271
+ # Dropout
272
+ x = self.dropout(x)
273
+
274
+ if feat_cache is not None:
275
+ idx = feat_idx[0]
276
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
277
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
278
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
279
+
280
+ x = self.conv2(x, feat_cache[idx])
281
+ feat_cache[idx] = cache_x
282
+ feat_idx[0] += 1
283
+ else:
284
+ x = self.conv2(x)
285
+
286
+ # Add residual connection
287
+ return x + h
288
+
289
+
290
+ class QwenImageAttentionBlock(nn.Module):
291
+ r"""
292
+ Causal self-attention with a single head.
293
+
294
+ Args:
295
+ dim (int): The number of channels in the input tensor.
296
+ """
297
+
298
+ def __init__(self, dim):
299
+ super().__init__()
300
+ self.dim = dim
301
+
302
+ # layers
303
+ self.norm = QwenImageRMS_norm(dim)
304
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
305
+ self.proj = nn.Conv2d(dim, dim, 1)
306
+
307
+ def forward(self, x):
308
+ identity = x
309
+ batch_size, channels, time, height, width = x.size()
310
+
311
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
312
+ x = self.norm(x)
313
+
314
+ # compute query, key, value
315
+ qkv = self.to_qkv(x)
316
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
317
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
318
+ q, k, v = qkv.chunk(3, dim=-1)
319
+
320
+ # apply attention
321
+ x = F.scaled_dot_product_attention(q, k, v)
322
+
323
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
324
+
325
+ # output projection
326
+ x = self.proj(x)
327
+
328
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
329
+ x = x.view(batch_size, time, channels, height, width)
330
+ x = x.permute(0, 2, 1, 3, 4)
331
+
332
+ return x + identity
333
+
334
+
335
+ class QwenImageMidBlock(nn.Module):
336
+ """
337
+ Middle block for QwenImageVAE encoder and decoder.
338
+
339
+ Args:
340
+ dim (int): Number of input/output channels.
341
+ dropout (float): Dropout rate.
342
+ non_linearity (str): Type of non-linearity to use.
343
+ """
344
+
345
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
346
+ super().__init__()
347
+ self.dim = dim
348
+
349
+ # Create the components
350
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
351
+ attentions = []
352
+ for _ in range(num_layers):
353
+ attentions.append(QwenImageAttentionBlock(dim))
354
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
355
+ self.attentions = nn.ModuleList(attentions)
356
+ self.resnets = nn.ModuleList(resnets)
357
+
358
+ self.gradient_checkpointing = False
359
+
360
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
361
+ # First residual block
362
+ x = self.resnets[0](x, feat_cache, feat_idx)
363
+
364
+ # Process through attention and residual blocks
365
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
366
+ if attn is not None:
367
+ x = attn(x)
368
+
369
+ x = resnet(x, feat_cache, feat_idx)
370
+
371
+ return x
372
+
373
+
374
+ class QwenImageEncoder3d(nn.Module):
375
+ r"""
376
+ A 3D encoder module.
377
+
378
+ Args:
379
+ dim (int): The base number of channels in the first layer.
380
+ z_dim (int): The dimensionality of the latent space.
381
+ dim_mult (list of int): Multipliers for the number of channels in each block.
382
+ num_res_blocks (int): Number of residual blocks in each block.
383
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
384
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
385
+ dropout (float): Dropout rate for the dropout layers.
386
+ non_linearity (str): Type of non-linearity to use.
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ dim=128,
392
+ z_dim=4,
393
+ dim_mult=[1, 2, 4, 4],
394
+ num_res_blocks=2,
395
+ attn_scales=[],
396
+ temperal_downsample=[True, True, False],
397
+ dropout=0.0,
398
+ non_linearity: str = "silu",
399
+ ):
400
+ super().__init__()
401
+ self.dim = dim
402
+ self.z_dim = z_dim
403
+ self.dim_mult = dim_mult
404
+ self.num_res_blocks = num_res_blocks
405
+ self.attn_scales = attn_scales
406
+ self.temperal_downsample = temperal_downsample
407
+ self.nonlinearity = get_activation(non_linearity)
408
+
409
+ # dimensions
410
+ dims = [dim * u for u in [1] + dim_mult]
411
+ scale = 1.0
412
+
413
+ # init block
414
+ self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
415
+
416
+ # downsample blocks
417
+ self.down_blocks = nn.ModuleList([])
418
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
419
+ # residual (+attention) blocks
420
+ for _ in range(num_res_blocks):
421
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
422
+ if scale in attn_scales:
423
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
424
+ in_dim = out_dim
425
+
426
+ # downsample block
427
+ if i != len(dim_mult) - 1:
428
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
429
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
430
+ scale /= 2.0
431
+
432
+ # middle blocks
433
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
434
+
435
+ # output blocks
436
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
437
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
438
+
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
442
+ if feat_cache is not None:
443
+ idx = feat_idx[0]
444
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
445
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
446
+ # cache last frame of last two chunk
447
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
448
+ x = self.conv_in(x, feat_cache[idx])
449
+ feat_cache[idx] = cache_x
450
+ feat_idx[0] += 1
451
+ else:
452
+ x = self.conv_in(x)
453
+
454
+ ## downsamples
455
+ for layer in self.down_blocks:
456
+ if feat_cache is not None:
457
+ x = layer(x, feat_cache, feat_idx)
458
+ else:
459
+ x = layer(x)
460
+
461
+ ## middle
462
+ x = self.mid_block(x, feat_cache, feat_idx)
463
+
464
+ ## head
465
+ x = self.norm_out(x)
466
+ x = self.nonlinearity(x)
467
+ if feat_cache is not None:
468
+ idx = feat_idx[0]
469
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
470
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
471
+ # cache last frame of last two chunk
472
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
473
+ x = self.conv_out(x, feat_cache[idx])
474
+ feat_cache[idx] = cache_x
475
+ feat_idx[0] += 1
476
+ else:
477
+ x = self.conv_out(x)
478
+ return x
479
+
480
+
481
+ class QwenImageUpBlock(nn.Module):
482
+ """
483
+ A block that handles upsampling for the QwenImageVAE decoder.
484
+
485
+ Args:
486
+ in_dim (int): Input dimension
487
+ out_dim (int): Output dimension
488
+ num_res_blocks (int): Number of residual blocks
489
+ dropout (float): Dropout rate
490
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
491
+ non_linearity (str): Type of non-linearity to use
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ in_dim: int,
497
+ out_dim: int,
498
+ num_res_blocks: int,
499
+ dropout: float = 0.0,
500
+ upsample_mode: Optional[str] = None,
501
+ non_linearity: str = "silu",
502
+ ):
503
+ super().__init__()
504
+ self.in_dim = in_dim
505
+ self.out_dim = out_dim
506
+
507
+ # Create layers list
508
+ resnets = []
509
+ # Add residual blocks and attention if needed
510
+ current_dim = in_dim
511
+ for _ in range(num_res_blocks + 1):
512
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
513
+ current_dim = out_dim
514
+
515
+ self.resnets = nn.ModuleList(resnets)
516
+
517
+ # Add upsampling layer if needed
518
+ self.upsamplers = None
519
+ if upsample_mode is not None:
520
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
521
+
522
+ self.gradient_checkpointing = False
523
+
524
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
525
+ """
526
+ Forward pass through the upsampling block.
527
+
528
+ Args:
529
+ x (torch.Tensor): Input tensor
530
+ feat_cache (list, optional): Feature cache for causal convolutions
531
+ feat_idx (list, optional): Feature index for cache management
532
+
533
+ Returns:
534
+ torch.Tensor: Output tensor
535
+ """
536
+ for resnet in self.resnets:
537
+ if feat_cache is not None:
538
+ x = resnet(x, feat_cache, feat_idx)
539
+ else:
540
+ x = resnet(x)
541
+
542
+ if self.upsamplers is not None:
543
+ if feat_cache is not None:
544
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
545
+ else:
546
+ x = self.upsamplers[0](x)
547
+ return x
548
+
549
+
550
+ class QwenImageDecoder3d(nn.Module):
551
+ r"""
552
+ A 3D decoder module.
553
+
554
+ Args:
555
+ dim (int): The base number of channels in the first layer.
556
+ z_dim (int): The dimensionality of the latent space.
557
+ dim_mult (list of int): Multipliers for the number of channels in each block.
558
+ num_res_blocks (int): Number of residual blocks in each block.
559
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
560
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
561
+ dropout (float): Dropout rate for the dropout layers.
562
+ non_linearity (str): Type of non-linearity to use.
563
+ """
564
+
565
+ def __init__(
566
+ self,
567
+ dim=128,
568
+ z_dim=4,
569
+ dim_mult=[1, 2, 4, 4],
570
+ num_res_blocks=2,
571
+ attn_scales=[],
572
+ temperal_upsample=[False, True, True],
573
+ dropout=0.0,
574
+ non_linearity: str = "silu",
575
+ ):
576
+ super().__init__()
577
+ self.dim = dim
578
+ self.z_dim = z_dim
579
+ self.dim_mult = dim_mult
580
+ self.num_res_blocks = num_res_blocks
581
+ self.attn_scales = attn_scales
582
+ self.temperal_upsample = temperal_upsample
583
+
584
+ self.nonlinearity = get_activation(non_linearity)
585
+
586
+ # dimensions
587
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
588
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
589
+
590
+ # init block
591
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
592
+
593
+ # middle blocks
594
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
595
+
596
+ # upsample blocks
597
+ self.up_blocks = nn.ModuleList([])
598
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
599
+ # residual (+attention) blocks
600
+ if i > 0:
601
+ in_dim = in_dim // 2
602
+
603
+ # Determine if we need upsampling
604
+ upsample_mode = None
605
+ if i != len(dim_mult) - 1:
606
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
607
+
608
+ # Create and add the upsampling block
609
+ up_block = QwenImageUpBlock(
610
+ in_dim=in_dim,
611
+ out_dim=out_dim,
612
+ num_res_blocks=num_res_blocks,
613
+ dropout=dropout,
614
+ upsample_mode=upsample_mode,
615
+ non_linearity=non_linearity,
616
+ )
617
+ self.up_blocks.append(up_block)
618
+
619
+ # Update scale for next iteration
620
+ if upsample_mode is not None:
621
+ scale *= 2.0
622
+
623
+ # output blocks
624
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
625
+ self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
626
+
627
+ self.gradient_checkpointing = False
628
+
629
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
630
+ ## conv1
631
+ if feat_cache is not None:
632
+ idx = feat_idx[0]
633
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
634
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
635
+ # cache last frame of last two chunk
636
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
637
+ x = self.conv_in(x, feat_cache[idx])
638
+ feat_cache[idx] = cache_x
639
+ feat_idx[0] += 1
640
+ else:
641
+ x = self.conv_in(x)
642
+
643
+ ## middle
644
+ x = self.mid_block(x, feat_cache, feat_idx)
645
+
646
+ ## upsamples
647
+ for up_block in self.up_blocks:
648
+ x = up_block(x, feat_cache, feat_idx)
649
+
650
+ ## head
651
+ x = self.norm_out(x)
652
+ x = self.nonlinearity(x)
653
+ if feat_cache is not None:
654
+ idx = feat_idx[0]
655
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
656
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
657
+ # cache last frame of last two chunk
658
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
659
+ x = self.conv_out(x, feat_cache[idx])
660
+ feat_cache[idx] = cache_x
661
+ feat_idx[0] += 1
662
+ else:
663
+ x = self.conv_out(x)
664
+ return x
665
+
666
+
667
+ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
668
+ r"""
669
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
670
+
671
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
672
+ for all models (such as downloading or saving).
673
+ """
674
+
675
+ _supports_gradient_checkpointing = False
676
+
677
+ # fmt: off
678
+ @register_to_config
679
+ def __init__(
680
+ self,
681
+ base_dim: int = 96,
682
+ z_dim: int = 16,
683
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
684
+ num_res_blocks: int = 2,
685
+ attn_scales: List[float] = [],
686
+ temperal_downsample: List[bool] = [False, True, True],
687
+ dropout: float = 0.0,
688
+ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
689
+ latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
690
+ ) -> None:
691
+ # fmt: on
692
+ super().__init__()
693
+
694
+ self.z_dim = z_dim
695
+ self.temperal_downsample = temperal_downsample
696
+ self.temperal_upsample = temperal_downsample[::-1]
697
+
698
+ self.encoder = QwenImageEncoder3d(
699
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
700
+ )
701
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
702
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
703
+
704
+ self.decoder = QwenImageDecoder3d(
705
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
706
+ )
707
+
708
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
709
+
710
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
711
+ # to perform decoding of a single video latent at a time.
712
+ self.use_slicing = False
713
+
714
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
715
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
716
+ # intermediate tiles together, the memory requirement can be lowered.
717
+ self.use_tiling = False
718
+
719
+ # The minimal tile height and width for spatial tiling to be used
720
+ self.tile_sample_min_height = 256
721
+ self.tile_sample_min_width = 256
722
+
723
+ # The minimal distance between two spatial tiles
724
+ self.tile_sample_stride_height = 192
725
+ self.tile_sample_stride_width = 192
726
+
727
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
728
+ self._cached_conv_counts = {
729
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
730
+ if self.decoder is not None
731
+ else 0,
732
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
733
+ if self.encoder is not None
734
+ else 0,
735
+ }
736
+
737
+ def enable_tiling(
738
+ self,
739
+ tile_sample_min_height: Optional[int] = None,
740
+ tile_sample_min_width: Optional[int] = None,
741
+ tile_sample_stride_height: Optional[float] = None,
742
+ tile_sample_stride_width: Optional[float] = None,
743
+ ) -> None:
744
+ r"""
745
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
746
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
747
+ processing larger images.
748
+
749
+ Args:
750
+ tile_sample_min_height (`int`, *optional*):
751
+ The minimum height required for a sample to be separated into tiles across the height dimension.
752
+ tile_sample_min_width (`int`, *optional*):
753
+ The minimum width required for a sample to be separated into tiles across the width dimension.
754
+ tile_sample_stride_height (`int`, *optional*):
755
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
756
+ no tiling artifacts produced across the height dimension.
757
+ tile_sample_stride_width (`int`, *optional*):
758
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
759
+ artifacts produced across the width dimension.
760
+ """
761
+ self.use_tiling = True
762
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
763
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
764
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
765
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
766
+
767
+ def disable_tiling(self) -> None:
768
+ r"""
769
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
770
+ decoding in one step.
771
+ """
772
+ self.use_tiling = False
773
+
774
+ def enable_slicing(self) -> None:
775
+ r"""
776
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
777
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
778
+ """
779
+ self.use_slicing = True
780
+
781
+ def disable_slicing(self) -> None:
782
+ r"""
783
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
784
+ decoding in one step.
785
+ """
786
+ self.use_slicing = False
787
+
788
+ def clear_cache(self):
789
+ def _count_conv3d(model):
790
+ count = 0
791
+ for m in model.modules():
792
+ if isinstance(m, QwenImageCausalConv3d):
793
+ count += 1
794
+ return count
795
+
796
+ self._conv_num = _count_conv3d(self.decoder)
797
+ self._conv_idx = [0]
798
+ self._feat_map = [None] * self._conv_num
799
+ # cache encode
800
+ self._enc_conv_num = _count_conv3d(self.encoder)
801
+ self._enc_conv_idx = [0]
802
+ self._enc_feat_map = [None] * self._enc_conv_num
803
+
804
+ def _encode(self, x: torch.Tensor):
805
+ _, _, num_frame, height, width = x.shape
806
+
807
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
808
+ return self.tiled_encode(x)
809
+
810
+ self.clear_cache()
811
+ iter_ = 1 + (num_frame - 1) // 4
812
+ for i in range(iter_):
813
+ self._enc_conv_idx = [0]
814
+ if i == 0:
815
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
816
+ else:
817
+ out_ = self.encoder(
818
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
819
+ feat_cache=self._enc_feat_map,
820
+ feat_idx=self._enc_conv_idx,
821
+ )
822
+ out = torch.cat([out, out_], 2)
823
+
824
+ enc = self.quant_conv(out)
825
+ self.clear_cache()
826
+ return enc
827
+
828
+ @apply_forward_hook
829
+ def encode(
830
+ self, x: torch.Tensor, return_dict: bool = True
831
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
832
+ r"""
833
+ Encode a batch of images into latents.
834
+
835
+ Args:
836
+ x (`torch.Tensor`): Input batch of images.
837
+ return_dict (`bool`, *optional*, defaults to `True`):
838
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
839
+
840
+ Returns:
841
+ The latent representations of the encoded videos. If `return_dict` is True, a
842
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
843
+ """
844
+ if self.use_slicing and x.shape[0] > 1:
845
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
846
+ h = torch.cat(encoded_slices)
847
+ else:
848
+ h = self._encode(x)
849
+ posterior = DiagonalGaussianDistribution(h)
850
+
851
+ if not return_dict:
852
+ return (posterior,)
853
+ return AutoencoderKLOutput(latent_dist=posterior)
854
+
855
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
856
+ _, _, num_frame, height, width = z.shape
857
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
858
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
859
+
860
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
861
+ return self.tiled_decode(z, return_dict=return_dict)
862
+
863
+ self.clear_cache()
864
+ x = self.post_quant_conv(z)
865
+ for i in range(num_frame):
866
+ self._conv_idx = [0]
867
+ if i == 0:
868
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
869
+ else:
870
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
871
+ out = torch.cat([out, out_], 2)
872
+
873
+ out = torch.clamp(out, min=-1.0, max=1.0)
874
+ self.clear_cache()
875
+ if not return_dict:
876
+ return (out,)
877
+
878
+ return DecoderOutput(sample=out)
879
+
880
+ @apply_forward_hook
881
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
882
+ r"""
883
+ Decode a batch of images.
884
+
885
+ Args:
886
+ z (`torch.Tensor`): Input batch of latent vectors.
887
+ return_dict (`bool`, *optional*, defaults to `True`):
888
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
889
+
890
+ Returns:
891
+ [`~models.vae.DecoderOutput`] or `tuple`:
892
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
893
+ returned.
894
+ """
895
+ if self.use_slicing and z.shape[0] > 1:
896
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
897
+ decoded = torch.cat(decoded_slices)
898
+ else:
899
+ decoded = self._decode(z).sample
900
+
901
+ if not return_dict:
902
+ return (decoded,)
903
+ return DecoderOutput(sample=decoded)
904
+
905
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
906
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
907
+ for y in range(blend_extent):
908
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
909
+ y / blend_extent
910
+ )
911
+ return b
912
+
913
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
914
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
915
+ for x in range(blend_extent):
916
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
917
+ x / blend_extent
918
+ )
919
+ return b
920
+
921
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
922
+ r"""Encode a batch of images using a tiled encoder.
923
+
924
+ Args:
925
+ x (`torch.Tensor`): Input batch of videos.
926
+
927
+ Returns:
928
+ `torch.Tensor`:
929
+ The latent representation of the encoded videos.
930
+ """
931
+ _, _, num_frames, height, width = x.shape
932
+ latent_height = height // self.spatial_compression_ratio
933
+ latent_width = width // self.spatial_compression_ratio
934
+
935
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
936
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
937
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
938
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
939
+
940
+ blend_height = tile_latent_min_height - tile_latent_stride_height
941
+ blend_width = tile_latent_min_width - tile_latent_stride_width
942
+
943
+ # Split x into overlapping tiles and encode them separately.
944
+ # The tiles have an overlap to avoid seams between tiles.
945
+ rows = []
946
+ for i in range(0, height, self.tile_sample_stride_height):
947
+ row = []
948
+ for j in range(0, width, self.tile_sample_stride_width):
949
+ self.clear_cache()
950
+ time = []
951
+ frame_range = 1 + (num_frames - 1) // 4
952
+ for k in range(frame_range):
953
+ self._enc_conv_idx = [0]
954
+ if k == 0:
955
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
956
+ else:
957
+ tile = x[
958
+ :,
959
+ :,
960
+ 1 + 4 * (k - 1) : 1 + 4 * k,
961
+ i : i + self.tile_sample_min_height,
962
+ j : j + self.tile_sample_min_width,
963
+ ]
964
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
965
+ tile = self.quant_conv(tile)
966
+ time.append(tile)
967
+ row.append(torch.cat(time, dim=2))
968
+ rows.append(row)
969
+ self.clear_cache()
970
+
971
+ result_rows = []
972
+ for i, row in enumerate(rows):
973
+ result_row = []
974
+ for j, tile in enumerate(row):
975
+ # blend the above tile and the left tile
976
+ # to the current tile and add the current tile to the result row
977
+ if i > 0:
978
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
979
+ if j > 0:
980
+ tile = self.blend_h(row[j - 1], tile, blend_width)
981
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
982
+ result_rows.append(torch.cat(result_row, dim=-1))
983
+
984
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
985
+ return enc
986
+
987
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
988
+ r"""
989
+ Decode a batch of images using a tiled decoder.
990
+
991
+ Args:
992
+ z (`torch.Tensor`): Input batch of latent vectors.
993
+ return_dict (`bool`, *optional*, defaults to `True`):
994
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
995
+
996
+ Returns:
997
+ [`~models.vae.DecoderOutput`] or `tuple`:
998
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
999
+ returned.
1000
+ """
1001
+ _, _, num_frames, height, width = z.shape
1002
+ sample_height = height * self.spatial_compression_ratio
1003
+ sample_width = width * self.spatial_compression_ratio
1004
+
1005
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1006
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1007
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1008
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1009
+
1010
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1011
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1012
+
1013
+ # Split z into overlapping tiles and decode them separately.
1014
+ # The tiles have an overlap to avoid seams between tiles.
1015
+ rows = []
1016
+ for i in range(0, height, tile_latent_stride_height):
1017
+ row = []
1018
+ for j in range(0, width, tile_latent_stride_width):
1019
+ self.clear_cache()
1020
+ time = []
1021
+ for k in range(num_frames):
1022
+ self._conv_idx = [0]
1023
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1024
+ tile = self.post_quant_conv(tile)
1025
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1026
+ time.append(decoded)
1027
+ row.append(torch.cat(time, dim=2))
1028
+ rows.append(row)
1029
+ self.clear_cache()
1030
+
1031
+ result_rows = []
1032
+ for i, row in enumerate(rows):
1033
+ result_row = []
1034
+ for j, tile in enumerate(row):
1035
+ # blend the above tile and the left tile
1036
+ # to the current tile and add the current tile to the result row
1037
+ if i > 0:
1038
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1039
+ if j > 0:
1040
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1041
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1042
+ result_rows.append(torch.cat(result_row, dim=-1))
1043
+
1044
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1045
+
1046
+ if not return_dict:
1047
+ return (dec,)
1048
+ return DecoderOutput(sample=dec)
1049
+
1050
+ def forward(
1051
+ self,
1052
+ sample: torch.Tensor,
1053
+ sample_posterior: bool = False,
1054
+ return_dict: bool = True,
1055
+ generator: Optional[torch.Generator] = None,
1056
+ ) -> Union[DecoderOutput, torch.Tensor]:
1057
+ """
1058
+ Args:
1059
+ sample (`torch.Tensor`): Input sample.
1060
+ return_dict (`bool`, *optional*, defaults to `True`):
1061
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1062
+ """
1063
+ x = sample
1064
+ posterior = self.encode(x).latent_dist
1065
+ if sample_posterior:
1066
+ z = posterior.sample(generator=generator)
1067
+ else:
1068
+ z = posterior.mode()
1069
+ dec = self.decode(z, return_dict=return_dict)
1070
+ return dec
fastvideo/models/qwenimage/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from diffusers.utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class QwenImagePipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
fastvideo/models/qwenimage/pipeline_qwenimage.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import QwenImageLoraLoaderMixin
24
+ from fastvideo.models.qwenimage.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
25
+ from fastvideo.models.qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from .pipeline_output import QwenImagePipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import torch
47
+ >>> from diffusers import QwenImagePipeline
48
+
49
+ >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
50
+ >>> pipe.to("cuda")
51
+ >>> prompt = "A cat holding a sign that says hello world"
52
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
53
+ >>> # Refer to the pipeline documentation for more details.
54
+ >>> image = pipe(prompt, num_inference_steps=50).images[0]
55
+ >>> image.save("qwenimage.png")
56
+ ```
57
+ """
58
+
59
+
60
+ def calculate_shift(
61
+ image_seq_len,
62
+ base_seq_len: int = 256,
63
+ max_seq_len: int = 4096,
64
+ base_shift: float = 0.5,
65
+ max_shift: float = 1.15,
66
+ ):
67
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
68
+ b = base_shift - m * base_seq_len
69
+ mu = image_seq_len * m + b
70
+ return mu
71
+
72
+
73
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ r"""
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
107
+ if timesteps is not None:
108
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
109
+ if not accepts_timesteps:
110
+ raise ValueError(
111
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
112
+ f" timestep schedules. Please check whether you are using the correct scheduler."
113
+ )
114
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
115
+ timesteps = scheduler.timesteps
116
+ num_inference_steps = len(timesteps)
117
+ elif sigmas is not None:
118
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accept_sigmas:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ else:
128
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ return timesteps, num_inference_steps
131
+
132
+
133
+ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
134
+ r"""
135
+ The QwenImage pipeline for text-to-image generation.
136
+
137
+ Args:
138
+ transformer ([`QwenImageTransformer2DModel`]):
139
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
140
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
141
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
142
+ vae ([`AutoencoderKL`]):
143
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
144
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
145
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
146
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
147
+ tokenizer (`QwenTokenizer`):
148
+ Tokenizer of class
149
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
150
+ """
151
+
152
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
153
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
154
+
155
+ def __init__(
156
+ self,
157
+ scheduler: FlowMatchEulerDiscreteScheduler,
158
+ vae: AutoencoderKLQwenImage,
159
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
160
+ tokenizer: Qwen2Tokenizer,
161
+ transformer: QwenImageTransformer2DModel,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.register_modules(
166
+ vae=vae,
167
+ text_encoder=text_encoder,
168
+ tokenizer=tokenizer,
169
+ transformer=transformer,
170
+ scheduler=scheduler,
171
+ )
172
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
173
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
174
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
175
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
176
+ self.tokenizer_max_length = 1024
177
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
178
+ self.prompt_template_encode_start_idx = 34
179
+ self.default_sample_size = 128
180
+
181
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
182
+ bool_mask = mask.bool()
183
+ valid_lengths = bool_mask.sum(dim=1)
184
+ selected = hidden_states[bool_mask]
185
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
186
+
187
+ return split_result
188
+
189
+ def _get_qwen_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ device: Optional[torch.device] = None,
193
+ dtype: Optional[torch.dtype] = None,
194
+ ):
195
+ device = device or self._execution_device
196
+ dtype = dtype or self.text_encoder.dtype
197
+
198
+ prompt = [prompt] if isinstance(prompt, str) else prompt
199
+
200
+ template = self.prompt_template_encode
201
+ drop_idx = self.prompt_template_encode_start_idx
202
+ txt = [template.format(e) for e in prompt]
203
+ txt_tokens = self.tokenizer(
204
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
205
+ ).to(device)
206
+ encoder_hidden_states = self.text_encoder(
207
+ input_ids=txt_tokens.input_ids,
208
+ attention_mask=txt_tokens.attention_mask,
209
+ output_hidden_states=True,
210
+ )
211
+ hidden_states = encoder_hidden_states.hidden_states[-1]
212
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
213
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
214
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
215
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
216
+ prompt_embeds = torch.stack(
217
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
218
+ )
219
+ encoder_attention_mask = torch.stack(
220
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
221
+ )
222
+
223
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
224
+
225
+ return prompt_embeds, encoder_attention_mask
226
+
227
+ def encode_prompt(
228
+ self,
229
+ prompt: Union[str, List[str]],
230
+ device: Optional[torch.device] = None,
231
+ num_images_per_prompt: int = 1,
232
+ prompt_embeds: Optional[torch.Tensor] = None,
233
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
234
+ max_sequence_length: int = 1024,
235
+ ):
236
+ r"""
237
+
238
+ Args:
239
+ prompt (`str` or `List[str]`, *optional*):
240
+ prompt to be encoded
241
+ device: (`torch.device`):
242
+ torch device
243
+ num_images_per_prompt (`int`):
244
+ number of images that should be generated per prompt
245
+ prompt_embeds (`torch.Tensor`, *optional*):
246
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
247
+ provided, text embeddings will be generated from `prompt` input argument.
248
+ """
249
+ device = device or self._execution_device
250
+
251
+ prompt = [prompt] if isinstance(prompt, str) else prompt
252
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
253
+
254
+ if prompt_embeds is None:
255
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
256
+
257
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
258
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
259
+
260
+ _, seq_len, _ = prompt_embeds.shape
261
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
262
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
263
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
264
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
265
+
266
+ return prompt_embeds, prompt_embeds_mask
267
+
268
+ def check_inputs(
269
+ self,
270
+ prompt,
271
+ height,
272
+ width,
273
+ negative_prompt=None,
274
+ prompt_embeds=None,
275
+ negative_prompt_embeds=None,
276
+ prompt_embeds_mask=None,
277
+ negative_prompt_embeds_mask=None,
278
+ callback_on_step_end_tensor_inputs=None,
279
+ max_sequence_length=None,
280
+ ):
281
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
282
+ logger.warning(
283
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
284
+ )
285
+
286
+ if callback_on_step_end_tensor_inputs is not None and not all(
287
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
288
+ ):
289
+ raise ValueError(
290
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
291
+ )
292
+
293
+ if prompt is not None and prompt_embeds is not None:
294
+ raise ValueError(
295
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
296
+ " only forward one of the two."
297
+ )
298
+ elif prompt is None and prompt_embeds is None:
299
+ raise ValueError(
300
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
301
+ )
302
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
303
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
304
+
305
+ if negative_prompt is not None and negative_prompt_embeds is not None:
306
+ raise ValueError(
307
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
308
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
309
+ )
310
+
311
+ if prompt_embeds is not None and prompt_embeds_mask is None:
312
+ raise ValueError(
313
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
314
+ )
315
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
316
+ raise ValueError(
317
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
318
+ )
319
+
320
+ if max_sequence_length is not None and max_sequence_length > 1024:
321
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
322
+
323
+ @staticmethod
324
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
325
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
326
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
327
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
328
+
329
+ return latents
330
+
331
+ @staticmethod
332
+ def _unpack_latents(latents, height, width, vae_scale_factor):
333
+ batch_size, num_patches, channels = latents.shape
334
+
335
+ # VAE applies 8x compression on images but we must also account for packing which requires
336
+ # latent height and width to be divisible by 2.
337
+ height = 2 * (int(height) // (vae_scale_factor * 2))
338
+ width = 2 * (int(width) // (vae_scale_factor * 2))
339
+
340
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
341
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
342
+
343
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
344
+
345
+ return latents
346
+
347
+ def enable_vae_slicing(self):
348
+ r"""
349
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
350
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
351
+ """
352
+ self.vae.enable_slicing()
353
+
354
+ def disable_vae_slicing(self):
355
+ r"""
356
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
357
+ computing decoding in one step.
358
+ """
359
+ self.vae.disable_slicing()
360
+
361
+ def enable_vae_tiling(self):
362
+ r"""
363
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
364
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
365
+ processing larger images.
366
+ """
367
+ self.vae.enable_tiling()
368
+
369
+ def disable_vae_tiling(self):
370
+ r"""
371
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
372
+ computing decoding in one step.
373
+ """
374
+ self.vae.disable_tiling()
375
+
376
+ def prepare_latents(
377
+ self,
378
+ batch_size,
379
+ num_channels_latents,
380
+ height,
381
+ width,
382
+ dtype,
383
+ device,
384
+ generator,
385
+ latents=None,
386
+ ):
387
+ # VAE applies 8x compression on images but we must also account for packing which requires
388
+ # latent height and width to be divisible by 2.
389
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
390
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
391
+
392
+ shape = (batch_size, 1, num_channels_latents, height, width)
393
+
394
+ if latents is not None:
395
+ return latents.to(device=device, dtype=dtype)
396
+
397
+ if isinstance(generator, list) and len(generator) != batch_size:
398
+ raise ValueError(
399
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
400
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
401
+ )
402
+
403
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
404
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
405
+
406
+ return latents
407
+
408
+ @property
409
+ def guidance_scale(self):
410
+ return self._guidance_scale
411
+
412
+ @property
413
+ def attention_kwargs(self):
414
+ return self._attention_kwargs
415
+
416
+ @property
417
+ def num_timesteps(self):
418
+ return self._num_timesteps
419
+
420
+ @property
421
+ def current_timestep(self):
422
+ return self._current_timestep
423
+
424
+ @property
425
+ def interrupt(self):
426
+ return self._interrupt
427
+
428
+ @torch.no_grad()
429
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
430
+ def __call__(
431
+ self,
432
+ prompt: Union[str, List[str]] = None,
433
+ negative_prompt: Union[str, List[str]] = None,
434
+ true_cfg_scale: float = 4.0,
435
+ height: Optional[int] = None,
436
+ width: Optional[int] = None,
437
+ num_inference_steps: int = 50,
438
+ sigmas: Optional[List[float]] = None,
439
+ guidance_scale: float = 1.0,
440
+ num_images_per_prompt: int = 1,
441
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
442
+ latents: Optional[torch.Tensor] = None,
443
+ prompt_embeds: Optional[torch.Tensor] = None,
444
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
445
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
446
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
447
+ output_type: Optional[str] = "pil",
448
+ return_dict: bool = True,
449
+ attention_kwargs: Optional[Dict[str, Any]] = None,
450
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
451
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
452
+ max_sequence_length: int = 512,
453
+ ):
454
+ r"""
455
+ Function invoked when calling the pipeline for generation.
456
+
457
+ Args:
458
+ prompt (`str` or `List[str]`, *optional*):
459
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
460
+ instead.
461
+ negative_prompt (`str` or `List[str]`, *optional*):
462
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
463
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
464
+ not greater than `1`).
465
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
466
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
467
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
468
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
469
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
470
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
471
+ num_inference_steps (`int`, *optional*, defaults to 50):
472
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
473
+ expense of slower inference.
474
+ sigmas (`List[float]`, *optional*):
475
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
476
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
477
+ will be used.
478
+ guidance_scale (`float`, *optional*, defaults to 3.5):
479
+ Guidance scale as defined in [Classifier-Free Diffusion
480
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
481
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
482
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
483
+ the text `prompt`, usually at the expense of lower image quality.
484
+
485
+ This parameter in the pipeline is there to support future guidance-distilled models when they come up.
486
+ Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
487
+ please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
488
+ enable classifier-free guidance computations.
489
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
490
+ The number of images to generate per prompt.
491
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
492
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
493
+ to make generation deterministic.
494
+ latents (`torch.Tensor`, *optional*):
495
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
496
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
497
+ tensor will be generated by sampling using the supplied random `generator`.
498
+ prompt_embeds (`torch.Tensor`, *optional*):
499
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
500
+ provided, text embeddings will be generated from `prompt` input argument.
501
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
502
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
503
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
504
+ argument.
505
+ output_type (`str`, *optional*, defaults to `"pil"`):
506
+ The output format of the generate image. Choose between
507
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
508
+ return_dict (`bool`, *optional*, defaults to `True`):
509
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
510
+ attention_kwargs (`dict`, *optional*):
511
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
512
+ `self.processor` in
513
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
514
+ callback_on_step_end (`Callable`, *optional*):
515
+ A function that calls at the end of each denoising steps during the inference. The function is called
516
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
517
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
518
+ `callback_on_step_end_tensor_inputs`.
519
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
520
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
521
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
522
+ `._callback_tensor_inputs` attribute of your pipeline class.
523
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
524
+
525
+ Examples:
526
+
527
+ Returns:
528
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
529
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
530
+ returning a tuple, the first element is a list with the generated images.
531
+ """
532
+
533
+ height = height or self.default_sample_size * self.vae_scale_factor
534
+ width = width or self.default_sample_size * self.vae_scale_factor
535
+
536
+ # 1. Check inputs. Raise error if not correct
537
+ self.check_inputs(
538
+ prompt,
539
+ height,
540
+ width,
541
+ negative_prompt=negative_prompt,
542
+ prompt_embeds=prompt_embeds,
543
+ negative_prompt_embeds=negative_prompt_embeds,
544
+ prompt_embeds_mask=prompt_embeds_mask,
545
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
546
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
547
+ max_sequence_length=max_sequence_length,
548
+ )
549
+
550
+ self._guidance_scale = guidance_scale
551
+ self._attention_kwargs = attention_kwargs
552
+ self._current_timestep = None
553
+ self._interrupt = False
554
+
555
+ # 2. Define call parameters
556
+ if prompt is not None and isinstance(prompt, str):
557
+ batch_size = 1
558
+ elif prompt is not None and isinstance(prompt, list):
559
+ batch_size = len(prompt)
560
+ else:
561
+ batch_size = prompt_embeds.shape[0]
562
+
563
+ device = self._execution_device
564
+
565
+ has_neg_prompt = negative_prompt is not None or (
566
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
567
+ )
568
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
569
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
570
+ prompt=prompt,
571
+ prompt_embeds=prompt_embeds,
572
+ prompt_embeds_mask=prompt_embeds_mask,
573
+ device=device,
574
+ num_images_per_prompt=num_images_per_prompt,
575
+ max_sequence_length=max_sequence_length,
576
+ )
577
+ if do_true_cfg:
578
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
579
+ prompt=negative_prompt,
580
+ prompt_embeds=negative_prompt_embeds,
581
+ prompt_embeds_mask=negative_prompt_embeds_mask,
582
+ device=device,
583
+ num_images_per_prompt=num_images_per_prompt,
584
+ max_sequence_length=max_sequence_length,
585
+ )
586
+
587
+ # 4. Prepare latent variables
588
+ num_channels_latents = self.transformer.config.in_channels // 4
589
+ latents = self.prepare_latents(
590
+ batch_size * num_images_per_prompt,
591
+ num_channels_latents,
592
+ height,
593
+ width,
594
+ prompt_embeds.dtype,
595
+ device,
596
+ generator,
597
+ latents,
598
+ )
599
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
600
+
601
+ # 5. Prepare timesteps
602
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
603
+ image_seq_len = latents.shape[1]
604
+ mu = calculate_shift(
605
+ image_seq_len,
606
+ self.scheduler.config.get("base_image_seq_len", 256),
607
+ self.scheduler.config.get("max_image_seq_len", 4096),
608
+ self.scheduler.config.get("base_shift", 0.5),
609
+ self.scheduler.config.get("max_shift", 1.15),
610
+ )
611
+ timesteps, num_inference_steps = retrieve_timesteps(
612
+ self.scheduler,
613
+ num_inference_steps,
614
+ device,
615
+ sigmas=sigmas,
616
+ mu=mu,
617
+ )
618
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
619
+ self._num_timesteps = len(timesteps)
620
+
621
+ # handle guidance
622
+ if self.transformer.config.guidance_embeds:
623
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
624
+ guidance = guidance.expand(latents.shape[0])
625
+ else:
626
+ guidance = None
627
+
628
+ if self.attention_kwargs is None:
629
+ self._attention_kwargs = {}
630
+
631
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
632
+ negative_txt_seq_lens = (
633
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
634
+ )
635
+
636
+ # 6. Denoising loop
637
+ self.scheduler.set_begin_index(0)
638
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
639
+ for i, t in enumerate(timesteps):
640
+ if self.interrupt:
641
+ continue
642
+
643
+ self._current_timestep = t
644
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
646
+ with self.transformer.cache_context("cond"):
647
+ noise_pred = self.transformer(
648
+ hidden_states=latents,
649
+ timestep=timestep / 1000,
650
+ guidance=guidance,
651
+ encoder_hidden_states_mask=prompt_embeds_mask,
652
+ encoder_hidden_states=prompt_embeds,
653
+ img_shapes=img_shapes,
654
+ txt_seq_lens=txt_seq_lens,
655
+ attention_kwargs=self.attention_kwargs,
656
+ return_dict=False,
657
+ )[0]
658
+
659
+ if do_true_cfg:
660
+ with self.transformer.cache_context("uncond"):
661
+ neg_noise_pred = self.transformer(
662
+ hidden_states=latents,
663
+ timestep=timestep / 1000,
664
+ guidance=guidance,
665
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
666
+ encoder_hidden_states=negative_prompt_embeds,
667
+ img_shapes=img_shapes,
668
+ txt_seq_lens=negative_txt_seq_lens,
669
+ attention_kwargs=self.attention_kwargs,
670
+ return_dict=False,
671
+ )[0]
672
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
673
+
674
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
675
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
676
+ noise_pred = comb_pred * (cond_norm / noise_norm)
677
+
678
+ # compute the previous noisy sample x_t -> x_t-1
679
+ latents_dtype = latents.dtype
680
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
681
+
682
+ if latents.dtype != latents_dtype:
683
+ if torch.backends.mps.is_available():
684
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
685
+ latents = latents.to(latents_dtype)
686
+
687
+ if callback_on_step_end is not None:
688
+ callback_kwargs = {}
689
+ for k in callback_on_step_end_tensor_inputs:
690
+ callback_kwargs[k] = locals()[k]
691
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
692
+
693
+ latents = callback_outputs.pop("latents", latents)
694
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
695
+
696
+ # call the callback, if provided
697
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
698
+ progress_bar.update()
699
+
700
+ if XLA_AVAILABLE:
701
+ xm.mark_step()
702
+
703
+ self._current_timestep = None
704
+ if output_type == "latent":
705
+ image = latents
706
+ else:
707
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
708
+ latents = latents.to(self.vae.dtype)
709
+ latents_mean = (
710
+ torch.tensor(self.vae.config.latents_mean)
711
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
712
+ .to(latents.device, latents.dtype)
713
+ )
714
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
715
+ latents.device, latents.dtype
716
+ )
717
+ latents = latents / latents_std + latents_mean
718
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
719
+ image = self.image_processor.postprocess(image, output_type=output_type)
720
+
721
+ # Offload all models
722
+ self.maybe_free_model_hooks()
723
+
724
+ if not return_dict:
725
+ return (image,)
726
+
727
+ return QwenImagePipelineOutput(images=image)
fastvideo/models/qwenimage/transformer_qwenimage.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from diffusers.models.attention import FeedForward
28
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
29
+ from diffusers.models.attention_processor import Attention
30
+ from diffusers.models.cache_utils import CacheMixin
31
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def get_timestep_embedding(
41
+ timesteps: torch.Tensor,
42
+ embedding_dim: int,
43
+ flip_sin_to_cos: bool = False,
44
+ downscale_freq_shift: float = 1,
45
+ scale: float = 1,
46
+ max_period: int = 10000,
47
+ ) -> torch.Tensor:
48
+ """
49
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
50
+
51
+ Args
52
+ timesteps (torch.Tensor):
53
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
54
+ embedding_dim (int):
55
+ the dimension of the output.
56
+ flip_sin_to_cos (bool):
57
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
58
+ downscale_freq_shift (float):
59
+ Controls the delta between frequencies between dimensions
60
+ scale (float):
61
+ Scaling factor applied to the embeddings.
62
+ max_period (int):
63
+ Controls the maximum frequency of the embeddings
64
+ Returns
65
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
68
+
69
+ half_dim = embedding_dim // 2
70
+ exponent = -math.log(max_period) * torch.arange(
71
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
72
+ )
73
+ exponent = exponent / (half_dim - downscale_freq_shift)
74
+
75
+ emb = torch.exp(exponent).to(timesteps.dtype)
76
+ emb = timesteps[:, None].float() * emb[None, :]
77
+
78
+ # scale embeddings
79
+ emb = scale * emb
80
+
81
+ # concat sine and cosine embeddings
82
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
83
+
84
+ # flip sine and cosine embeddings
85
+ if flip_sin_to_cos:
86
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
87
+
88
+ # zero pad
89
+ if embedding_dim % 2 == 1:
90
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
91
+ return emb
92
+
93
+
94
+ def apply_rotary_emb_qwen(
95
+ x: torch.Tensor,
96
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
97
+ use_real: bool = True,
98
+ use_real_unbind_dim: int = -1,
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
102
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
103
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
104
+ tensors contain rotary embeddings and are returned as real tensors.
105
+
106
+ Args:
107
+ x (`torch.Tensor`):
108
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
109
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
110
+
111
+ Returns:
112
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
113
+ """
114
+ if use_real:
115
+ cos, sin = freqs_cis # [S, D]
116
+ cos = cos[None, None]
117
+ sin = sin[None, None]
118
+ cos, sin = cos.to(x.device), sin.to(x.device)
119
+
120
+ if use_real_unbind_dim == -1:
121
+ # Used for flux, cogvideox, hunyuan-dit
122
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
123
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
124
+ elif use_real_unbind_dim == -2:
125
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
126
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
127
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
128
+ else:
129
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
130
+
131
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
132
+
133
+ return out
134
+ else:
135
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
136
+ freqs_cis = freqs_cis.unsqueeze(1)
137
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
138
+
139
+ return x_out.type_as(x)
140
+
141
+
142
+ class QwenTimestepProjEmbeddings(nn.Module):
143
+ def __init__(self, embedding_dim):
144
+ super().__init__()
145
+
146
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
147
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
148
+
149
+ def forward(self, timestep, hidden_states):
150
+ timesteps_proj = self.time_proj(timestep)
151
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
152
+
153
+ conditioning = timesteps_emb
154
+
155
+ return conditioning
156
+
157
+
158
+ class QwenEmbedRope(nn.Module):
159
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
160
+ super().__init__()
161
+ self.theta = theta
162
+ self.axes_dim = axes_dim
163
+ pos_index = torch.arange(4096)
164
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
165
+ self.pos_freqs = torch.cat(
166
+ [
167
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
168
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
169
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
170
+ ],
171
+ dim=1,
172
+ )
173
+ self.neg_freqs = torch.cat(
174
+ [
175
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
176
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
177
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
178
+ ],
179
+ dim=1,
180
+ )
181
+ self.rope_cache = {}
182
+
183
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
184
+ self.scale_rope = scale_rope
185
+
186
+ def rope_params(self, index, dim, theta=10000):
187
+ """
188
+ Args:
189
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
190
+ """
191
+ assert dim % 2 == 0
192
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
193
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
194
+ return freqs
195
+
196
+ def forward(self, video_fhw, txt_seq_lens, device):
197
+ """
198
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199
+ txt_length: [bs] a list of 1 integers representing the length of the text
200
+ """
201
+ if self.pos_freqs.device != device:
202
+ self.pos_freqs = self.pos_freqs.to(device)
203
+ self.neg_freqs = self.neg_freqs.to(device)
204
+
205
+ if isinstance(video_fhw, list):
206
+ video_fhw = video_fhw[0]
207
+ if not isinstance(video_fhw, list):
208
+ video_fhw = [video_fhw]
209
+
210
+ vid_freqs = []
211
+ max_vid_index = 0
212
+ for idx, fhw in enumerate(video_fhw):
213
+ frame, height, width = fhw
214
+ rope_key = f"{idx}_{height}_{width}"
215
+
216
+ if not torch.compiler.is_compiling():
217
+ if rope_key not in self.rope_cache:
218
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219
+ video_freq = self.rope_cache[rope_key]
220
+ else:
221
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
222
+ video_freq = video_freq.to(device)
223
+ vid_freqs.append(video_freq)
224
+
225
+ if self.scale_rope:
226
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
227
+ else:
228
+ max_vid_index = max(height, width, max_vid_index)
229
+
230
+ max_len = max(txt_seq_lens)
231
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
232
+ vid_freqs = torch.cat(vid_freqs, dim=0)
233
+
234
+ return vid_freqs, txt_freqs
235
+
236
+ @functools.lru_cache(maxsize=None)
237
+ def _compute_video_freqs(self, frame, height, width, idx=0):
238
+ seq_lens = frame * height * width
239
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
240
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
241
+
242
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
243
+ if self.scale_rope:
244
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
245
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
246
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
247
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
248
+ else:
249
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
250
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
251
+
252
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
253
+ return freqs.clone().contiguous()
254
+
255
+
256
+ class QwenDoubleStreamAttnProcessor2_0:
257
+ """
258
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
259
+ implements joint attention computation where text and image streams are processed together.
260
+ """
261
+
262
+ _attention_backend = None
263
+
264
+ def __init__(self):
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError(
267
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
268
+ )
269
+
270
+ def __call__(
271
+ self,
272
+ attn: Attention,
273
+ hidden_states: torch.FloatTensor, # Image stream
274
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
275
+ encoder_hidden_states_mask: torch.FloatTensor = None,
276
+ attention_mask: Optional[torch.FloatTensor] = None,
277
+ image_rotary_emb: Optional[torch.Tensor] = None,
278
+ ) -> torch.FloatTensor:
279
+ if encoder_hidden_states is None:
280
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
281
+
282
+ seq_txt = encoder_hidden_states.shape[1]
283
+
284
+ # Compute QKV for image stream (sample projections)
285
+ img_query = attn.to_q(hidden_states)
286
+ img_key = attn.to_k(hidden_states)
287
+ img_value = attn.to_v(hidden_states)
288
+
289
+ # Compute QKV for text stream (context projections)
290
+ txt_query = attn.add_q_proj(encoder_hidden_states)
291
+ txt_key = attn.add_k_proj(encoder_hidden_states)
292
+ txt_value = attn.add_v_proj(encoder_hidden_states)
293
+
294
+ # Reshape for multi-head attention
295
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
296
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
297
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
298
+
299
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
300
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
301
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
302
+
303
+ # Apply QK normalization
304
+ if attn.norm_q is not None:
305
+ img_query = attn.norm_q(img_query)
306
+ if attn.norm_k is not None:
307
+ img_key = attn.norm_k(img_key)
308
+ if attn.norm_added_q is not None:
309
+ txt_query = attn.norm_added_q(txt_query)
310
+ if attn.norm_added_k is not None:
311
+ txt_key = attn.norm_added_k(txt_key)
312
+
313
+ # Apply RoPE
314
+ if image_rotary_emb is not None:
315
+ img_freqs, txt_freqs = image_rotary_emb
316
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
317
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
318
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
319
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
320
+
321
+ # Concatenate for joint attention
322
+ # Order: [text, image]
323
+ joint_query = torch.cat([txt_query, img_query], dim=1)
324
+ joint_key = torch.cat([txt_key, img_key], dim=1)
325
+ joint_value = torch.cat([txt_value, img_value], dim=1)
326
+
327
+ # Compute joint attention
328
+ joint_hidden_states = dispatch_attention_fn(
329
+ joint_query,
330
+ joint_key,
331
+ joint_value,
332
+ attn_mask=attention_mask,
333
+ dropout_p=0.0,
334
+ is_causal=False,
335
+ backend=self._attention_backend,
336
+ )
337
+
338
+ # Reshape back
339
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
340
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
341
+
342
+ # Split attention outputs back
343
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
344
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
345
+
346
+ # Apply output projections
347
+ img_attn_output = attn.to_out[0](img_attn_output)
348
+ if len(attn.to_out) > 1:
349
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
350
+
351
+ txt_attn_output = attn.to_add_out(txt_attn_output)
352
+
353
+ return img_attn_output, txt_attn_output
354
+
355
+
356
+ @maybe_allow_in_graph
357
+ class QwenImageTransformerBlock(nn.Module):
358
+ def __init__(
359
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
360
+ ):
361
+ super().__init__()
362
+
363
+ self.dim = dim
364
+ self.num_attention_heads = num_attention_heads
365
+ self.attention_head_dim = attention_head_dim
366
+
367
+ # Image processing modules
368
+ self.img_mod = nn.Sequential(
369
+ nn.SiLU(),
370
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
371
+ )
372
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
373
+ self.attn = Attention(
374
+ query_dim=dim,
375
+ cross_attention_dim=None, # Enable cross attention for joint computation
376
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
377
+ dim_head=attention_head_dim,
378
+ heads=num_attention_heads,
379
+ out_dim=dim,
380
+ context_pre_only=False,
381
+ bias=True,
382
+ processor=QwenDoubleStreamAttnProcessor2_0(),
383
+ qk_norm=qk_norm,
384
+ eps=eps,
385
+ )
386
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
387
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
388
+
389
+ # Text processing modules
390
+ self.txt_mod = nn.Sequential(
391
+ nn.SiLU(),
392
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
393
+ )
394
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
395
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
396
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
397
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
398
+
399
+ def _modulate(self, x, mod_params):
400
+ """Apply modulation to input tensor"""
401
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
402
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ encoder_hidden_states: torch.Tensor,
408
+ encoder_hidden_states_mask: torch.Tensor,
409
+ temb: torch.Tensor,
410
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
411
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
412
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
413
+ # Get modulation parameters for both streams
414
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
415
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
416
+
417
+ # Split modulation parameters for norm1 and norm2
418
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
419
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
420
+
421
+ # Process image stream - norm1 + modulation
422
+ img_normed = self.img_norm1(hidden_states)
423
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
424
+
425
+ # Process text stream - norm1 + modulation
426
+ txt_normed = self.txt_norm1(encoder_hidden_states)
427
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
428
+
429
+ # Use QwenAttnProcessor2_0 for joint attention computation
430
+ # This directly implements the DoubleStreamLayerMegatron logic:
431
+ # 1. Computes QKV for both streams
432
+ # 2. Applies QK normalization and RoPE
433
+ # 3. Concatenates and runs joint attention
434
+ # 4. Splits results back to separate streams
435
+ joint_attention_kwargs = joint_attention_kwargs or {}
436
+ attn_output = self.attn(
437
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
438
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
439
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
440
+ image_rotary_emb=image_rotary_emb,
441
+ **joint_attention_kwargs,
442
+ )
443
+
444
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
445
+ img_attn_output, txt_attn_output = attn_output
446
+
447
+ # Apply attention gates and add residual (like in Megatron)
448
+ hidden_states = hidden_states + img_gate1 * img_attn_output
449
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
450
+
451
+ # Process image stream - norm2 + MLP
452
+ img_normed2 = self.img_norm2(hidden_states)
453
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
454
+ img_mlp_output = self.img_mlp(img_modulated2)
455
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
456
+
457
+ # Process text stream - norm2 + MLP
458
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
459
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
460
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
461
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
462
+
463
+ # Clip to prevent overflow for fp16
464
+ if encoder_hidden_states.dtype == torch.float16:
465
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
466
+ if hidden_states.dtype == torch.float16:
467
+ hidden_states = hidden_states.clip(-65504, 65504)
468
+
469
+ return encoder_hidden_states, hidden_states
470
+
471
+
472
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
473
+ """
474
+ The Transformer model introduced in Qwen.
475
+
476
+ Args:
477
+ patch_size (`int`, defaults to `2`):
478
+ Patch size to turn the input data into small patches.
479
+ in_channels (`int`, defaults to `64`):
480
+ The number of channels in the input.
481
+ out_channels (`int`, *optional*, defaults to `None`):
482
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
483
+ num_layers (`int`, defaults to `60`):
484
+ The number of layers of dual stream DiT blocks to use.
485
+ attention_head_dim (`int`, defaults to `128`):
486
+ The number of dimensions to use for each attention head.
487
+ num_attention_heads (`int`, defaults to `24`):
488
+ The number of attention heads to use.
489
+ joint_attention_dim (`int`, defaults to `3584`):
490
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
491
+ `encoder_hidden_states`).
492
+ guidance_embeds (`bool`, defaults to `False`):
493
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
494
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
495
+ The dimensions to use for the rotary positional embeddings.
496
+ """
497
+
498
+ _supports_gradient_checkpointing = True
499
+ _no_split_modules = ["QwenImageTransformerBlock"]
500
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
501
+ _repeated_blocks = ["QwenImageTransformerBlock"]
502
+
503
+ @register_to_config
504
+ def __init__(
505
+ self,
506
+ patch_size: int = 2,
507
+ in_channels: int = 64,
508
+ out_channels: Optional[int] = 16,
509
+ num_layers: int = 60,
510
+ attention_head_dim: int = 128,
511
+ num_attention_heads: int = 24,
512
+ joint_attention_dim: int = 3584,
513
+ guidance_embeds: bool = False, # TODO: this should probably be removed
514
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
515
+ ):
516
+ super().__init__()
517
+ self.out_channels = out_channels or in_channels
518
+ self.inner_dim = num_attention_heads * attention_head_dim
519
+
520
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
521
+
522
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
523
+
524
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
525
+
526
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
527
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
528
+
529
+ self.transformer_blocks = nn.ModuleList(
530
+ [
531
+ QwenImageTransformerBlock(
532
+ dim=self.inner_dim,
533
+ num_attention_heads=num_attention_heads,
534
+ attention_head_dim=attention_head_dim,
535
+ )
536
+ for _ in range(num_layers)
537
+ ]
538
+ )
539
+
540
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
541
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
542
+
543
+ self.gradient_checkpointing = False
544
+
545
+ def forward(
546
+ self,
547
+ hidden_states: torch.Tensor,
548
+ encoder_hidden_states: torch.Tensor = None,
549
+ encoder_hidden_states_mask: torch.Tensor = None,
550
+ timestep: torch.LongTensor = None,
551
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
552
+ txt_seq_lens: Optional[List[int]] = None,
553
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
554
+ attention_kwargs: Optional[Dict[str, Any]] = None,
555
+ return_dict: bool = True,
556
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
557
+ """
558
+ The [`QwenTransformer2DModel`] forward method.
559
+
560
+ Args:
561
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
562
+ Input `hidden_states`.
563
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
564
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
565
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
566
+ Mask of the input conditions.
567
+ timestep ( `torch.LongTensor`):
568
+ Used to indicate denoising step.
569
+ attention_kwargs (`dict`, *optional*):
570
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
571
+ `self.processor` in
572
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
573
+ return_dict (`bool`, *optional*, defaults to `True`):
574
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
575
+ tuple.
576
+
577
+ Returns:
578
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
579
+ `tuple` where the first element is the sample tensor.
580
+ """
581
+ if attention_kwargs is not None:
582
+ attention_kwargs = attention_kwargs.copy()
583
+ lora_scale = attention_kwargs.pop("scale", 1.0)
584
+ else:
585
+ lora_scale = 1.0
586
+
587
+ if USE_PEFT_BACKEND:
588
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
589
+ scale_lora_layers(self, lora_scale)
590
+ else:
591
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
592
+ logger.warning(
593
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
594
+ )
595
+
596
+ hidden_states = self.img_in(hidden_states)
597
+
598
+ timestep = timestep.to(hidden_states.dtype)
599
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
600
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
601
+
602
+ if guidance is not None:
603
+ guidance = guidance.to(hidden_states.dtype) * 1000
604
+
605
+ temb = (
606
+ self.time_text_embed(timestep, hidden_states)
607
+ if guidance is None
608
+ else self.time_text_embed(timestep, guidance, hidden_states)
609
+ )
610
+
611
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
612
+
613
+ for index_block, block in enumerate(self.transformer_blocks):
614
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
615
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
616
+ block,
617
+ hidden_states,
618
+ encoder_hidden_states,
619
+ encoder_hidden_states_mask,
620
+ temb,
621
+ image_rotary_emb,
622
+ )
623
+
624
+ else:
625
+ encoder_hidden_states, hidden_states = block(
626
+ hidden_states=hidden_states,
627
+ encoder_hidden_states=encoder_hidden_states,
628
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
629
+ temb=temb,
630
+ image_rotary_emb=image_rotary_emb,
631
+ joint_attention_kwargs=attention_kwargs,
632
+ )
633
+
634
+ # Use only the image part (hidden_states) from the dual-stream blocks
635
+ hidden_states = self.norm_out(hidden_states, temb)
636
+ output = self.proj_out(hidden_states)
637
+
638
+ if USE_PEFT_BACKEND:
639
+ # remove `lora_scale` from each PEFT layer
640
+ unscale_lora_layers(self, lora_scale)
641
+
642
+ if not return_dict:
643
+ return (output,)
644
+
645
+ return Transformer2DModelOutput(sample=output)
fastvideo/models/stable_diffusion/ddim_with_logprob.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## s0:(2.1924) s5: (2.3384), s15: (2.6422) s24:(2.8335)
106
+ # eta_bound = (((1-alpha_prod_t) * alpha_prod_t_prev) / (alpha_prod_t_prev - alpha_prod_t)) ** (0.5)
107
+
108
+ ## a_t # torch.Size([4, 4, 64, 64])
109
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
110
+
111
+ ## a_t-1
112
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
113
+ sample.device
114
+ )
115
+
116
+ ## b_t
117
+ beta_prod_t = 1 - alpha_prod_t
118
+
119
+ ## pred_x_0
120
+ if self.config.prediction_type == "epsilon":
121
+ pred_original_sample = (
122
+ sample - beta_prod_t ** (0.5) * model_output
123
+ ) / alpha_prod_t ** (0.5)
124
+ pred_epsilon = model_output
125
+ elif self.config.prediction_type == "sample":
126
+ pred_original_sample = model_output
127
+ pred_epsilon = (
128
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
129
+ ) / beta_prod_t ** (0.5)
130
+ elif self.config.prediction_type == "v_prediction":
131
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
132
+ beta_prod_t**0.5
133
+ ) * model_output
134
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
135
+ beta_prod_t**0.5
136
+ ) * sample
137
+ else:
138
+ raise ValueError(
139
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
140
+ " `v_prediction`"
141
+ )
142
+
143
+ # 4. Clip or threshold "predicted x_0"
144
+ if self.config.thresholding:
145
+ pred_original_sample = self._threshold_sample(pred_original_sample)
146
+ elif self.config.clip_sample:
147
+ pred_original_sample = pred_original_sample.clamp(
148
+ -self.config.clip_sample_range, self.config.clip_sample_range
149
+ )
150
+
151
+ # 5. compute variance: "sigma_t(η)"
152
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
153
+
154
+ ## var = (b_t-1 / b_t) * (1 - a_t/a_t-1)
155
+ variance = _get_variance(self, timestep, prev_timestep)
156
+
157
+ ## std = eta * sqrt(var)
158
+ std_dev_t = eta * variance ** (0.5)
159
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
160
+
161
+ if use_clipped_model_output:
162
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
163
+ pred_epsilon = (
164
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
165
+ ) / beta_prod_t ** (0.5)
166
+
167
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
168
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
169
+
170
+ # 7. x_t-1-less
171
+ prev_sample_mean = (
172
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
173
+ )
174
+
175
+ if prev_sample is not None and generator is not None:
176
+ raise ValueError(
177
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
178
+ " `prev_sample` stays `None`."
179
+ )
180
+
181
+ if prev_sample is None:
182
+ variance_noise = randn_tensor(
183
+ model_output.shape,
184
+ generator=generator,
185
+ device=model_output.device,
186
+ dtype=model_output.dtype,
187
+ )
188
+
189
+ # alpha = 1
190
+ # scale = 1.0 / (1 + 2*alpha + 2*alpha**2) ** 0.5
191
+ # new_noise_1 = variance_noise[[0]] + alpha * (variance_noise[[0]]-variance_noise[[1]])
192
+ # new_noise_2 = variance_noise[[1]] + alpha * (variance_noise[[1]]-variance_noise[[0]])
193
+
194
+ # new_noise_1 = new_noise_1 * scale
195
+ # new_noise_2 = new_noise_2 * scale
196
+
197
+ # new_noise = torch.cat((variance_noise[[0]], variance_noise[[1]], new_noise_1, new_noise_2), dim=0)
198
+ # prev_sample = prev_sample_mean + std_dev_t * new_noise
199
+
200
+ ## x_t-1 = x_t-1_mean + std * noise
201
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
202
+
203
+
204
+ ## x_t -> 多个 x_t-1
205
+
206
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
207
+ log_prob = (
208
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
209
+ - torch.log(std_dev_t)
210
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
211
+ )
212
+ # mean along all but batch dimension
213
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
214
+
215
+ return prev_sample.type(sample.dtype), log_prob
fastvideo/models/stable_diffusion/ddim_with_logprob_v6.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ reward_mask: torch.FloatTensor,
42
+ eta: float = 0.0,
43
+ use_clipped_model_output: bool = False,
44
+ generator=None,
45
+ prev_sample: Optional[torch.FloatTensor] = None,
46
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
47
+ """
48
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
49
+ process from the learned model outputs (most often the predicted noise).
50
+
51
+ Args:
52
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
53
+ timestep (`int`): current discrete timestep in the diffusion chain.
54
+ sample (`torch.FloatTensor`):
55
+ current instance of sample being created by diffusion process.
56
+ eta (`float`): weight of noise for added noise in diffusion step.
57
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
58
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
59
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
60
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
61
+ generator: random number generator.
62
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
63
+ can directly provide the noise for the variance itself. This is useful for methods such as
64
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
65
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
66
+
67
+ Returns:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
69
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
70
+ returning a tuple, the first element is the sample tensor.
71
+
72
+ """
73
+ assert isinstance(self, DDIMScheduler)
74
+ if self.num_inference_steps is None:
75
+ raise ValueError(
76
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
77
+ )
78
+
79
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
80
+ # Ideally, read DDIM paper in-detail understanding
81
+
82
+ # Notation (<variable name> -> <name in paper>
83
+ # - pred_noise_t -> e_theta(x_t, t)
84
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
85
+ # - std_dev_t -> sigma_t
86
+ # - eta -> η
87
+ # - pred_sample_direction -> "direction pointing to x_t"
88
+ # - pred_prev_sample -> "x_t-1"
89
+
90
+ ## t-1
91
+ prev_timestep = (
92
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
93
+ )
94
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
95
+
96
+ ## a_t
97
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
98
+
99
+ ## a_t-1
100
+ alpha_prod_t_prev = torch.where(
101
+ prev_timestep.cpu() >= 0,
102
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
103
+ self.final_alpha_cumprod,
104
+ )
105
+
106
+ ## a_t # torch.Size([4, 4, 64, 64])
107
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
108
+
109
+ ## a_t-1
110
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
111
+ sample.device
112
+ )
113
+
114
+ ## b_t
115
+ beta_prod_t = 1 - alpha_prod_t
116
+
117
+ ## pred_x_0
118
+ if self.config.prediction_type == "epsilon":
119
+ pred_original_sample = (
120
+ sample - beta_prod_t ** (0.5) * model_output
121
+ ) / alpha_prod_t ** (0.5)
122
+ pred_epsilon = model_output
123
+ elif self.config.prediction_type == "sample":
124
+ pred_original_sample = model_output
125
+ pred_epsilon = (
126
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
127
+ ) / beta_prod_t ** (0.5)
128
+ elif self.config.prediction_type == "v_prediction":
129
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
130
+ beta_prod_t**0.5
131
+ ) * model_output
132
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
133
+ beta_prod_t**0.5
134
+ ) * sample
135
+ else:
136
+ raise ValueError(
137
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
138
+ " `v_prediction`"
139
+ )
140
+
141
+ # 4. Clip or threshold "predicted x_0"
142
+ if self.config.thresholding:
143
+ pred_original_sample = self._threshold_sample(pred_original_sample)
144
+ elif self.config.clip_sample:
145
+ pred_original_sample = pred_original_sample.clamp(
146
+ -self.config.clip_sample_range, self.config.clip_sample_range
147
+ )
148
+
149
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
150
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
151
+
152
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
153
+ variance = _get_variance(self, timestep, prev_timestep)
154
+
155
+ ## std = eta * sqrt(var)
156
+ std_dev_t = eta * variance ** (0.5)
157
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
158
+
159
+ if use_clipped_model_output:
160
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
161
+ pred_epsilon = (
162
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
163
+ ) / beta_prod_t ** (0.5)
164
+
165
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
166
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
167
+ 0.5
168
+ ) * pred_epsilon
169
+
170
+ # 7. x_t-1-less
171
+ prev_sample_mean = (
172
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
173
+ )
174
+
175
+ if prev_sample is not None and generator is not None:
176
+ raise ValueError(
177
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
178
+ " `prev_sample` stays `None`."
179
+ )
180
+
181
+ if prev_sample is None:
182
+ variance_noise = randn_tensor(
183
+ model_output.shape,
184
+ generator=generator,
185
+ device=model_output.device,
186
+ dtype=model_output.dtype,
187
+ )
188
+
189
+ ## x_t-1 = x_t-1_mean + std * noise
190
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
191
+ ## x_t -> 多个 x_t-1
192
+
193
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
194
+ log_prob = (
195
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))- torch.log(std_dev_t)- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
196
+ )
197
+
198
+ # mean along all but batch dimension
199
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
200
+
201
+ return prev_sample.type(sample.dtype), log_prob
fastvideo/models/stable_diffusion/ddim_with_logprob_v6_2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ if prev_sample is None:
181
+ variance_noise = randn_tensor(
182
+ model_output.shape,
183
+ generator=generator,
184
+ device=model_output.device,
185
+ dtype=model_output.dtype,
186
+ )
187
+
188
+ ## x_t-1 = x_t-1_mean + std * noise
189
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
190
+ ## x_t -> 多个 x_t-1
191
+
192
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
193
+ log_prob = (
194
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))- torch.log(std_dev_t)- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
195
+ )
196
+
197
+ # mean along all but batch dimension
198
+ # log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
199
+
200
+ return prev_sample.type(sample.dtype), log_prob
fastvideo/models/stable_diffusion/ddim_with_logprob_v6_8.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ reward_mask: torch.FloatTensor,
42
+ eta: float = 0.0,
43
+ use_clipped_model_output: bool = False,
44
+ generator=None,
45
+ prev_sample: Optional[torch.FloatTensor] = None,
46
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
47
+ """
48
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
49
+ process from the learned model outputs (most often the predicted noise).
50
+
51
+ Args:
52
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
53
+ timestep (`int`): current discrete timestep in the diffusion chain.
54
+ sample (`torch.FloatTensor`):
55
+ current instance of sample being created by diffusion process.
56
+ eta (`float`): weight of noise for added noise in diffusion step.
57
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
58
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
59
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
60
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
61
+ generator: random number generator.
62
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
63
+ can directly provide the noise for the variance itself. This is useful for methods such as
64
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
65
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
66
+
67
+ Returns:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
69
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
70
+ returning a tuple, the first element is the sample tensor.
71
+
72
+ """
73
+ assert isinstance(self, DDIMScheduler)
74
+ if self.num_inference_steps is None:
75
+ raise ValueError(
76
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
77
+ )
78
+
79
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
80
+ # Ideally, read DDIM paper in-detail understanding
81
+
82
+ # Notation (<variable name> -> <name in paper>
83
+ # - pred_noise_t -> e_theta(x_t, t)
84
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
85
+ # - std_dev_t -> sigma_t
86
+ # - eta -> η
87
+ # - pred_sample_direction -> "direction pointing to x_t"
88
+ # - pred_prev_sample -> "x_t-1"
89
+
90
+ ## t-1
91
+ prev_timestep = (
92
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
93
+ )
94
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
95
+
96
+ ## a_t
97
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
98
+
99
+ ## a_t-1
100
+ alpha_prod_t_prev = torch.where(
101
+ prev_timestep.cpu() >= 0,
102
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
103
+ self.final_alpha_cumprod,
104
+ )
105
+
106
+ ## a_t # torch.Size([4, 4, 64, 64])
107
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
108
+
109
+ ## a_t-1
110
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
111
+ sample.device
112
+ )
113
+
114
+ ## b_t
115
+ beta_prod_t = 1 - alpha_prod_t
116
+
117
+ ## pred_x_0
118
+ if self.config.prediction_type == "epsilon":
119
+ pred_original_sample = (
120
+ sample - beta_prod_t ** (0.5) * model_output
121
+ ) / alpha_prod_t ** (0.5)
122
+ pred_epsilon = model_output
123
+ elif self.config.prediction_type == "sample":
124
+ pred_original_sample = model_output
125
+ pred_epsilon = (
126
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
127
+ ) / beta_prod_t ** (0.5)
128
+ elif self.config.prediction_type == "v_prediction":
129
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
130
+ beta_prod_t**0.5
131
+ ) * model_output
132
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
133
+ beta_prod_t**0.5
134
+ ) * sample
135
+ else:
136
+ raise ValueError(
137
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
138
+ " `v_prediction`"
139
+ )
140
+
141
+ # 4. Clip or threshold "predicted x_0"
142
+ if self.config.thresholding:
143
+ pred_original_sample = self._threshold_sample(pred_original_sample)
144
+ elif self.config.clip_sample:
145
+ pred_original_sample = pred_original_sample.clamp(
146
+ -self.config.clip_sample_range, self.config.clip_sample_range
147
+ )
148
+
149
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
150
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
151
+
152
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
153
+ variance = _get_variance(self, timestep, prev_timestep)
154
+
155
+ ## std = eta * sqrt(var)
156
+ std_dev_t = eta * variance ** (0.5)
157
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
158
+
159
+ if use_clipped_model_output:
160
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
161
+ pred_epsilon = (
162
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
163
+ ) / beta_prod_t ** (0.5)
164
+
165
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
166
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
167
+ 0.5
168
+ ) * pred_epsilon
169
+
170
+ # 7. x_t-1-less
171
+ prev_sample_mean = (
172
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
173
+ )
174
+
175
+ if prev_sample is not None and generator is not None:
176
+ raise ValueError(
177
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
178
+ " `prev_sample` stays `None`."
179
+ )
180
+
181
+ if prev_sample is None:
182
+ variance_noise = randn_tensor(
183
+ model_output.shape,
184
+ generator=generator,
185
+ device=model_output.device,
186
+ dtype=model_output.dtype,
187
+ )
188
+
189
+ ## x_t-1 = x_t-1_mean + std * noise
190
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
191
+ ## x_t -> 多个 x_t-1
192
+
193
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
194
+ log_prob = (
195
+ - reward_mask * ((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))- torch.log(std_dev_t)- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
196
+ )
197
+
198
+ # mean along all but batch dimension
199
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
200
+
201
+ return prev_sample.type(sample.dtype), log_prob
fastvideo/models/stable_diffusion/ddim_with_logprob_v8.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ if prev_sample is None:
181
+ variance_noise = randn_tensor(
182
+ model_output.shape,
183
+ generator=generator,
184
+ device=model_output.device,
185
+ dtype=model_output.dtype,
186
+ )
187
+
188
+ ## x_t-1 = x_t-1_mean + std * noise
189
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
190
+
191
+
192
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
193
+ log_prob = (
194
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
195
+ - torch.log(std_dev_t)
196
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
197
+ )
198
+ # mean along all but batch dimension
199
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
200
+
201
+ return prev_sample.type(sample.dtype), log_prob, prev_sample_mean, std_dev_t, variance_noise
fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob_w_x0(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ if prev_sample is None:
181
+ variance_noise = randn_tensor(
182
+ model_output.shape,
183
+ generator=generator,
184
+ device=model_output.device,
185
+ dtype=model_output.dtype,
186
+ )
187
+
188
+ ## x_t-1 = x_t-1_mean + std * noise
189
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
190
+ ## x_t -> 多个 x_t-1
191
+
192
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
193
+ log_prob = (
194
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
195
+ - torch.log(std_dev_t)
196
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
197
+ )
198
+ # mean along all but batch dimension
199
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
200
+
201
+ return prev_sample.type(sample.dtype), log_prob, pred_original_sample
fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0_2.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob_w_x0(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ if prev_sample is None:
181
+ variance_noise = randn_tensor(
182
+ model_output.shape,
183
+ generator=generator,
184
+ device=model_output.device,
185
+ dtype=model_output.dtype,
186
+ )
187
+
188
+ ## x_t-1 = x_t-1_mean + std * noise
189
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
190
+ ## x_t -> 多个 x_t-1
191
+
192
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
193
+ log_prob = (
194
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
195
+ - torch.log(std_dev_t)
196
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
197
+ )
198
+ # mean along all but batch dimension
199
+ # log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
200
+
201
+ return prev_sample.type(sample.dtype), log_prob, pred_original_sample
fastvideo/models/stable_diffusion/ddim_with_logprob_w_x0_v7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob_w_x0(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ ### ode result
181
+ ode_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
182
+ ode_sample_mean = (alpha_prod_t_prev ** (0.5) * pred_original_sample + ode_direction)
183
+ prev_ode_sample = ode_sample_mean
184
+
185
+ if prev_sample is None:
186
+ variance_noise = randn_tensor(
187
+ model_output.shape,
188
+ generator=generator,
189
+ device=model_output.device,
190
+ dtype=model_output.dtype,
191
+ )
192
+
193
+ ## x_t-1 = x_t-1_mean + std * noise
194
+ ## sde results
195
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
196
+ ## x_t -> 多个 x_t-1
197
+
198
+ ## 混合两个结果
199
+ if eta !=0:
200
+ mixed_tensor = torch.zeros_like(prev_sample).to(prev_sample.device)
201
+ mixed_tensor[0] = prev_sample[0]
202
+ mixed_tensor[1] = prev_sample[1]
203
+
204
+ mixed_tensor[2, :, :, :32] = prev_sample[2, :, :, :32]
205
+ mixed_tensor[2, :, :, 32:] = prev_ode_sample[2, :, :, 32:]
206
+
207
+ mixed_tensor[3, :, :, :32] = prev_ode_sample[3, :, :, :32]
208
+ mixed_tensor[3, :, :, 32:] = prev_sample[3, :, :, 32:]
209
+
210
+ prev_sample = mixed_tensor
211
+
212
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
213
+ log_prob = (
214
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
215
+ - torch.log(std_dev_t)
216
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
217
+ )
218
+ # mean along all but batch dimension
219
+ # log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
220
+
221
+ return prev_sample.type(sample.dtype), log_prob, pred_original_sample
fastvideo/models/stable_diffusion/ddim_with_logprob_wo_eta.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import math
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
10
+
11
+
12
+ def _left_broadcast(t, shape):
13
+ assert t.ndim <= len(shape)
14
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
15
+
16
+
17
+ def _get_variance(self, timestep, prev_timestep):
18
+ ## a_t
19
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
20
+
21
+ ## a_t-1
22
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
23
+
24
+ ## b_t
25
+ beta_prod_t = 1 - alpha_prod_t
26
+
27
+ ## b_t-1
28
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
29
+
30
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
31
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
32
+
33
+ return variance
34
+
35
+
36
+ def ddim_step_with_logprob(
37
+ self: DDIMScheduler,
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ prev_sample: Optional[torch.FloatTensor] = None,
45
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
46
+ """
47
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
48
+ process from the learned model outputs (most often the predicted noise).
49
+
50
+ Args:
51
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
52
+ timestep (`int`): current discrete timestep in the diffusion chain.
53
+ sample (`torch.FloatTensor`):
54
+ current instance of sample being created by diffusion process.
55
+ eta (`float`): weight of noise for added noise in diffusion step.
56
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
57
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
58
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
59
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
60
+ generator: random number generator.
61
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
62
+ can directly provide the noise for the variance itself. This is useful for methods such as
63
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
64
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
65
+
66
+ Returns:
67
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
68
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
69
+ returning a tuple, the first element is the sample tensor.
70
+
71
+ """
72
+ assert isinstance(self, DDIMScheduler)
73
+ if self.num_inference_steps is None:
74
+ raise ValueError(
75
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76
+ )
77
+
78
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79
+ # Ideally, read DDIM paper in-detail understanding
80
+
81
+ # Notation (<variable name> -> <name in paper>
82
+ # - pred_noise_t -> e_theta(x_t, t)
83
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
84
+ # - std_dev_t -> sigma_t
85
+ # - eta -> η
86
+ # - pred_sample_direction -> "direction pointing to x_t"
87
+ # - pred_prev_sample -> "x_t-1"
88
+
89
+ ## t-1
90
+ prev_timestep = (
91
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
92
+ )
93
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
94
+
95
+ ## a_t
96
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
97
+
98
+ ## a_t-1
99
+ alpha_prod_t_prev = torch.where(
100
+ prev_timestep.cpu() >= 0,
101
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
102
+ self.final_alpha_cumprod,
103
+ )
104
+
105
+ ## a_t # torch.Size([4, 4, 64, 64])
106
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
107
+
108
+ ## a_t-1
109
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
110
+ sample.device
111
+ )
112
+
113
+ ## b_t
114
+ beta_prod_t = 1 - alpha_prod_t
115
+
116
+ ## pred_x_0
117
+ if self.config.prediction_type == "epsilon":
118
+ pred_original_sample = (
119
+ sample - beta_prod_t ** (0.5) * model_output
120
+ ) / alpha_prod_t ** (0.5)
121
+ pred_epsilon = model_output
122
+ elif self.config.prediction_type == "sample":
123
+ pred_original_sample = model_output
124
+ pred_epsilon = (
125
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
126
+ ) / beta_prod_t ** (0.5)
127
+ elif self.config.prediction_type == "v_prediction":
128
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
129
+ beta_prod_t**0.5
130
+ ) * model_output
131
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
132
+ beta_prod_t**0.5
133
+ ) * sample
134
+ else:
135
+ raise ValueError(
136
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
137
+ " `v_prediction`"
138
+ )
139
+
140
+ # 4. Clip or threshold "predicted x_0"
141
+ if self.config.thresholding:
142
+ pred_original_sample = self._threshold_sample(pred_original_sample)
143
+ elif self.config.clip_sample:
144
+ pred_original_sample = pred_original_sample.clamp(
145
+ -self.config.clip_sample_range, self.config.clip_sample_range
146
+ )
147
+
148
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
149
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
150
+
151
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1) : 0.2065
152
+ variance = _get_variance(self, timestep, prev_timestep)
153
+
154
+ ## std = eta * sqrt(var)
155
+ std_dev_t = eta * variance ** (0.5)
156
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
157
+
158
+ if use_clipped_model_output:
159
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
160
+ pred_epsilon = (
161
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
162
+ ) / beta_prod_t ** (0.5)
163
+
164
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
165
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
166
+ 0.5
167
+ ) * pred_epsilon
168
+
169
+ # 7. x_t-1-less
170
+ prev_sample_mean = (
171
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
172
+ )
173
+
174
+ if prev_sample is not None and generator is not None:
175
+ raise ValueError(
176
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
177
+ " `prev_sample` stays `None`."
178
+ )
179
+
180
+ if prev_sample is None:
181
+ variance_noise = randn_tensor(
182
+ model_output.shape,
183
+ generator=generator,
184
+ device=model_output.device,
185
+ dtype=model_output.dtype,
186
+ )
187
+
188
+ ## x_t-1 = x_t-1_mean + std * noise
189
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
190
+
191
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
192
+ log_prob = (
193
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
194
+ - torch.log(std_dev_t)
195
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
196
+ )
197
+ # mean along all but batch dimension
198
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
199
+
200
+ return prev_sample.type(sample.dtype), log_prob
fastvideo/models/stable_diffusion/pipeline_with_logprob.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+ # 0. Default height and width to unet
111
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
112
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
113
+
114
+ # 1. Check inputs. Raise error if not correct
115
+ self.check_inputs(
116
+ prompt,
117
+ height,
118
+ width,
119
+ callback_steps,
120
+ negative_prompt,
121
+ prompt_embeds,
122
+ negative_prompt_embeds,
123
+ )
124
+
125
+ # 2. Define call parameters
126
+ if prompt is not None and isinstance(prompt, str):
127
+ batch_size = 1
128
+ elif prompt is not None and isinstance(prompt, list):
129
+ batch_size = len(prompt)
130
+ else:
131
+ batch_size = prompt_embeds.shape[0]
132
+
133
+ device = self._execution_device
134
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
135
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
136
+ # corresponds to doing no classifier free guidance.
137
+ do_classifier_free_guidance = guidance_scale > 1.0
138
+
139
+ # 3. Encode input prompt
140
+ text_encoder_lora_scale = (
141
+ cross_attention_kwargs.get("scale", None)
142
+ if cross_attention_kwargs is not None
143
+ else None
144
+ )
145
+ prompt_embeds = self._encode_prompt(
146
+ prompt,
147
+ device,
148
+ num_images_per_prompt,
149
+ do_classifier_free_guidance,
150
+ negative_prompt,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ lora_scale=text_encoder_lora_scale,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ all_latents = [latents]
179
+ all_log_probs = []
180
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
181
+ for i, t in enumerate(timesteps):
182
+ # expand the latents if we are doing classifier free guidance
183
+ latent_model_input = (
184
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
185
+ )
186
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
187
+
188
+ # predict the noise residual
189
+ noise_pred = self.unet(
190
+ latent_model_input,
191
+ t,
192
+ encoder_hidden_states=prompt_embeds,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ return_dict=False,
195
+ )[0]
196
+
197
+ # perform guidance
198
+ if do_classifier_free_guidance:
199
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
200
+ noise_pred = noise_pred_uncond + guidance_scale * (
201
+ noise_pred_text - noise_pred_uncond
202
+ )
203
+
204
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
205
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
206
+ noise_pred = rescale_noise_cfg(
207
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
208
+ )
209
+
210
+ # compute the previous noisy sample x_t -> x_t-1
211
+ latents, log_prob = ddim_step_with_logprob(
212
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
213
+ )
214
+
215
+ all_latents.append(latents)
216
+ all_log_probs.append(log_prob)
217
+
218
+ # call the callback, if provided
219
+ if i == len(timesteps) - 1 or (
220
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
221
+ ):
222
+ progress_bar.update()
223
+ if callback is not None and i % callback_steps == 0:
224
+ callback(i, t, latents)
225
+
226
+ if not output_type == "latent":
227
+ image = self.vae.decode(
228
+ latents / self.vae.config.scaling_factor, return_dict=False
229
+ )[0]
230
+ image, has_nsfw_concept = self.run_safety_checker(
231
+ image, device, prompt_embeds.dtype
232
+ )
233
+ else:
234
+ image = latents
235
+ has_nsfw_concept = None
236
+
237
+ if has_nsfw_concept is None:
238
+ do_denormalize = [True] * image.shape[0]
239
+ else:
240
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
241
+
242
+ image = self.image_processor.postprocess(
243
+ image, output_type=output_type, do_denormalize=do_denormalize
244
+ )
245
+
246
+ # Offload last model to CPU
247
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
248
+ self.final_offload_hook.offload()
249
+
250
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_p1.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_p1(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ prefix_step: Optional[int] = None, ## 模型只执行前prefix_step步
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ guidance_rescale: float = 0.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+
40
+ Args:
41
+ prompt (`str` or `List[str]`, *optional*):
42
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
43
+ instead.
44
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
45
+ The height in pixels of the generated image.
46
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
47
+ The width in pixels of the generated image.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ guidance_scale (`float`, *optional*, defaults to 7.5):
52
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
53
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
54
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
55
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
56
+ usually at the expense of lower image quality.
57
+ negative_prompt (`str` or `List[str]`, *optional*):
58
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
59
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
60
+ less than `1`).
61
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
62
+ The number of images to generate per prompt.
63
+ eta (`float`, *optional*, defaults to 0.0):
64
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
65
+ [`schedulers.DDIMScheduler`], will be ignored for others.
66
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
67
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
68
+ to make generation deterministic.
69
+ latents (`torch.FloatTensor`, *optional*):
70
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
71
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
72
+ tensor will ge generated by sampling using the supplied random `generator`.
73
+ prompt_embeds (`torch.FloatTensor`, *optional*):
74
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
75
+ provided, text embeddings will be generated from `prompt` input argument.
76
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
77
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
78
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
79
+ argument.
80
+ output_type (`str`, *optional*, defaults to `"pil"`):
81
+ The output format of the generate image. Choose between
82
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
83
+ return_dict (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
85
+ plain tuple.
86
+ callback (`Callable`, *optional*):
87
+ A function that will be called every `callback_steps` steps during inference. The function will be
88
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
89
+ callback_steps (`int`, *optional*, defaults to 1):
90
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
91
+ called at every step.
92
+ cross_attention_kwargs (`dict`, *optional*):
93
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
94
+ `self.processor` in
95
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
96
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
97
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
98
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
99
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
100
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
101
+
102
+ Examples:
103
+
104
+ Returns:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
106
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
107
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
108
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
109
+ (nsfw) content, according to the `safety_checker`.
110
+ """
111
+ # 0. Default height and width to unet
112
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
113
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
114
+
115
+ # 1. Check inputs. Raise error if not correct
116
+ self.check_inputs(
117
+ prompt,
118
+ height,
119
+ width,
120
+ callback_steps,
121
+ negative_prompt,
122
+ prompt_embeds,
123
+ negative_prompt_embeds,
124
+ )
125
+
126
+ # 2. Define call parameters
127
+ if prompt is not None and isinstance(prompt, str):
128
+ batch_size = 1
129
+ elif prompt is not None and isinstance(prompt, list):
130
+ batch_size = len(prompt)
131
+ else:
132
+ batch_size = prompt_embeds.shape[0]
133
+
134
+ device = self._execution_device
135
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
136
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
137
+ # corresponds to doing no classifier free guidance.
138
+ do_classifier_free_guidance = guidance_scale > 1.0
139
+
140
+ # 3. Encode input prompt
141
+ text_encoder_lora_scale = (
142
+ cross_attention_kwargs.get("scale", None)
143
+ if cross_attention_kwargs is not None
144
+ else None
145
+ )
146
+ prompt_embeds = self._encode_prompt(
147
+ prompt,
148
+ device,
149
+ num_images_per_prompt,
150
+ do_classifier_free_guidance,
151
+ negative_prompt,
152
+ prompt_embeds=prompt_embeds,
153
+ negative_prompt_embeds=negative_prompt_embeds,
154
+ lora_scale=text_encoder_lora_scale,
155
+ )
156
+
157
+ # 4. Prepare timesteps
158
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
159
+ timesteps = self.scheduler.timesteps
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Denoising loop
178
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
179
+ all_latents = [latents]
180
+ all_log_probs = []
181
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
182
+ for i, t in enumerate(timesteps):
183
+
184
+ ## 第一阶段执行[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] 共15步采样
185
+ ## 所以第二阶段应该从x_15开始
186
+ if i >= prefix_step:
187
+ break
188
+ # expand the latents if we are doing classifier free guidance
189
+ latent_model_input = (
190
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
191
+ )
192
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
193
+
194
+ # predict the noise residual
195
+ noise_pred = self.unet(
196
+ latent_model_input,
197
+ t,
198
+ encoder_hidden_states=prompt_embeds,
199
+ cross_attention_kwargs=cross_attention_kwargs,
200
+ return_dict=False,
201
+ )[0]
202
+
203
+ # perform guidance
204
+ if do_classifier_free_guidance:
205
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
206
+ noise_pred = noise_pred_uncond + guidance_scale * (
207
+ noise_pred_text - noise_pred_uncond
208
+ )
209
+
210
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
211
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
212
+ noise_pred = rescale_noise_cfg(
213
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
214
+ )
215
+
216
+ # compute the previous noisy sample x_t -> x_t-1
217
+ latents, log_prob = ddim_step_with_logprob(
218
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
219
+ )
220
+
221
+ all_latents.append(latents)
222
+ all_log_probs.append(log_prob)
223
+
224
+ # call the callback, if provided
225
+ if i == len(timesteps) - 1 or (
226
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
227
+ ):
228
+ progress_bar.update()
229
+ if callback is not None and i % callback_steps == 0:
230
+ callback(i, t, latents)
231
+
232
+ if not output_type == "latent": ## false
233
+ image = self.vae.decode(
234
+ latents / self.vae.config.scaling_factor, return_dict=False
235
+ )[0]
236
+ image, has_nsfw_concept = self.run_safety_checker(
237
+ image, device, prompt_embeds.dtype
238
+ )
239
+ else:
240
+ image = latents
241
+ has_nsfw_concept = None
242
+
243
+ if has_nsfw_concept is None:
244
+ do_denormalize = [True] * image.shape[0]
245
+ else:
246
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
247
+
248
+ ## 第一阶段不需要x_0
249
+ # image = self.image_processor.postprocess(
250
+ # image, output_type=output_type, do_denormalize=do_denormalize
251
+ # )
252
+ image = None
253
+
254
+ # Offload last model to CPU
255
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
256
+ self.final_offload_hook.offload()
257
+
258
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_p2.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+
14
+ def retrieve_timesteps(
15
+ scheduler,
16
+ num_inference_steps: Optional[int] = None,
17
+ device: Optional[Union[str, torch.device]] = None,
18
+ timesteps: Optional[List[int]] = None,
19
+ sigmas: Optional[List[float]] = None,
20
+ **kwargs,
21
+ ):
22
+ """
23
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
24
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
25
+
26
+ Args:
27
+ scheduler (`SchedulerMixin`):
28
+ The scheduler to get timesteps from.
29
+ num_inference_steps (`int`):
30
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
31
+ must be `None`.
32
+ device (`str` or `torch.device`, *optional*):
33
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
34
+ timesteps (`List[int]`, *optional*):
35
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
36
+ `num_inference_steps` and `sigmas` must be `None`.
37
+ sigmas (`List[float]`, *optional*):
38
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
39
+ `num_inference_steps` and `timesteps` must be `None`.
40
+
41
+ Returns:
42
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
43
+ second element is the number of inference steps.
44
+ """
45
+ if timesteps is not None and sigmas is not None:
46
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
47
+ if timesteps is not None:
48
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
49
+ if not accepts_timesteps:
50
+ raise ValueError(
51
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
52
+ f" timestep schedules. Please check whether you are using the correct scheduler."
53
+ )
54
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
55
+ timesteps = scheduler.timesteps
56
+ num_inference_steps = len(timesteps)
57
+ elif sigmas is not None:
58
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
59
+ if not accept_sigmas:
60
+ raise ValueError(
61
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
62
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
63
+ )
64
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+ num_inference_steps = len(timesteps)
67
+ else:
68
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
69
+ timesteps = scheduler.timesteps
70
+ return timesteps, num_inference_steps
71
+
72
+ @torch.no_grad()
73
+ def pipeline_with_logprob_p2(
74
+ self: StableDiffusionPipeline,
75
+ prompt: Union[str, List[str]] = None,
76
+ prefix_step: Optional[int] = None,
77
+ height: Optional[int] = None,
78
+ width: Optional[int] = None,
79
+ num_inference_steps: int = 50,
80
+ guidance_scale: float = 7.5,
81
+ negative_prompt: Optional[Union[str, List[str]]] = None,
82
+ num_images_per_prompt: Optional[int] = 1,
83
+ eta: float = 0.0,
84
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
85
+ latents: Optional[torch.FloatTensor] = None,
86
+ prompt_embeds: Optional[torch.FloatTensor] = None,
87
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
88
+ output_type: Optional[str] = "pil",
89
+ return_dict: bool = True,
90
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
91
+ callback_steps: int = 1,
92
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
93
+ guidance_rescale: float = 0.0,
94
+ ):
95
+ r"""
96
+ Function invoked when calling the pipeline for generation.
97
+
98
+ Args:
99
+ prompt (`str` or `List[str]`, *optional*):
100
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
101
+ instead.
102
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
103
+ The height in pixels of the generated image.
104
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
105
+ The width in pixels of the generated image.
106
+ num_inference_steps (`int`, *optional*, defaults to 50):
107
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
108
+ expense of slower inference.
109
+ guidance_scale (`float`, *optional*, defaults to 7.5):
110
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
111
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
112
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
113
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
114
+ usually at the expense of lower image quality.
115
+ negative_prompt (`str` or `List[str]`, *optional*):
116
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
117
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
118
+ less than `1`).
119
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
120
+ The number of images to generate per prompt.
121
+ eta (`float`, *optional*, defaults to 0.0):
122
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
123
+ [`schedulers.DDIMScheduler`], will be ignored for others.
124
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
125
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
126
+ to make generation deterministic.
127
+ latents (`torch.FloatTensor`, *optional*):
128
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
129
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
130
+ tensor will ge generated by sampling using the supplied random `generator`.
131
+ prompt_embeds (`torch.FloatTensor`, *optional*):
132
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
133
+ provided, text embeddings will be generated from `prompt` input argument.
134
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
135
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
136
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
137
+ argument.
138
+ output_type (`str`, *optional*, defaults to `"pil"`):
139
+ The output format of the generate image. Choose between
140
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
141
+ return_dict (`bool`, *optional*, defaults to `True`):
142
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
143
+ plain tuple.
144
+ callback (`Callable`, *optional*):
145
+ A function that will be called every `callback_steps` steps during inference. The function will be
146
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
147
+ callback_steps (`int`, *optional*, defaults to 1):
148
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
149
+ called at every step.
150
+ cross_attention_kwargs (`dict`, *optional*):
151
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
152
+ `self.processor` in
153
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
154
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
155
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
156
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
157
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
158
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
159
+
160
+ Examples:
161
+
162
+ Returns:
163
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
164
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
165
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
166
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
167
+ (nsfw) content, according to the `safety_checker`.
168
+ """
169
+
170
+
171
+ def get_timesteps(self, num_inference_steps, timesteps, strength):
172
+ # get the original timestep using init_timestep
173
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
174
+
175
+ t_start = max(num_inference_steps - init_timestep, 0)
176
+ timesteps = timesteps[t_start * self.scheduler.order :]
177
+
178
+ return timesteps, num_inference_steps - t_start
179
+
180
+ # 0. Default height and width to unet
181
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
182
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
183
+
184
+ # 1. Check inputs. Raise error if not correct
185
+ self.check_inputs(
186
+ prompt,
187
+ height,
188
+ width,
189
+ callback_steps,
190
+ negative_prompt,
191
+ prompt_embeds,
192
+ negative_prompt_embeds,
193
+ )
194
+
195
+ # 2. Define call parameters
196
+ if prompt is not None and isinstance(prompt, str):
197
+ batch_size = 1
198
+ elif prompt is not None and isinstance(prompt, list):
199
+ batch_size = len(prompt)
200
+ else:
201
+ batch_size = prompt_embeds.shape[0]
202
+
203
+ device = self._execution_device
204
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
206
+ # corresponds to doing no classifier free guidance.
207
+ do_classifier_free_guidance = guidance_scale > 1.0
208
+
209
+ # 3. Encode input prompt
210
+ text_encoder_lora_scale = (
211
+ cross_attention_kwargs.get("scale", None)
212
+ if cross_attention_kwargs is not None
213
+ else None
214
+ )
215
+ prompt_embeds = self._encode_prompt(
216
+ prompt,
217
+ device,
218
+ num_images_per_prompt,
219
+ do_classifier_free_guidance,
220
+ negative_prompt,
221
+ prompt_embeds=prompt_embeds,
222
+ negative_prompt_embeds=negative_prompt_embeds,
223
+ lora_scale=text_encoder_lora_scale,
224
+ )
225
+
226
+ # 4. Prepare timesteps
227
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
228
+ timesteps = self.scheduler.timesteps
229
+
230
+ # 5. Prepare latent variables
231
+ num_channels_latents = self.unet.config.in_channels
232
+ latents = self.prepare_latents(
233
+ batch_size * num_images_per_prompt,
234
+ num_channels_latents,
235
+ height,
236
+ width,
237
+ prompt_embeds.dtype,
238
+ device,
239
+ generator,
240
+ latents,
241
+ )
242
+
243
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
244
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
245
+
246
+ # 7. Denoising loop
247
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
248
+ all_latents = [latents]
249
+ all_log_probs = []
250
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
251
+ for i, t in enumerate(timesteps):
252
+
253
+ if i < prefix_step:
254
+ continue
255
+
256
+ # expand the latents if we are doing classifier free guidance
257
+ latent_model_input = (
258
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
259
+ )
260
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
261
+
262
+ # predict the noise residual
263
+ noise_pred = self.unet(
264
+ latent_model_input,
265
+ t,
266
+ encoder_hidden_states=prompt_embeds,
267
+ cross_attention_kwargs=cross_attention_kwargs,
268
+ return_dict=False,
269
+ )[0]
270
+
271
+ # perform guidance
272
+ if do_classifier_free_guidance:
273
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
274
+ noise_pred = noise_pred_uncond + guidance_scale * (
275
+ noise_pred_text - noise_pred_uncond
276
+ )
277
+
278
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
279
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
280
+ noise_pred = rescale_noise_cfg(
281
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
282
+ )
283
+
284
+ # compute the previous noisy sample x_t -> x_t-1
285
+ latents, log_prob = ddim_step_with_logprob(
286
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
287
+ )
288
+
289
+ all_latents.append(latents)
290
+ all_log_probs.append(log_prob)
291
+
292
+ # call the callback, if provided
293
+ if i == len(timesteps) - 1 or (
294
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
295
+ ):
296
+ progress_bar.update()
297
+ if callback is not None and i % callback_steps == 0:
298
+ callback(i, t, latents)
299
+
300
+ if not output_type == "latent":
301
+ image = self.vae.decode(
302
+ latents / self.vae.config.scaling_factor, return_dict=False
303
+ )[0]
304
+ image, has_nsfw_concept = self.run_safety_checker(
305
+ image, device, prompt_embeds.dtype
306
+ )
307
+ else:
308
+ image = latents
309
+ has_nsfw_concept = None
310
+
311
+ if has_nsfw_concept is None:
312
+ do_denormalize = [True] * image.shape[0]
313
+ else:
314
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
315
+
316
+ image = self.image_processor.postprocess(
317
+ image, output_type=output_type, do_denormalize=do_denormalize
318
+ )
319
+
320
+ # Offload last model to CPU
321
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
322
+ self.final_offload_hook.offload()
323
+
324
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_prefix.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+ # 0. Default height and width to unet
111
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
112
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
113
+
114
+ # 1. Check inputs. Raise error if not correct
115
+ self.check_inputs(
116
+ prompt,
117
+ height,
118
+ width,
119
+ callback_steps,
120
+ negative_prompt,
121
+ prompt_embeds,
122
+ negative_prompt_embeds,
123
+ )
124
+
125
+ # 2. Define call parameters
126
+ if prompt is not None and isinstance(prompt, str):
127
+ batch_size = 1
128
+ elif prompt is not None and isinstance(prompt, list):
129
+ batch_size = len(prompt)
130
+ else:
131
+ batch_size = prompt_embeds.shape[0]
132
+
133
+ device = self._execution_device
134
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
135
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
136
+ # corresponds to doing no classifier free guidance.
137
+ do_classifier_free_guidance = guidance_scale > 1.0
138
+
139
+ # 3. Encode input prompt
140
+ text_encoder_lora_scale = (
141
+ cross_attention_kwargs.get("scale", None)
142
+ if cross_attention_kwargs is not None
143
+ else None
144
+ )
145
+ prompt_embeds = self._encode_prompt(
146
+ prompt,
147
+ device,
148
+ num_images_per_prompt,
149
+ do_classifier_free_guidance,
150
+ negative_prompt,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ lora_scale=text_encoder_lora_scale,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ all_latents = [latents]
179
+ all_log_probs = []
180
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
181
+ for i, t in enumerate(timesteps):
182
+ # expand the latents if we are doing classifier free guidance
183
+ latent_model_input = (
184
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
185
+ )
186
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
187
+
188
+ # predict the noise residual
189
+ noise_pred = self.unet(
190
+ latent_model_input,
191
+ t,
192
+ encoder_hidden_states=prompt_embeds,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ return_dict=False,
195
+ )[0]
196
+
197
+ # perform guidance
198
+ if do_classifier_free_guidance:
199
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
200
+ noise_pred = noise_pred_uncond + guidance_scale * (
201
+ noise_pred_text - noise_pred_uncond
202
+ )
203
+
204
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
205
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
206
+ noise_pred = rescale_noise_cfg(
207
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
208
+ )
209
+
210
+ # compute the previous noisy sample x_t -> x_t-1
211
+ if i == 0:
212
+ latents, log_prob = ddim_step_with_logprob(
213
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
214
+ )
215
+ else:
216
+ extra_step_kwargs["eta"] = 0
217
+ latents, log_prob = ddim_step_with_logprob(
218
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
219
+ )
220
+
221
+ all_latents.append(latents)
222
+ all_log_probs.append(log_prob)
223
+
224
+ # call the callback, if provided
225
+ if i == len(timesteps) - 1 or (
226
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
227
+ ):
228
+ progress_bar.update()
229
+ if callback is not None and i % callback_steps == 0:
230
+ callback(i, t, latents)
231
+
232
+ if not output_type == "latent":
233
+ image = self.vae.decode(
234
+ latents / self.vae.config.scaling_factor, return_dict=False
235
+ )[0]
236
+ image, has_nsfw_concept = self.run_safety_checker(
237
+ image, device, prompt_embeds.dtype
238
+ )
239
+ else:
240
+ image = latents
241
+ has_nsfw_concept = None
242
+
243
+ if has_nsfw_concept is None:
244
+ do_denormalize = [True] * image.shape[0]
245
+ else:
246
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
247
+
248
+ image = self.image_processor.postprocess(
249
+ image, output_type=output_type, do_denormalize=do_denormalize
250
+ )
251
+
252
+ # Offload last model to CPU
253
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
254
+ self.final_offload_hook.offload()
255
+
256
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+ @torch.no_grad()
14
+ def pipeline_with_logprob_w_eta(
15
+ self: StableDiffusionPipeline,
16
+ prompt: Union[str, List[str]] = None,
17
+ eta_step: Optional[int] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+
111
+ # 0. Default height and width to unet
112
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
113
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
114
+
115
+ # 1. Check inputs. Raise error if not correct
116
+ self.check_inputs(
117
+ prompt,
118
+ height,
119
+ width,
120
+ callback_steps,
121
+ negative_prompt,
122
+ prompt_embeds,
123
+ negative_prompt_embeds,
124
+ )
125
+
126
+ # 2. Define call parameters
127
+ if prompt is not None and isinstance(prompt, str):
128
+ batch_size = 1
129
+ elif prompt is not None and isinstance(prompt, list):
130
+ batch_size = len(prompt)
131
+ else:
132
+ batch_size = prompt_embeds.shape[0]
133
+
134
+ device = self._execution_device
135
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
136
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
137
+ # corresponds to doing no classifier free guidance.
138
+ do_classifier_free_guidance = guidance_scale > 1.0
139
+
140
+ # 3. Encode input prompt
141
+ text_encoder_lora_scale = (
142
+ cross_attention_kwargs.get("scale", None)
143
+ if cross_attention_kwargs is not None
144
+ else None
145
+ )
146
+ prompt_embeds = self._encode_prompt(
147
+ prompt,
148
+ device,
149
+ num_images_per_prompt,
150
+ do_classifier_free_guidance,
151
+ negative_prompt,
152
+ prompt_embeds=prompt_embeds,
153
+ negative_prompt_embeds=negative_prompt_embeds,
154
+ lora_scale=text_encoder_lora_scale,
155
+ )
156
+
157
+ # 4. Prepare timesteps
158
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
159
+ timesteps = self.scheduler.timesteps
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Denoising loop
178
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
179
+ all_latents = [latents]
180
+ all_log_probs = []
181
+
182
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
183
+ for i, t in enumerate(timesteps):
184
+
185
+ if i < eta_step:
186
+ continue
187
+
188
+ # expand the latents if we are doing classifier free guidance
189
+ latent_model_input = ( ## torch.Size([8, 4, 64, 64])
190
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
191
+ )
192
+
193
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) ## none
194
+
195
+ # predict the noise residual
196
+ noise_pred = self.unet( ## torch.Size([8, 4, 64, 64]) 两个latent的预测噪声
197
+ latent_model_input,
198
+ t,
199
+ encoder_hidden_states=prompt_embeds,
200
+ cross_attention_kwargs=cross_attention_kwargs,
201
+ return_dict=False,
202
+ )[0]
203
+
204
+ # perform guidance
205
+ if do_classifier_free_guidance:
206
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) ## 无条件输出和条件输出组合
207
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
208
+
209
+ if do_classifier_free_guidance and guidance_rescale > 0.0: ## none
210
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
211
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
212
+
213
+ # compute the previous noisy sample x_t -> x_t-1
214
+ ## 仅第一步引入随机性
215
+ if i == eta_step:
216
+ latents, log_prob = ddim_step_with_logprob(
217
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
218
+ )
219
+ else:
220
+ ## 其他步按DDIM确定性采样得到结果
221
+ extra_step_kwargs["eta"] = 0
222
+ latents, log_prob = ddim_step_with_logprob(
223
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
224
+ )
225
+
226
+ all_latents.append(latents)
227
+ all_log_probs.append(log_prob)
228
+
229
+ # call the callback, if provided
230
+ if i == len(timesteps) - 1 or (
231
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
232
+ ):
233
+ progress_bar.update()
234
+ if callback is not None and i % callback_steps == 0:
235
+ callback(i, t, latents)
236
+
237
+ if not output_type == "latent":
238
+ image = self.vae.decode(
239
+ latents / self.vae.config.scaling_factor, return_dict=False
240
+ )[0]
241
+ image, has_nsfw_concept = self.run_safety_checker(
242
+ image, device, prompt_embeds.dtype
243
+ )
244
+ else:
245
+ image = latents
246
+ has_nsfw_concept = None
247
+
248
+ if has_nsfw_concept is None:
249
+ do_denormalize = [True] * image.shape[0]
250
+ else:
251
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
252
+
253
+ image = self.image_processor.postprocess(
254
+ image, output_type=output_type, do_denormalize=do_denormalize
255
+ )
256
+
257
+ # Offload last model to CPU
258
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
259
+ self.final_offload_hook.offload()
260
+
261
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_bid.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+
14
+
15
+ def _get_variance(scheduler, timestep, prev_timestep):
16
+
17
+ ## a_t
18
+ alpha_prod_t = torch.gather(scheduler.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
19
+
20
+ ## a_t-1
21
+ alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,scheduler.alphas_cumprod.gather(0, prev_timestep.cpu()),scheduler.final_alpha_cumprod,).to(timestep.device)
22
+
23
+ ## b_t
24
+ beta_prod_t = 1 - alpha_prod_t
25
+
26
+ ## b_t-1
27
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
28
+
29
+ ## (b_t-1 / b_t) * (1 - a_t/a_t-1)
30
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
31
+
32
+ return variance
33
+
34
+ def _left_broadcast(t, shape):
35
+ assert t.ndim <= len(shape)
36
+ return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
37
+
38
+
39
+ @torch.no_grad()
40
+ def pipeline_with_logprob_w_eta_bid(
41
+ self: StableDiffusionPipeline,
42
+ prompt: Union[str, List[str]] = None,
43
+ eta_step: Optional[int] = None,
44
+ height: Optional[int] = None,
45
+ width: Optional[int] = None,
46
+ num_inference_steps: int = 50,
47
+ guidance_scale: float = 7.5,
48
+ negative_prompt: Optional[Union[str, List[str]]] = None,
49
+ num_images_per_prompt: Optional[int] = 1,
50
+ eta: float = 0.0,
51
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
52
+ latents: Optional[torch.FloatTensor] = None,
53
+ prompt_embeds: Optional[torch.FloatTensor] = None,
54
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ output_type: Optional[str] = "pil",
56
+ return_dict: bool = True,
57
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
58
+ callback_steps: int = 1,
59
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
60
+ guidance_rescale: float = 0.0,
61
+ anchor_aug_latents=None,
62
+ ):
63
+ r"""
64
+ Function invoked when calling the pipeline for generation.
65
+
66
+ Args:
67
+ prompt (`str` or `List[str]`, *optional*):
68
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
69
+ instead.
70
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
71
+ The height in pixels of the generated image.
72
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
73
+ The width in pixels of the generated image.
74
+ num_inference_steps (`int`, *optional*, defaults to 50):
75
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
76
+ expense of slower inference.
77
+ guidance_scale (`float`, *optional*, defaults to 7.5):
78
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
79
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
80
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
81
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
82
+ usually at the expense of lower image quality.
83
+ negative_prompt (`str` or `List[str]`, *optional*):
84
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
85
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
86
+ less than `1`).
87
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
88
+ The number of images to generate per prompt.
89
+ eta (`float`, *optional*, defaults to 0.0):
90
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
91
+ [`schedulers.DDIMScheduler`], will be ignored for others.
92
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
93
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
94
+ to make generation deterministic.
95
+ latents (`torch.FloatTensor`, *optional*):
96
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
97
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
98
+ tensor will ge generated by sampling using the supplied random `generator`.
99
+ prompt_embeds (`torch.FloatTensor`, *optional*):
100
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
101
+ provided, text embeddings will be generated from `prompt` input argument.
102
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
103
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
104
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
105
+ argument.
106
+ output_type (`str`, *optional*, defaults to `"pil"`):
107
+ The output format of the generate image. Choose between
108
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
109
+ return_dict (`bool`, *optional*, defaults to `True`):
110
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
111
+ plain tuple.
112
+ callback (`Callable`, *optional*):
113
+ A function that will be called every `callback_steps` steps during inference. The function will be
114
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
115
+ callback_steps (`int`, *optional*, defaults to 1):
116
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
117
+ called at every step.
118
+ cross_attention_kwargs (`dict`, *optional*):
119
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120
+ `self.processor` in
121
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
122
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
123
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
124
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
125
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
126
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
127
+
128
+ Examples:
129
+
130
+ Returns:
131
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
132
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
133
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
134
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
135
+ (nsfw) content, according to the `safety_checker`.
136
+ """
137
+
138
+ # 0. Default height and width to unet
139
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
140
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
141
+
142
+ # 1. Check inputs. Raise error if not correct
143
+ self.check_inputs(
144
+ prompt,
145
+ height,
146
+ width,
147
+ callback_steps,
148
+ negative_prompt,
149
+ prompt_embeds,
150
+ negative_prompt_embeds,
151
+ )
152
+
153
+ # 2. Define call parameters
154
+ if prompt is not None and isinstance(prompt, str):
155
+ batch_size = 1
156
+ elif prompt is not None and isinstance(prompt, list):
157
+ batch_size = len(prompt)
158
+ else:
159
+ batch_size = prompt_embeds.shape[0]
160
+
161
+ device = self._execution_device
162
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
163
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
164
+ # corresponds to doing no classifier free guidance.
165
+ do_classifier_free_guidance = guidance_scale > 1.0
166
+
167
+ # 3. Encode input prompt
168
+ text_encoder_lora_scale = (
169
+ cross_attention_kwargs.get("scale", None)
170
+ if cross_attention_kwargs is not None
171
+ else None
172
+ )
173
+ prompt_embeds = self._encode_prompt(
174
+ prompt,
175
+ device,
176
+ num_images_per_prompt,
177
+ do_classifier_free_guidance,
178
+ negative_prompt,
179
+ prompt_embeds=prompt_embeds,
180
+ negative_prompt_embeds=negative_prompt_embeds,
181
+ lora_scale=text_encoder_lora_scale,
182
+ )
183
+
184
+ # 4. Prepare timesteps
185
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
186
+ timesteps = self.scheduler.timesteps
187
+
188
+ # 5. Prepare latent variables
189
+ num_channels_latents = self.unet.config.in_channels
190
+ latents = self.prepare_latents(
191
+ batch_size * num_images_per_prompt,
192
+ num_channels_latents,
193
+ height,
194
+ width,
195
+ prompt_embeds.dtype,
196
+ device,
197
+ generator,
198
+ latents,
199
+ )
200
+
201
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
202
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
203
+
204
+ # 7. Denoising loop
205
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
206
+ all_latents = [latents]
207
+ all_log_probs = []
208
+
209
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
210
+ for i, t in enumerate(timesteps):
211
+
212
+ if i < eta_step:
213
+ continue
214
+
215
+ # expand the latents if we are doing classifier free guidance
216
+ latent_model_input = (
217
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
218
+ )
219
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
220
+
221
+ # predict the noise residual
222
+ noise_pred = self.unet(
223
+ latent_model_input,
224
+ t,
225
+ encoder_hidden_states=prompt_embeds,
226
+ cross_attention_kwargs=cross_attention_kwargs,
227
+ return_dict=False,
228
+ )[0]
229
+
230
+ # perform guidance
231
+ if do_classifier_free_guidance:
232
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
233
+ noise_pred = noise_pred_uncond + guidance_scale * (
234
+ noise_pred_text - noise_pred_uncond
235
+ )
236
+
237
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
238
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
239
+ noise_pred = rescale_noise_cfg(
240
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
241
+ )
242
+
243
+ # compute the previous noisy sample x_t -> x_t-1
244
+ ## 仅第一步引入随机性
245
+ if i == eta_step:
246
+
247
+ ## 对第一项latents引入随机性扰动
248
+ ## 符合SDE的扰动规则: x_t_random = x_t + std_dev_t * variance_noise
249
+ ## std_dev_t = eta * sqrt(var)
250
+ ## var = (b_t-1 / b_t) * (1 - a_t/a_t-1)
251
+
252
+
253
+
254
+ ## x_t_mean -> x_t_aug
255
+ # # get sigma_t+1
256
+ # next_timestep = (t + self.scheduler.config.num_train_timesteps // len(timesteps))
257
+ # variance = _get_variance(self.scheduler, next_timestep, t) ## t t-1 , t+1 t
258
+ # std_dev_t = eta * variance ** (0.5)
259
+ # std_dev_t = _left_broadcast(std_dev_t, latents.shape).to(latents.device)
260
+
261
+ # # get alpha_t
262
+ # alpha_prod_t = torch.gather(self.scheduler.alphas_cumprod, 0, t.cpu()).to(latents.device)
263
+ # alpha_prod_t = _left_broadcast(alpha_prod_t, latents.shape).to(latents.device)
264
+ # x_t_aug = ((1 - alpha_prod_t - std_dev_t**2)**(0.5) - (1-alpha_prod_t)**(0.5)) * anchor_noises + all_latents[0]
265
+
266
+ # ## 前向随机性
267
+ # variance_noise = randn_tensor(
268
+ # latents.shape,
269
+ # generator=generator,
270
+ # device=latents.device,
271
+ # dtype=latents.dtype,
272
+ # )
273
+ # aug_latents = x_t_aug + std_dev_t * variance_noise
274
+
275
+ new_latents = torch.cat((latents[:2], anchor_aug_latents[2:]), dim=0)
276
+ all_latents[0] = new_latents
277
+
278
+
279
+ ## 输入换成增强后的new_latents
280
+ latents, log_prob = ddim_step_with_logprob(
281
+ self.scheduler, noise_pred, t, new_latents, **extra_step_kwargs
282
+ )
283
+
284
+ else:
285
+ ## 其他步按DDIM确定性采样得到结果
286
+ extra_step_kwargs["eta"] = 0
287
+ latents, log_prob = ddim_step_with_logprob(
288
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
289
+ )
290
+
291
+ all_latents.append(latents)
292
+ all_log_probs.append(log_prob)
293
+
294
+ # call the callback, if provided
295
+ if i == len(timesteps) - 1 or (
296
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
297
+ ):
298
+ progress_bar.update()
299
+ if callback is not None and i % callback_steps == 0:
300
+ callback(i, t, latents)
301
+
302
+ if not output_type == "latent":
303
+ image = self.vae.decode(
304
+ latents / self.vae.config.scaling_factor, return_dict=False
305
+ )[0]
306
+ image, has_nsfw_concept = self.run_safety_checker(
307
+ image, device, prompt_embeds.dtype
308
+ )
309
+ else:
310
+ image = latents
311
+ has_nsfw_concept = None
312
+
313
+ if has_nsfw_concept is None:
314
+ do_denormalize = [True] * image.shape[0]
315
+ else:
316
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
317
+
318
+ image = self.image_processor.postprocess(
319
+ image, output_type=output_type, do_denormalize=do_denormalize
320
+ )
321
+
322
+ # Offload last model to CPU
323
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
324
+ self.final_offload_hook.offload()
325
+
326
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_mask.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+
12
+ from .ddim_with_logprob_w_x0 import ddim_step_with_logprob_w_x0
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_w_eta_mask(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ eta_step: Optional[int] = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ guidance_rescale: float = 0.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+
40
+ Args:
41
+ prompt (`str` or `List[str]`, *optional*):
42
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
43
+ instead.
44
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
45
+ The height in pixels of the generated image.
46
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
47
+ The width in pixels of the generated image.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ guidance_scale (`float`, *optional*, defaults to 7.5):
52
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
53
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
54
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
55
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
56
+ usually at the expense of lower image quality.
57
+ negative_prompt (`str` or `List[str]`, *optional*):
58
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
59
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
60
+ less than `1`).
61
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
62
+ The number of images to generate per prompt.
63
+ eta (`float`, *optional*, defaults to 0.0):
64
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
65
+ [`schedulers.DDIMScheduler`], will be ignored for others.
66
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
67
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
68
+ to make generation deterministic.
69
+ latents (`torch.FloatTensor`, *optional*):
70
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
71
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
72
+ tensor will ge generated by sampling using the supplied random `generator`.
73
+ prompt_embeds (`torch.FloatTensor`, *optional*):
74
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
75
+ provided, text embeddings will be generated from `prompt` input argument.
76
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
77
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
78
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
79
+ argument.
80
+ output_type (`str`, *optional*, defaults to `"pil"`):
81
+ The output format of the generate image. Choose between
82
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
83
+ return_dict (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
85
+ plain tuple.
86
+ callback (`Callable`, *optional*):
87
+ A function that will be called every `callback_steps` steps during inference. The function will be
88
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
89
+ callback_steps (`int`, *optional*, defaults to 1):
90
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
91
+ called at every step.
92
+ cross_attention_kwargs (`dict`, *optional*):
93
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
94
+ `self.processor` in
95
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
96
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
97
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
98
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
99
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
100
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
101
+
102
+ Examples:
103
+
104
+ Returns:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
106
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
107
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
108
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
109
+ (nsfw) content, according to the `safety_checker`.
110
+ """
111
+
112
+ # 0. Default height and width to unet
113
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
114
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
115
+
116
+ # 1. Check inputs. Raise error if not correct
117
+ self.check_inputs(
118
+ prompt,
119
+ height,
120
+ width,
121
+ callback_steps,
122
+ negative_prompt,
123
+ prompt_embeds,
124
+ negative_prompt_embeds,
125
+ )
126
+
127
+ # 2. Define call parameters
128
+ if prompt is not None and isinstance(prompt, str):
129
+ batch_size = 1
130
+ elif prompt is not None and isinstance(prompt, list):
131
+ batch_size = len(prompt)
132
+ else:
133
+ batch_size = prompt_embeds.shape[0]
134
+
135
+ device = self._execution_device
136
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
137
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
138
+ # corresponds to doing no classifier free guidance.
139
+ do_classifier_free_guidance = guidance_scale > 1.0
140
+
141
+ # 3. Encode input prompt
142
+ text_encoder_lora_scale = (
143
+ cross_attention_kwargs.get("scale", None)
144
+ if cross_attention_kwargs is not None
145
+ else None
146
+ )
147
+ prompt_embeds = self._encode_prompt(
148
+ prompt,
149
+ device,
150
+ num_images_per_prompt,
151
+ do_classifier_free_guidance,
152
+ negative_prompt,
153
+ prompt_embeds=prompt_embeds,
154
+ negative_prompt_embeds=negative_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ )
157
+
158
+ # 4. Prepare timesteps
159
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
160
+ timesteps = self.scheduler.timesteps
161
+
162
+ # 5. Prepare latent variables
163
+ num_channels_latents = self.unet.config.in_channels
164
+ latents = self.prepare_latents(
165
+ batch_size * num_images_per_prompt,
166
+ num_channels_latents,
167
+ height,
168
+ width,
169
+ prompt_embeds.dtype,
170
+ device,
171
+ generator,
172
+ latents,
173
+ )
174
+
175
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
176
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
177
+
178
+ # 7. Denoising loop
179
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
180
+ all_latents = [latents]
181
+ all_log_probs = []
182
+ all_pred_z0 = []
183
+
184
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
185
+ for i, t in enumerate(timesteps):
186
+
187
+ if i < eta_step:
188
+ continue
189
+
190
+ # expand the latents if we are doing classifier free guidance
191
+ latent_model_input = (
192
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
193
+ )
194
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
195
+
196
+ # predict the noise residual
197
+ noise_pred = self.unet(
198
+ latent_model_input,
199
+ t,
200
+ encoder_hidden_states=prompt_embeds,
201
+ cross_attention_kwargs=cross_attention_kwargs,
202
+ return_dict=False,
203
+ )[0]
204
+
205
+ # perform guidance
206
+ if do_classifier_free_guidance:
207
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
208
+ noise_pred = noise_pred_uncond + guidance_scale * (
209
+ noise_pred_text - noise_pred_uncond
210
+ )
211
+
212
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
213
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
214
+ noise_pred = rescale_noise_cfg(
215
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
216
+ )
217
+
218
+ # compute the previous noisy sample x_t -> x_t-1
219
+ ## 仅第一步引入随机性
220
+ if i == eta_step:
221
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
222
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
223
+ )
224
+ else:
225
+ ## 其他步按DDIM确定性采样得到结果
226
+ extra_step_kwargs["eta"] = 0
227
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
228
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
229
+ )
230
+
231
+ all_latents.append(latents)
232
+ all_log_probs.append(log_prob)
233
+ all_pred_z0.append(pred_z0)
234
+
235
+ # call the callback, if provided
236
+ if i == len(timesteps) - 1 or (
237
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
238
+ ):
239
+ progress_bar.update()
240
+ if callback is not None and i % callback_steps == 0:
241
+ callback(i, t, latents)
242
+
243
+ if not output_type == "latent":
244
+ image = self.vae.decode(
245
+ latents / self.vae.config.scaling_factor, return_dict=False
246
+ )[0]
247
+ image, has_nsfw_concept = self.run_safety_checker(
248
+ image, device, prompt_embeds.dtype
249
+ )
250
+ else:
251
+ image = latents
252
+ has_nsfw_concept = None
253
+
254
+ if has_nsfw_concept is None:
255
+ do_denormalize = [True] * image.shape[0]
256
+ else:
257
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
258
+
259
+ image = self.image_processor.postprocess(
260
+ image, output_type=output_type, do_denormalize=do_denormalize
261
+ )
262
+
263
+ # Offload last model to CPU
264
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
265
+ self.final_offload_hook.offload()
266
+
267
+ return image, has_nsfw_concept, all_latents, all_log_probs, all_pred_z0
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_mask2.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+ from .ddim_with_logprob_w_x0_2 import ddim_step_with_logprob_w_x0
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_w_eta_mask(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ eta_step: Optional[int] = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ guidance_rescale: float = 0.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+
40
+ Args:
41
+ prompt (`str` or `List[str]`, *optional*):
42
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
43
+ instead.
44
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
45
+ The height in pixels of the generated image.
46
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
47
+ The width in pixels of the generated image.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ guidance_scale (`float`, *optional*, defaults to 7.5):
52
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
53
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
54
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
55
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
56
+ usually at the expense of lower image quality.
57
+ negative_prompt (`str` or `List[str]`, *optional*):
58
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
59
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
60
+ less than `1`).
61
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
62
+ The number of images to generate per prompt.
63
+ eta (`float`, *optional*, defaults to 0.0):
64
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
65
+ [`schedulers.DDIMScheduler`], will be ignored for others.
66
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
67
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
68
+ to make generation deterministic.
69
+ latents (`torch.FloatTensor`, *optional*):
70
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
71
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
72
+ tensor will ge generated by sampling using the supplied random `generator`.
73
+ prompt_embeds (`torch.FloatTensor`, *optional*):
74
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
75
+ provided, text embeddings will be generated from `prompt` input argument.
76
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
77
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
78
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
79
+ argument.
80
+ output_type (`str`, *optional*, defaults to `"pil"`):
81
+ The output format of the generate image. Choose between
82
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
83
+ return_dict (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
85
+ plain tuple.
86
+ callback (`Callable`, *optional*):
87
+ A function that will be called every `callback_steps` steps during inference. The function will be
88
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
89
+ callback_steps (`int`, *optional*, defaults to 1):
90
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
91
+ called at every step.
92
+ cross_attention_kwargs (`dict`, *optional*):
93
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
94
+ `self.processor` in
95
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
96
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
97
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
98
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
99
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
100
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
101
+
102
+ Examples:
103
+
104
+ Returns:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
106
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
107
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
108
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
109
+ (nsfw) content, according to the `safety_checker`.
110
+ """
111
+
112
+ # 0. Default height and width to unet
113
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
114
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
115
+
116
+ # 1. Check inputs. Raise error if not correct
117
+ self.check_inputs(
118
+ prompt,
119
+ height,
120
+ width,
121
+ callback_steps,
122
+ negative_prompt,
123
+ prompt_embeds,
124
+ negative_prompt_embeds,
125
+ )
126
+
127
+ # 2. Define call parameters
128
+ if prompt is not None and isinstance(prompt, str):
129
+ batch_size = 1
130
+ elif prompt is not None and isinstance(prompt, list):
131
+ batch_size = len(prompt)
132
+ else:
133
+ batch_size = prompt_embeds.shape[0]
134
+
135
+ device = self._execution_device
136
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
137
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
138
+ # corresponds to doing no classifier free guidance.
139
+ do_classifier_free_guidance = guidance_scale > 1.0
140
+
141
+ # 3. Encode input prompt
142
+ text_encoder_lora_scale = (
143
+ cross_attention_kwargs.get("scale", None)
144
+ if cross_attention_kwargs is not None
145
+ else None
146
+ )
147
+ prompt_embeds = self._encode_prompt(
148
+ prompt,
149
+ device,
150
+ num_images_per_prompt,
151
+ do_classifier_free_guidance,
152
+ negative_prompt,
153
+ prompt_embeds=prompt_embeds,
154
+ negative_prompt_embeds=negative_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ )
157
+
158
+ # 4. Prepare timesteps
159
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
160
+ timesteps = self.scheduler.timesteps
161
+
162
+ # 5. Prepare latent variables
163
+ num_channels_latents = self.unet.config.in_channels
164
+ latents = self.prepare_latents(
165
+ batch_size * num_images_per_prompt,
166
+ num_channels_latents,
167
+ height,
168
+ width,
169
+ prompt_embeds.dtype,
170
+ device,
171
+ generator,
172
+ latents,
173
+ )
174
+
175
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
176
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
177
+
178
+ # 7. Denoising loop
179
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
180
+ all_latents = [latents]
181
+ all_log_probs = []
182
+ all_pred_z0 = []
183
+
184
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
185
+ for i, t in enumerate(timesteps):
186
+
187
+ if i < eta_step:
188
+ continue
189
+
190
+ # expand the latents if we are doing classifier free guidance
191
+ latent_model_input = (
192
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
193
+ )
194
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
195
+
196
+ # predict the noise residual
197
+ noise_pred = self.unet(
198
+ latent_model_input,
199
+ t,
200
+ encoder_hidden_states=prompt_embeds,
201
+ cross_attention_kwargs=cross_attention_kwargs,
202
+ return_dict=False,
203
+ )[0]
204
+
205
+ # perform guidance
206
+ if do_classifier_free_guidance:
207
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
208
+ noise_pred = noise_pred_uncond + guidance_scale * (
209
+ noise_pred_text - noise_pred_uncond
210
+ )
211
+
212
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
213
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
214
+ noise_pred = rescale_noise_cfg(
215
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
216
+ )
217
+
218
+ # compute the previous noisy sample x_t -> x_t-1
219
+ ## 仅第一步引入随机性
220
+ if i == eta_step:
221
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
222
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
223
+ )
224
+ else:
225
+ ## 其他步按DDIM确定性采样得到结果
226
+ extra_step_kwargs["eta"] = 0
227
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
228
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
229
+ )
230
+
231
+ all_latents.append(latents)
232
+ all_log_probs.append(log_prob)
233
+ all_pred_z0.append(pred_z0)
234
+
235
+ # call the callback, if provided
236
+ if i == len(timesteps) - 1 or (
237
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
238
+ ):
239
+ progress_bar.update()
240
+ if callback is not None and i % callback_steps == 0:
241
+ callback(i, t, latents)
242
+
243
+ if not output_type == "latent":
244
+ image = self.vae.decode(
245
+ latents / self.vae.config.scaling_factor, return_dict=False
246
+ )[0]
247
+ image, has_nsfw_concept = self.run_safety_checker(
248
+ image, device, prompt_embeds.dtype
249
+ )
250
+ else:
251
+ image = latents
252
+ has_nsfw_concept = None
253
+
254
+ if has_nsfw_concept is None:
255
+ do_denormalize = [True] * image.shape[0]
256
+ else:
257
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
258
+
259
+ image = self.image_processor.postprocess(
260
+ image, output_type=output_type, do_denormalize=do_denormalize
261
+ )
262
+
263
+ # Offload last model to CPU
264
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
265
+ self.final_offload_hook.offload()
266
+
267
+ return image, has_nsfw_concept, all_latents, all_log_probs, all_pred_z0
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_v7.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+ from .ddim_with_logprob_w_x0_v7 import ddim_step_with_logprob_w_x0
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_w_eta_mask(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ eta_step: Optional[int] = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ guidance_rescale: float = 0.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+
40
+ Args:
41
+ prompt (`str` or `List[str]`, *optional*):
42
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
43
+ instead.
44
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
45
+ The height in pixels of the generated image.
46
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
47
+ The width in pixels of the generated image.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ guidance_scale (`float`, *optional*, defaults to 7.5):
52
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
53
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
54
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
55
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
56
+ usually at the expense of lower image quality.
57
+ negative_prompt (`str` or `List[str]`, *optional*):
58
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
59
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
60
+ less than `1`).
61
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
62
+ The number of images to generate per prompt.
63
+ eta (`float`, *optional*, defaults to 0.0):
64
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
65
+ [`schedulers.DDIMScheduler`], will be ignored for others.
66
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
67
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
68
+ to make generation deterministic.
69
+ latents (`torch.FloatTensor`, *optional*):
70
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
71
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
72
+ tensor will ge generated by sampling using the supplied random `generator`.
73
+ prompt_embeds (`torch.FloatTensor`, *optional*):
74
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
75
+ provided, text embeddings will be generated from `prompt` input argument.
76
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
77
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
78
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
79
+ argument.
80
+ output_type (`str`, *optional*, defaults to `"pil"`):
81
+ The output format of the generate image. Choose between
82
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
83
+ return_dict (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
85
+ plain tuple.
86
+ callback (`Callable`, *optional*):
87
+ A function that will be called every `callback_steps` steps during inference. The function will be
88
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
89
+ callback_steps (`int`, *optional*, defaults to 1):
90
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
91
+ called at every step.
92
+ cross_attention_kwargs (`dict`, *optional*):
93
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
94
+ `self.processor` in
95
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
96
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
97
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
98
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
99
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
100
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
101
+
102
+ Examples:
103
+
104
+ Returns:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
106
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
107
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
108
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
109
+ (nsfw) content, according to the `safety_checker`.
110
+ """
111
+
112
+ # 0. Default height and width to unet
113
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
114
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
115
+
116
+ # 1. Check inputs. Raise error if not correct
117
+ self.check_inputs(
118
+ prompt,
119
+ height,
120
+ width,
121
+ callback_steps,
122
+ negative_prompt,
123
+ prompt_embeds,
124
+ negative_prompt_embeds,
125
+ )
126
+
127
+ # 2. Define call parameters
128
+ if prompt is not None and isinstance(prompt, str):
129
+ batch_size = 1
130
+ elif prompt is not None and isinstance(prompt, list):
131
+ batch_size = len(prompt)
132
+ else:
133
+ batch_size = prompt_embeds.shape[0]
134
+
135
+ device = self._execution_device
136
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
137
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
138
+ # corresponds to doing no classifier free guidance.
139
+ do_classifier_free_guidance = guidance_scale > 1.0
140
+
141
+ # 3. Encode input prompt
142
+ text_encoder_lora_scale = (
143
+ cross_attention_kwargs.get("scale", None)
144
+ if cross_attention_kwargs is not None
145
+ else None
146
+ )
147
+ prompt_embeds = self._encode_prompt(
148
+ prompt,
149
+ device,
150
+ num_images_per_prompt,
151
+ do_classifier_free_guidance,
152
+ negative_prompt,
153
+ prompt_embeds=prompt_embeds,
154
+ negative_prompt_embeds=negative_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ )
157
+
158
+ # 4. Prepare timesteps
159
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
160
+ timesteps = self.scheduler.timesteps
161
+
162
+ # 5. Prepare latent variables
163
+ num_channels_latents = self.unet.config.in_channels
164
+ latents = self.prepare_latents(
165
+ batch_size * num_images_per_prompt,
166
+ num_channels_latents,
167
+ height,
168
+ width,
169
+ prompt_embeds.dtype,
170
+ device,
171
+ generator,
172
+ latents,
173
+ )
174
+
175
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
176
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
177
+
178
+ # 7. Denoising loop
179
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
180
+ all_latents = [latents]
181
+ all_log_probs = []
182
+ all_pred_z0 = []
183
+
184
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
185
+ for i, t in enumerate(timesteps):
186
+
187
+ if i < eta_step:
188
+ continue
189
+
190
+ # expand the latents if we are doing classifier free guidance
191
+ latent_model_input = (
192
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
193
+ )
194
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
195
+
196
+ # predict the noise residual
197
+ noise_pred = self.unet(
198
+ latent_model_input,
199
+ t,
200
+ encoder_hidden_states=prompt_embeds,
201
+ cross_attention_kwargs=cross_attention_kwargs,
202
+ return_dict=False,
203
+ )[0]
204
+
205
+ # perform guidance
206
+ if do_classifier_free_guidance:
207
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
208
+ noise_pred = noise_pred_uncond + guidance_scale * (
209
+ noise_pred_text - noise_pred_uncond
210
+ )
211
+
212
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
213
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
214
+ noise_pred = rescale_noise_cfg(
215
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
216
+ )
217
+
218
+ # compute the previous noisy sample x_t -> x_t-1
219
+ ## 仅第一步引入随机性
220
+ if i == eta_step:
221
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
222
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
223
+ )
224
+ else:
225
+ ## 其他步按DDIM确定性采样得到结果
226
+ extra_step_kwargs["eta"] = 0
227
+ latents, log_prob, pred_z0 = ddim_step_with_logprob_w_x0(
228
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
229
+ )
230
+
231
+ all_latents.append(latents)
232
+ all_log_probs.append(log_prob)
233
+ all_pred_z0.append(pred_z0)
234
+
235
+ # call the callback, if provided
236
+ if i == len(timesteps) - 1 or (
237
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
238
+ ):
239
+ progress_bar.update()
240
+ if callback is not None and i % callback_steps == 0:
241
+ callback(i, t, latents)
242
+
243
+ if not output_type == "latent":
244
+ image = self.vae.decode(
245
+ latents / self.vae.config.scaling_factor, return_dict=False
246
+ )[0]
247
+ image, has_nsfw_concept = self.run_safety_checker(
248
+ image, device, prompt_embeds.dtype
249
+ )
250
+ else:
251
+ image = latents
252
+ has_nsfw_concept = None
253
+
254
+ if has_nsfw_concept is None:
255
+ do_denormalize = [True] * image.shape[0]
256
+ else:
257
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
258
+
259
+ image = self.image_processor.postprocess(
260
+ image, output_type=output_type, do_denormalize=do_denormalize
261
+ )
262
+
263
+ # Offload last model to CPU
264
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
265
+ self.final_offload_hook.offload()
266
+
267
+ return image, has_nsfw_concept, all_latents, all_log_probs, all_pred_z0
fastvideo/models/stable_diffusion/pipeline_with_logprob_w_eta_v8.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob_v8 import ddim_step_with_logprob
12
+
13
+ @torch.no_grad()
14
+ def pipeline_with_logprob_w_eta(
15
+ self: StableDiffusionPipeline,
16
+ prompt: Union[str, List[str]] = None,
17
+ eta_step: Optional[int] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+
111
+ # 0. Default height and width to unet
112
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
113
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
114
+
115
+ # 1. Check inputs. Raise error if not correct
116
+ self.check_inputs(
117
+ prompt,
118
+ height,
119
+ width,
120
+ callback_steps,
121
+ negative_prompt,
122
+ prompt_embeds,
123
+ negative_prompt_embeds,
124
+ )
125
+
126
+ # 2. Define call parameters
127
+ if prompt is not None and isinstance(prompt, str):
128
+ batch_size = 1
129
+ elif prompt is not None and isinstance(prompt, list):
130
+ batch_size = len(prompt)
131
+ else:
132
+ batch_size = prompt_embeds.shape[0]
133
+
134
+ device = self._execution_device
135
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
136
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
137
+ # corresponds to doing no classifier free guidance.
138
+ do_classifier_free_guidance = guidance_scale > 1.0
139
+
140
+ # 3. Encode input prompt
141
+ text_encoder_lora_scale = (
142
+ cross_attention_kwargs.get("scale", None)
143
+ if cross_attention_kwargs is not None
144
+ else None
145
+ )
146
+ prompt_embeds = self._encode_prompt(
147
+ prompt,
148
+ device,
149
+ num_images_per_prompt,
150
+ do_classifier_free_guidance,
151
+ negative_prompt,
152
+ prompt_embeds=prompt_embeds,
153
+ negative_prompt_embeds=negative_prompt_embeds,
154
+ lora_scale=text_encoder_lora_scale,
155
+ )
156
+
157
+ # 4. Prepare timesteps
158
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
159
+ timesteps = self.scheduler.timesteps
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Denoising loop
178
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
179
+ all_latents = [latents]
180
+ all_log_probs = []
181
+ all_prev_sample_mean = []
182
+ all_std_dev_t = []
183
+ all_variance_noise = []
184
+
185
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
186
+ for i, t in enumerate(timesteps):
187
+
188
+ if i < eta_step:
189
+ continue
190
+
191
+ # expand the latents if we are doing classifier free guidance
192
+ latent_model_input = (
193
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
194
+ )
195
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
196
+
197
+ # predict the noise residual
198
+ noise_pred = self.unet(
199
+ latent_model_input,
200
+ t,
201
+ encoder_hidden_states=prompt_embeds,
202
+ cross_attention_kwargs=cross_attention_kwargs,
203
+ return_dict=False,
204
+ )[0]
205
+
206
+ # perform guidance
207
+ if do_classifier_free_guidance:
208
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
209
+ noise_pred = noise_pred_uncond + guidance_scale * (
210
+ noise_pred_text - noise_pred_uncond
211
+ )
212
+
213
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
214
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
215
+ noise_pred = rescale_noise_cfg(
216
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
217
+ )
218
+
219
+ # compute the previous noisy sample x_t -> x_t-1
220
+ ## 仅第一步引入随机性
221
+ if i == eta_step:
222
+ latents, log_prob, prev_sample_mean, std_dev_t, variance_noise = ddim_step_with_logprob(
223
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
224
+ )
225
+ else:
226
+ ## 其他步按DDIM确定性采样得到结果
227
+ extra_step_kwargs["eta"] = 0
228
+ latents, log_prob, prev_sample_mean, std_dev_t, variance_noise = ddim_step_with_logprob(
229
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
230
+ )
231
+
232
+ all_latents.append(latents)
233
+ all_log_probs.append(log_prob)
234
+ all_prev_sample_mean.append(prev_sample_mean)
235
+ all_std_dev_t.append(std_dev_t)
236
+ all_variance_noise.append(variance_noise)
237
+
238
+ # call the callback, if provided
239
+ if i == len(timesteps) - 1 or (
240
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
241
+ ):
242
+ progress_bar.update()
243
+ if callback is not None and i % callback_steps == 0:
244
+ callback(i, t, latents)
245
+
246
+ if not output_type == "latent":
247
+ image = self.vae.decode(
248
+ latents / self.vae.config.scaling_factor, return_dict=False
249
+ )[0]
250
+ image, has_nsfw_concept = self.run_safety_checker(
251
+ image, device, prompt_embeds.dtype
252
+ )
253
+ else:
254
+ image = latents
255
+ has_nsfw_concept = None
256
+
257
+ if has_nsfw_concept is None:
258
+ do_denormalize = [True] * image.shape[0]
259
+ else:
260
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
261
+
262
+ image = self.image_processor.postprocess(
263
+ image, output_type=output_type, do_denormalize=do_denormalize
264
+ )
265
+
266
+ # Offload last model to CPU
267
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
268
+ self.final_offload_hook.offload()
269
+
270
+ return image, has_nsfw_concept, all_latents, all_log_probs, all_prev_sample_mean, all_std_dev_t, all_variance_noise
fastvideo/models/stable_diffusion/pipeline_with_logprob_wo_eta.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob_wo_eta import ddim_step_with_logprob
12
+
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_wo_eta(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+ # 0. Default height and width to unet
111
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
112
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
113
+
114
+ # 1. Check inputs. Raise error if not correct
115
+ self.check_inputs(
116
+ prompt,
117
+ height,
118
+ width,
119
+ callback_steps,
120
+ negative_prompt,
121
+ prompt_embeds,
122
+ negative_prompt_embeds,
123
+ )
124
+
125
+ # 2. Define call parameters
126
+ if prompt is not None and isinstance(prompt, str):
127
+ batch_size = 1
128
+ elif prompt is not None and isinstance(prompt, list):
129
+ batch_size = len(prompt)
130
+ else:
131
+ batch_size = prompt_embeds.shape[0]
132
+
133
+ device = self._execution_device
134
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
135
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
136
+ # corresponds to doing no classifier free guidance.
137
+ do_classifier_free_guidance = guidance_scale > 1.0
138
+
139
+ # 3. Encode input prompt
140
+ text_encoder_lora_scale = (
141
+ cross_attention_kwargs.get("scale", None)
142
+ if cross_attention_kwargs is not None
143
+ else None
144
+ )
145
+ prompt_embeds = self._encode_prompt(
146
+ prompt,
147
+ device,
148
+ num_images_per_prompt,
149
+ do_classifier_free_guidance,
150
+ negative_prompt,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ lora_scale=text_encoder_lora_scale,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ all_latents = [latents]
179
+ all_log_probs = []
180
+
181
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
182
+ for i, t in enumerate(timesteps):
183
+
184
+ # expand the latents if we are doing classifier free guidance
185
+ latent_model_input = (
186
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
187
+ )
188
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
189
+
190
+ # predict the noise residual
191
+ noise_pred = self.unet(
192
+ latent_model_input,
193
+ t,
194
+ encoder_hidden_states=prompt_embeds,
195
+ cross_attention_kwargs=cross_attention_kwargs,
196
+ return_dict=False,
197
+ )[0]
198
+
199
+ # perform guidance
200
+ if do_classifier_free_guidance:
201
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
202
+ noise_pred = noise_pred_uncond + guidance_scale * (
203
+ noise_pred_text - noise_pred_uncond
204
+ )
205
+
206
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
207
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
208
+ noise_pred = rescale_noise_cfg(
209
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
210
+ )
211
+
212
+ # compute the previous noisy sample x_t -> x_t-1
213
+ latents, log_prob = ddim_step_with_logprob(
214
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
215
+ )
216
+
217
+ all_latents.append(latents)
218
+ all_log_probs.append(log_prob)
219
+
220
+ # call the callback, if provided
221
+ if i == len(timesteps) - 1 or (
222
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
223
+ ):
224
+ progress_bar.update()
225
+ if callback is not None and i % callback_steps == 0:
226
+ callback(i, t, latents)
227
+
228
+ if not output_type == "latent": ## false
229
+ image = self.vae.decode(
230
+ latents / self.vae.config.scaling_factor, return_dict=False
231
+ )[0]
232
+ image, has_nsfw_concept = self.run_safety_checker(
233
+ image, device, prompt_embeds.dtype
234
+ )
235
+ else:
236
+ image = latents
237
+ has_nsfw_concept = None
238
+
239
+ if has_nsfw_concept is None:
240
+ do_denormalize = [True] * image.shape[0]
241
+ else:
242
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
243
+
244
+ image = self.image_processor.postprocess(
245
+ image, output_type=output_type, do_denormalize=do_denormalize
246
+ )
247
+
248
+ # Offload last model to CPU
249
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
250
+ self.final_offload_hook.offload()
251
+
252
+ return image, has_nsfw_concept, all_latents, all_log_probs
fastvideo/models/stable_diffusion/pipeline_with_logprob_wo_eta_2.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py, which is licensed under MIT license.
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
8
+ StableDiffusionPipeline,
9
+ rescale_noise_cfg,
10
+ )
11
+ from .ddim_with_logprob import ddim_step_with_logprob
12
+
13
+
14
+ @torch.no_grad()
15
+ def pipeline_with_logprob_wo_eta(
16
+ self: StableDiffusionPipeline,
17
+ prompt: Union[str, List[str]] = None,
18
+ height: Optional[int] = None,
19
+ width: Optional[int] = None,
20
+ num_inference_steps: int = 50,
21
+ guidance_scale: float = 7.5,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ num_images_per_prompt: Optional[int] = 1,
24
+ eta: float = 0.0,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
32
+ callback_steps: int = 1,
33
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ guidance_rescale: float = 0.0,
35
+ ):
36
+ r"""
37
+ Function invoked when calling the pipeline for generation.
38
+
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
96
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
97
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
98
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
100
+
101
+ Examples:
102
+
103
+ Returns:
104
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
105
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
106
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
107
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
108
+ (nsfw) content, according to the `safety_checker`.
109
+ """
110
+ # 0. Default height and width to unet
111
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
112
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
113
+
114
+ # 1. Check inputs. Raise error if not correct
115
+ self.check_inputs(
116
+ prompt,
117
+ height,
118
+ width,
119
+ callback_steps,
120
+ negative_prompt,
121
+ prompt_embeds,
122
+ negative_prompt_embeds,
123
+ )
124
+
125
+ # 2. Define call parameters
126
+ if prompt is not None and isinstance(prompt, str):
127
+ batch_size = 1
128
+ elif prompt is not None and isinstance(prompt, list):
129
+ batch_size = len(prompt)
130
+ else:
131
+ batch_size = prompt_embeds.shape[0]
132
+
133
+ device = self._execution_device
134
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
135
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
136
+ # corresponds to doing no classifier free guidance.
137
+ do_classifier_free_guidance = guidance_scale > 1.0
138
+
139
+ # 3. Encode input prompt
140
+ text_encoder_lora_scale = (
141
+ cross_attention_kwargs.get("scale", None)
142
+ if cross_attention_kwargs is not None
143
+ else None
144
+ )
145
+ prompt_embeds = self._encode_prompt(
146
+ prompt,
147
+ device,
148
+ num_images_per_prompt,
149
+ do_classifier_free_guidance,
150
+ negative_prompt,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ lora_scale=text_encoder_lora_scale,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ all_latents = [latents]
179
+ all_aug_latents = [latents]
180
+
181
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
182
+ for i, t in enumerate(timesteps):
183
+
184
+ # expand the latents if we are doing classifier free guidance
185
+ latent_model_input = (
186
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
187
+ )
188
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
189
+
190
+ # predict the noise residual
191
+ noise_pred = self.unet(
192
+ latent_model_input,
193
+ t,
194
+ encoder_hidden_states=prompt_embeds,
195
+ cross_attention_kwargs=cross_attention_kwargs,
196
+ return_dict=False,
197
+ )[0]
198
+
199
+ # perform guidance
200
+ if do_classifier_free_guidance:
201
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
202
+ noise_pred = noise_pred_uncond + guidance_scale * (
203
+ noise_pred_text - noise_pred_uncond
204
+ )
205
+
206
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
207
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
208
+ noise_pred = rescale_noise_cfg(
209
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
210
+ )
211
+
212
+ # compute the previous noisy sample x_t -> x_t-1
213
+ latents, log_prob = ddim_step_with_logprob(
214
+ self.scheduler, noise_pred, t, latents, **extra_step_kwargs
215
+ )
216
+
217
+ # ### aug_latents
218
+ # aug_latents, _ = ddim_step_with_logprob(
219
+ # self.scheduler, noise_pred, t, latents, eta=1,
220
+ # )
221
+
222
+ all_latents.append(latents)
223
+ # all_aug_latents.append(aug_latents)
224
+
225
+ # call the callback, if provided
226
+ if i == len(timesteps) - 1 or (
227
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
228
+ ):
229
+ progress_bar.update()
230
+ if callback is not None and i % callback_steps == 0:
231
+ callback(i, t, latents)
232
+
233
+ if not output_type == "latent": ## false
234
+ image = self.vae.decode(
235
+ latents / self.vae.config.scaling_factor, return_dict=False
236
+ )[0]
237
+ image, has_nsfw_concept = self.run_safety_checker(
238
+ image, device, prompt_embeds.dtype
239
+ )
240
+ else:
241
+ image = latents
242
+ has_nsfw_concept = None
243
+
244
+ if has_nsfw_concept is None:
245
+ do_denormalize = [True] * image.shape[0]
246
+ else:
247
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
248
+
249
+ image = self.image_processor.postprocess(
250
+ image, output_type=output_type, do_denormalize=do_denormalize
251
+ )
252
+
253
+ # Offload last model to CPU
254
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
255
+ self.final_offload_hook.offload()
256
+
257
+ return image, has_nsfw_concept, all_latents, all_aug_latents