Add model and code
Browse files- diffusers_sv3d/__init__.py +2 -0
- diffusers_sv3d/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusers_sv3d/models/__init__.py +1 -0
- diffusers_sv3d/models/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusers_sv3d/models/unets/__init__.py +1 -0
- diffusers_sv3d/models/unets/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusers_sv3d/models/unets/__pycache__/unet_spatio_temporal_condition.cpython-311.pyc +0 -0
- diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py +483 -0
- diffusers_sv3d/pipelines/__init__.py +1 -0
- diffusers_sv3d/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py +2 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion.cpython-311.pyc +0 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion_rotate.cpython-311.pyc +0 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py +469 -0
- diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion_rotate.py +371 -0
- pretrained_sv3d/feature_extractor/preprocessor_config.json +27 -0
- pretrained_sv3d/image_encoder/config.json +23 -0
- pretrained_sv3d/image_encoder/model.safetensors +3 -0
- pretrained_sv3d/model_index.json +3 -0
- pretrained_sv3d/scheduler/scheduler_config.json +22 -0
- pretrained_sv3d/unet/config.json +37 -0
- pretrained_sv3d/unet/diffusion_pytorch_model.safetensors +3 -0
- pretrained_sv3d/vae/config.json +38 -0
- pretrained_sv3d/vae/diffusion_pytorch_model.safetensors +3 -0
- train.py +79 -0
diffusers_sv3d/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import SV3DUNetSpatioTemporalConditionModel
|
| 2 |
+
from .pipelines import StableVideo3DDiffusionPipeline, StableVideo3DDiffusionPipelineRotate
|
diffusers_sv3d/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (378 Bytes). View file
|
|
|
diffusers_sv3d/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .unets import SV3DUNetSpatioTemporalConditionModel
|
diffusers_sv3d/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (251 Bytes). View file
|
|
|
diffusers_sv3d/models/unets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .unet_spatio_temporal_condition import SV3DUNetSpatioTemporalConditionModel
|
diffusers_sv3d/models/unets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (282 Bytes). View file
|
|
|
diffusers_sv3d/models/unets/__pycache__/unet_spatio_temporal_condition.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Copied from diffusers.models.unets.unet_spatio_temporal_condition UNetSpatioTemporalConditionModel
|
| 7 |
+
class SV3DUNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
| 8 |
+
r"""
|
| 9 |
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
|
| 10 |
+
returns a sample shaped output.
|
| 11 |
+
|
| 12 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 13 |
+
for all models (such as downloading or saving).
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 17 |
+
Height and width of input/output sample.
|
| 18 |
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
| 19 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 20 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
| 21 |
+
The tuple of downsample blocks to use.
|
| 22 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
|
| 23 |
+
The tuple of upsample blocks to use.
|
| 24 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 25 |
+
The tuple of output channels for each block.
|
| 26 |
+
addition_time_embed_dim: (`int`, defaults to 256):
|
| 27 |
+
Dimension to to encode the additional time ids.
|
| 28 |
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
| 29 |
+
The dimension of the projection of encoded `added_time_ids`.
|
| 30 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 31 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 32 |
+
The dimension of the cross attention features.
|
| 33 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
| 34 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 35 |
+
[`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
|
| 36 |
+
[`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
| 37 |
+
[`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
| 38 |
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
| 39 |
+
The number of attention heads.
|
| 40 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
_supports_gradient_checkpointing = True
|
| 44 |
+
|
| 45 |
+
@register_to_config
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
sample_size: Optional[int] = None,
|
| 49 |
+
in_channels: int = 8,
|
| 50 |
+
out_channels: int = 4,
|
| 51 |
+
down_block_types: Tuple[str] = (
|
| 52 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 53 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 54 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 55 |
+
"DownBlockSpatioTemporal",
|
| 56 |
+
),
|
| 57 |
+
up_block_types: Tuple[str] = (
|
| 58 |
+
"UpBlockSpatioTemporal",
|
| 59 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 60 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 61 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 62 |
+
),
|
| 63 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 64 |
+
addition_time_embed_dim: int = 256,
|
| 65 |
+
projection_class_embeddings_input_dim: int = 768,
|
| 66 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 67 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
| 68 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| 69 |
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
| 70 |
+
num_frames: int = 25,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
self.sample_size = sample_size
|
| 75 |
+
|
| 76 |
+
# Check inputs
|
| 77 |
+
if len(down_block_types) != len(up_block_types):
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if len(block_out_channels) != len(down_block_types):
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 98 |
+
raise ValueError(
|
| 99 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# input
|
| 103 |
+
self.conv_in = nn.Conv2d(
|
| 104 |
+
in_channels,
|
| 105 |
+
block_out_channels[0],
|
| 106 |
+
kernel_size=3,
|
| 107 |
+
padding=1,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# time
|
| 111 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 112 |
+
|
| 113 |
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
| 114 |
+
timestep_input_dim = block_out_channels[0]
|
| 115 |
+
|
| 116 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 117 |
+
|
| 118 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
| 119 |
+
self.add_angle_proj = Timesteps(2*addition_time_embed_dim, True, downscale_freq_shift=0) # encode camera angles
|
| 120 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 121 |
+
|
| 122 |
+
self.down_blocks = nn.ModuleList([])
|
| 123 |
+
self.up_blocks = nn.ModuleList([])
|
| 124 |
+
|
| 125 |
+
if isinstance(num_attention_heads, int):
|
| 126 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 127 |
+
|
| 128 |
+
if isinstance(cross_attention_dim, int):
|
| 129 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 130 |
+
|
| 131 |
+
if isinstance(layers_per_block, int):
|
| 132 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 133 |
+
|
| 134 |
+
if isinstance(transformer_layers_per_block, int):
|
| 135 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 136 |
+
|
| 137 |
+
blocks_time_embed_dim = time_embed_dim
|
| 138 |
+
|
| 139 |
+
# down
|
| 140 |
+
output_channel = block_out_channels[0]
|
| 141 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 142 |
+
input_channel = output_channel
|
| 143 |
+
output_channel = block_out_channels[i]
|
| 144 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 145 |
+
|
| 146 |
+
down_block = get_down_block(
|
| 147 |
+
down_block_type,
|
| 148 |
+
num_layers=layers_per_block[i],
|
| 149 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 150 |
+
in_channels=input_channel,
|
| 151 |
+
out_channels=output_channel,
|
| 152 |
+
temb_channels=blocks_time_embed_dim,
|
| 153 |
+
add_downsample=not is_final_block,
|
| 154 |
+
resnet_eps=1e-5,
|
| 155 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 156 |
+
num_attention_heads=num_attention_heads[i],
|
| 157 |
+
resnet_act_fn="silu",
|
| 158 |
+
)
|
| 159 |
+
self.down_blocks.append(down_block)
|
| 160 |
+
|
| 161 |
+
# mid
|
| 162 |
+
self.mid_block = UNetMidBlockSpatioTemporal(
|
| 163 |
+
block_out_channels[-1],
|
| 164 |
+
temb_channels=blocks_time_embed_dim,
|
| 165 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 166 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 167 |
+
num_attention_heads=num_attention_heads[-1],
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# count how many layers upsample the images
|
| 171 |
+
self.num_upsamplers = 0
|
| 172 |
+
|
| 173 |
+
# up
|
| 174 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 175 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 176 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 177 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 178 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
| 179 |
+
|
| 180 |
+
output_channel = reversed_block_out_channels[0]
|
| 181 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 182 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 183 |
+
|
| 184 |
+
prev_output_channel = output_channel
|
| 185 |
+
output_channel = reversed_block_out_channels[i]
|
| 186 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 187 |
+
|
| 188 |
+
# add upsample block for all BUT final layer
|
| 189 |
+
if not is_final_block:
|
| 190 |
+
add_upsample = True
|
| 191 |
+
self.num_upsamplers += 1
|
| 192 |
+
else:
|
| 193 |
+
add_upsample = False
|
| 194 |
+
|
| 195 |
+
up_block = get_up_block(
|
| 196 |
+
up_block_type,
|
| 197 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 198 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 199 |
+
in_channels=input_channel,
|
| 200 |
+
out_channels=output_channel,
|
| 201 |
+
prev_output_channel=prev_output_channel,
|
| 202 |
+
temb_channels=blocks_time_embed_dim,
|
| 203 |
+
add_upsample=add_upsample,
|
| 204 |
+
resnet_eps=1e-5,
|
| 205 |
+
resolution_idx=i,
|
| 206 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 207 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 208 |
+
resnet_act_fn="silu",
|
| 209 |
+
)
|
| 210 |
+
self.up_blocks.append(up_block)
|
| 211 |
+
prev_output_channel = output_channel
|
| 212 |
+
|
| 213 |
+
# out
|
| 214 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
|
| 215 |
+
self.conv_act = nn.SiLU()
|
| 216 |
+
|
| 217 |
+
self.conv_out = nn.Conv2d(
|
| 218 |
+
block_out_channels[0],
|
| 219 |
+
out_channels,
|
| 220 |
+
kernel_size=3,
|
| 221 |
+
padding=1,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 226 |
+
r"""
|
| 227 |
+
Returns:
|
| 228 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 229 |
+
indexed by its weight name.
|
| 230 |
+
"""
|
| 231 |
+
# set recursively
|
| 232 |
+
processors = {}
|
| 233 |
+
|
| 234 |
+
def fn_recursive_add_processors(
|
| 235 |
+
name: str,
|
| 236 |
+
module: torch.nn.Module,
|
| 237 |
+
processors: Dict[str, AttentionProcessor],
|
| 238 |
+
):
|
| 239 |
+
if hasattr(module, "get_processor"):
|
| 240 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 241 |
+
|
| 242 |
+
for sub_name, child in module.named_children():
|
| 243 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 244 |
+
|
| 245 |
+
return processors
|
| 246 |
+
|
| 247 |
+
for name, module in self.named_children():
|
| 248 |
+
fn_recursive_add_processors(name, module, processors)
|
| 249 |
+
|
| 250 |
+
return processors
|
| 251 |
+
|
| 252 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 253 |
+
r"""
|
| 254 |
+
Sets the attention processor to use to compute attention.
|
| 255 |
+
|
| 256 |
+
Parameters:
|
| 257 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 258 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 259 |
+
for **all** `Attention` layers.
|
| 260 |
+
|
| 261 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 262 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 263 |
+
|
| 264 |
+
"""
|
| 265 |
+
count = len(self.attn_processors.keys())
|
| 266 |
+
|
| 267 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 270 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 274 |
+
if hasattr(module, "set_processor"):
|
| 275 |
+
if not isinstance(processor, dict):
|
| 276 |
+
module.set_processor(processor)
|
| 277 |
+
else:
|
| 278 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 279 |
+
|
| 280 |
+
for sub_name, child in module.named_children():
|
| 281 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 282 |
+
|
| 283 |
+
for name, module in self.named_children():
|
| 284 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 285 |
+
|
| 286 |
+
def set_default_attn_processor(self):
|
| 287 |
+
"""
|
| 288 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 289 |
+
"""
|
| 290 |
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 291 |
+
processor = AttnProcessor()
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.set_attn_processor(processor)
|
| 298 |
+
|
| 299 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 300 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 301 |
+
module.gradient_checkpointing = value
|
| 302 |
+
|
| 303 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
| 304 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
| 305 |
+
"""
|
| 306 |
+
Sets the attention processor to use [feed forward
|
| 307 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
| 308 |
+
|
| 309 |
+
Parameters:
|
| 310 |
+
chunk_size (`int`, *optional*):
|
| 311 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
| 312 |
+
over each tensor of dim=`dim`.
|
| 313 |
+
dim (`int`, *optional*, defaults to `0`):
|
| 314 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
| 315 |
+
or dim=1 (sequence length).
|
| 316 |
+
"""
|
| 317 |
+
if dim not in [0, 1]:
|
| 318 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
| 319 |
+
|
| 320 |
+
# By default chunk size is 1
|
| 321 |
+
chunk_size = chunk_size or 1
|
| 322 |
+
|
| 323 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 324 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 325 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 326 |
+
|
| 327 |
+
for child in module.children():
|
| 328 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 329 |
+
|
| 330 |
+
for module in self.children():
|
| 331 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
sample: torch.Tensor,
|
| 336 |
+
timestep: Union[torch.Tensor, float, int],
|
| 337 |
+
encoder_hidden_states: torch.Tensor,
|
| 338 |
+
added_time_ids: Union[torch.Tensor, List[torch.Tensor]],
|
| 339 |
+
return_dict: bool = True,
|
| 340 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
| 341 |
+
r"""
|
| 342 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
sample (`torch.Tensor`):
|
| 346 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
| 347 |
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 348 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 349 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
| 350 |
+
added_time_ids: (`torch.Tensor`):
|
| 351 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
| 352 |
+
embeddings and added to the time embeddings.
|
| 353 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 354 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
| 355 |
+
of a plain tuple.
|
| 356 |
+
Returns:
|
| 357 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
| 358 |
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
| 359 |
+
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 360 |
+
"""
|
| 361 |
+
# 1. time
|
| 362 |
+
timesteps = timestep
|
| 363 |
+
if not torch.is_tensor(timesteps):
|
| 364 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 365 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 366 |
+
is_mps = sample.device.type == "mps"
|
| 367 |
+
if isinstance(timestep, float):
|
| 368 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 369 |
+
else:
|
| 370 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 371 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 372 |
+
elif len(timesteps.shape) == 0:
|
| 373 |
+
timesteps = timesteps[None].to(sample.device)
|
| 374 |
+
|
| 375 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 376 |
+
batch_size, num_frames = sample.shape[:2]
|
| 377 |
+
timesteps = timesteps.expand(batch_size)
|
| 378 |
+
|
| 379 |
+
t_emb = self.time_proj(timesteps)
|
| 380 |
+
|
| 381 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 382 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 383 |
+
# there might be better ways to encapsulate this.
|
| 384 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 385 |
+
|
| 386 |
+
emb = self.time_embedding(t_emb)
|
| 387 |
+
|
| 388 |
+
if isinstance(added_time_ids, torch.Tensor):
|
| 389 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
| 390 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
| 391 |
+
time_embeds = time_embeds.to(emb.dtype)
|
| 392 |
+
aug_emb = self.add_embedding(time_embeds)
|
| 393 |
+
emb = emb + aug_emb
|
| 394 |
+
|
| 395 |
+
# Repeat the embeddings num_video_frames times
|
| 396 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
| 397 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
| 398 |
+
elif isinstance(added_time_ids, list):
|
| 399 |
+
# Repeat the embeddings num_video_frames times
|
| 400 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
| 401 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
| 402 |
+
|
| 403 |
+
cond_aug = added_time_ids[0]
|
| 404 |
+
cond_aug_emb = self.add_time_proj(cond_aug.flatten())
|
| 405 |
+
time_embeds = cond_aug_emb
|
| 406 |
+
time_embeds = time_embeds.to(emb.dtype)
|
| 407 |
+
aug_emb = self.add_embedding(time_embeds)
|
| 408 |
+
emb = emb + aug_emb
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError
|
| 411 |
+
|
| 412 |
+
# Flatten the batch and frames dimensions
|
| 413 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
| 414 |
+
sample = sample.flatten(0, 1)
|
| 415 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
| 416 |
+
|
| 417 |
+
# Taken care of in the pipeline (to allow reference manipulations)
|
| 418 |
+
# encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
| 419 |
+
|
| 420 |
+
# 2. pre-process
|
| 421 |
+
sample = self.conv_in(sample)
|
| 422 |
+
|
| 423 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
| 424 |
+
|
| 425 |
+
down_block_res_samples = (sample,)
|
| 426 |
+
for downsample_block in self.down_blocks:
|
| 427 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 428 |
+
sample, res_samples = downsample_block(
|
| 429 |
+
hidden_states=sample,
|
| 430 |
+
temb=emb,
|
| 431 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 432 |
+
image_only_indicator=image_only_indicator,
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
sample, res_samples = downsample_block(
|
| 436 |
+
hidden_states=sample,
|
| 437 |
+
temb=emb,
|
| 438 |
+
image_only_indicator=image_only_indicator,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
down_block_res_samples += res_samples
|
| 442 |
+
|
| 443 |
+
# 4. mid
|
| 444 |
+
sample = self.mid_block(
|
| 445 |
+
hidden_states=sample,
|
| 446 |
+
temb=emb,
|
| 447 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 448 |
+
image_only_indicator=image_only_indicator,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# 5. up
|
| 452 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 453 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 454 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 455 |
+
|
| 456 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 457 |
+
sample = upsample_block(
|
| 458 |
+
hidden_states=sample,
|
| 459 |
+
temb=emb,
|
| 460 |
+
res_hidden_states_tuple=res_samples,
|
| 461 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 462 |
+
image_only_indicator=image_only_indicator,
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
sample = upsample_block(
|
| 466 |
+
hidden_states=sample,
|
| 467 |
+
temb=emb,
|
| 468 |
+
res_hidden_states_tuple=res_samples,
|
| 469 |
+
image_only_indicator=image_only_indicator,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# 6. post-process
|
| 473 |
+
sample = self.conv_norm_out(sample)
|
| 474 |
+
sample = self.conv_act(sample)
|
| 475 |
+
sample = self.conv_out(sample)
|
| 476 |
+
|
| 477 |
+
# 7. Reshape back to original shape
|
| 478 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
| 479 |
+
|
| 480 |
+
if not return_dict:
|
| 481 |
+
return (sample,)
|
| 482 |
+
|
| 483 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
diffusers_sv3d/pipelines/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .stable_video_diffusion import StableVideo3DDiffusionPipeline, StableVideo3DDiffusionPipelineRotate
|
diffusers_sv3d/pipelines/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (318 Bytes). View file
|
|
|
diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline_stable_video_3d_diffusion import StableVideo3DDiffusionPipeline
|
| 2 |
+
from .pipeline_stable_video_3d_diffusion_rotate import StableVideo3DDiffusionPipelineRotate
|
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (418 Bytes). View file
|
|
|
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion_rotate.cpython-311.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import PIL.Image
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 6 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 7 |
+
StableVideoDiffusionPipeline,
|
| 8 |
+
_append_dims,
|
| 9 |
+
randn_tensor,
|
| 10 |
+
retrieve_timesteps,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from self_attn_swap import ACTIVATE_LAYER_CANDIDATE_SV3D, SharedAttentionProcessorThree
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
HEIGHT = 576
|
| 17 |
+
WIDTH = 576
|
| 18 |
+
NUM_FRAMES = 21
|
| 19 |
+
NOISE_AUG_STRENGTH = 1e-5
|
| 20 |
+
DECODE_CHUNK_SIZE = 2
|
| 21 |
+
NUM_VID = 1
|
| 22 |
+
BATCH_SIZE = 1
|
| 23 |
+
MIN_CFG = 1.0
|
| 24 |
+
MAX_CFG = 2.5
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class StableVideo3DDiffusionPipeline(StableVideoDiffusionPipeline):
|
| 28 |
+
def __init__(self, vae, image_encoder, unet, scheduler, feature_extractor):
|
| 29 |
+
super().__init__(vae, image_encoder, unet, scheduler, feature_extractor)
|
| 30 |
+
|
| 31 |
+
def _get_add_time_ids(
|
| 32 |
+
self, dtype: torch.dtype, num_processes, do_classifier_free_guidance: bool
|
| 33 |
+
) -> List[torch.Tensor]:
|
| 34 |
+
cond_aug = torch.tensor([NOISE_AUG_STRENGTH] * 21, dtype=dtype).repeat(BATCH_SIZE * num_processes, 1)
|
| 35 |
+
|
| 36 |
+
if do_classifier_free_guidance:
|
| 37 |
+
cond_aug = torch.cat([cond_aug, cond_aug])
|
| 38 |
+
|
| 39 |
+
add_time_ids = [cond_aug]
|
| 40 |
+
|
| 41 |
+
self.unet.to(dtype=torch.float16)
|
| 42 |
+
self.vae.to(dtype=torch.float16)
|
| 43 |
+
|
| 44 |
+
return add_time_ids
|
| 45 |
+
|
| 46 |
+
def prepare_video_latents(
|
| 47 |
+
self,
|
| 48 |
+
images: List[torch.Tensor],
|
| 49 |
+
timestep: torch.Tensor,
|
| 50 |
+
add_noise: bool = True,
|
| 51 |
+
refine_frames: Optional[int] = None,
|
| 52 |
+
original_latents: Optional[torch.Tensor] = None,
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
"""Prepare video latents by encoding frames and optionally adding noise."""
|
| 55 |
+
encoded_frames = [self._encode_vae_image(image, self.device, NUM_VID, False) for image in images]
|
| 56 |
+
encoded_frames = [frame.to(images[0].dtype) for frame in encoded_frames]
|
| 57 |
+
|
| 58 |
+
# TODO: check scaling factor?
|
| 59 |
+
encoded_frames = [self.vae.config.scaling_factor * frame for frame in encoded_frames]
|
| 60 |
+
|
| 61 |
+
if add_noise:
|
| 62 |
+
video_latents = [
|
| 63 |
+
self.scheduler.add_noise(
|
| 64 |
+
frame,
|
| 65 |
+
randn_tensor(encoded_frames[0].shape, self.generator, self.device, images[0].dtype),
|
| 66 |
+
timestep,
|
| 67 |
+
)
|
| 68 |
+
for frame in encoded_frames
|
| 69 |
+
]
|
| 70 |
+
else:
|
| 71 |
+
video_latents = encoded_frames
|
| 72 |
+
|
| 73 |
+
if refine_frames is not None and original_latents is not None:
|
| 74 |
+
video_latents = encoded_frames
|
| 75 |
+
|
| 76 |
+
for i in range(len(video_latents)):
|
| 77 |
+
if i in refine_frames:
|
| 78 |
+
video_latents[i] = original_latents[i].unsqueeze(0)
|
| 79 |
+
|
| 80 |
+
return torch.stack(video_latents, dim=1)
|
| 81 |
+
|
| 82 |
+
def activate_layers(self, config: Dict[str, List[Union[float, int]]], swapping_type="linear") -> Dict[str, AttnProcessor2_0]:
|
| 83 |
+
"""Activate swapping attention mechanism in specific UNet layers."""
|
| 84 |
+
|
| 85 |
+
# Setup default values first
|
| 86 |
+
default_attn_procs = {}
|
| 87 |
+
|
| 88 |
+
for layer in self.unet.attn_processors.keys():
|
| 89 |
+
default_attn_procs[layer] = AttnProcessor2_0()
|
| 90 |
+
|
| 91 |
+
self.unet.set_attn_processor(default_attn_procs)
|
| 92 |
+
|
| 93 |
+
spatial_attn = [layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".transformer_blocks.0.attn1" in layer]
|
| 94 |
+
temporal_attn = [
|
| 95 |
+
layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".temporal_transformer_blocks.0.attn1" in layer
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
assert len(spatial_attn) == len(config["spatial_ratio"]) == len(config["spatial_strength"])
|
| 99 |
+
assert len(temporal_attn) == len(config["temporal_ratio"]) == len(config["temporal_strength"])
|
| 100 |
+
|
| 101 |
+
ratios = {}
|
| 102 |
+
for layer, ratio, strength in zip(spatial_attn, config["spatial_ratio"], config["spatial_strength"]):
|
| 103 |
+
ratios[layer] = {"ratio": ratio, "strength": strength}
|
| 104 |
+
|
| 105 |
+
for layer, ratio, strength in zip(temporal_attn, config["temporal_ratio"], config["temporal_strength"]):
|
| 106 |
+
ratios[layer] = {"ratio": ratio, "strength": strength}
|
| 107 |
+
|
| 108 |
+
attn_procs = {}
|
| 109 |
+
|
| 110 |
+
for layer in self.unet.attn_processors.keys():
|
| 111 |
+
if layer in ratios:
|
| 112 |
+
attn_procs[layer] = SharedAttentionProcessorThree(
|
| 113 |
+
unet_chunk_size=2, activate_step_indices=config["activate_steps"], ratio=ratios[layer], swapping_type=swapping_type
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
attn_procs[layer] = AttnProcessor2_0()
|
| 117 |
+
|
| 118 |
+
self.unet.set_attn_processor(attn_procs)
|
| 119 |
+
|
| 120 |
+
return attn_procs
|
| 121 |
+
|
| 122 |
+
def _decode_vae_frames(self, image_latents: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
frames = []
|
| 124 |
+
for i in range(21):
|
| 125 |
+
frame = self.vae.decode(image_latents[:, i], self.device).sample
|
| 126 |
+
frames.append(frame)
|
| 127 |
+
return torch.stack(frames, dim=2)
|
| 128 |
+
|
| 129 |
+
def _preprocess_reference_images(self, reference_images: List[PIL.Image.Image]) -> List[torch.Tensor]:
|
| 130 |
+
"""Helper method to preprocess reference images consistently"""
|
| 131 |
+
processed_images = []
|
| 132 |
+
for image in reference_images:
|
| 133 |
+
ref_image = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
|
| 134 |
+
ref_noise = randn_tensor(ref_image.shape, self.generator, self.device, ref_image.dtype)
|
| 135 |
+
ref_image = ref_image + NOISE_AUG_STRENGTH * ref_noise
|
| 136 |
+
processed_images.append(ref_image)
|
| 137 |
+
return processed_images
|
| 138 |
+
|
| 139 |
+
def _preprocess_image(self, image: Union[PIL.Image.Image, torch.Tensor]) -> torch.Tensor:
|
| 140 |
+
"""Preprocess a single image with noise augmentation"""
|
| 141 |
+
processed = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
|
| 142 |
+
noise = randn_tensor(processed.shape, self.generator, self.device, processed.dtype)
|
| 143 |
+
return processed + NOISE_AUG_STRENGTH * noise
|
| 144 |
+
|
| 145 |
+
def _denoise_loop(
|
| 146 |
+
self,
|
| 147 |
+
latents: torch.Tensor,
|
| 148 |
+
image_latents: torch.Tensor,
|
| 149 |
+
image_embeddings: torch.Tensor,
|
| 150 |
+
added_time_ids: List[torch.Tensor],
|
| 151 |
+
timesteps: torch.Tensor,
|
| 152 |
+
z0_reference_images: Optional[List[torch.Tensor]] = None,
|
| 153 |
+
z0_shape_images: Optional[List[torch.Tensor]] = None,
|
| 154 |
+
refinement: bool = False,
|
| 155 |
+
refine_frames: Optional[list] = None,
|
| 156 |
+
z0_mid_images: Optional[List[torch.Tensor]] = None,
|
| 157 |
+
output_type: str = "pil",
|
| 158 |
+
add_noise: bool = True,
|
| 159 |
+
):
|
| 160 |
+
num_warmup_steps = len(timesteps) - self.num_inference_steps * self.scheduler.order
|
| 161 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 162 |
+
|
| 163 |
+
intermediate_steps = []
|
| 164 |
+
|
| 165 |
+
normal_latents = None
|
| 166 |
+
|
| 167 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float16):
|
| 168 |
+
with self.progress_bar(total=self.num_inference_steps) as progress_bar:
|
| 169 |
+
for i, t in enumerate(timesteps):
|
| 170 |
+
if i in self.replace_reference_steps:
|
| 171 |
+
latents[0] = self.prepare_video_latents(
|
| 172 |
+
z0_reference_images,
|
| 173 |
+
timestep=t.repeat(1),
|
| 174 |
+
add_noise=add_noise,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if refinement and z0_mid_images is not None:
|
| 178 |
+
latents[1] = self.prepare_video_latents(
|
| 179 |
+
z0_mid_images,
|
| 180 |
+
timestep=t.repeat(1),
|
| 181 |
+
add_noise=add_noise,
|
| 182 |
+
refine_frames=refine_frames,
|
| 183 |
+
original_latents=latents[1],
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if refinement and z0_shape_images is not None:
|
| 187 |
+
latents[2] = self.prepare_video_latents(
|
| 188 |
+
z0_shape_images,
|
| 189 |
+
timestep=t.repeat(1),
|
| 190 |
+
add_noise=add_noise,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# expand the latents if we are doing cfg
|
| 194 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 195 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 196 |
+
|
| 197 |
+
# Concatenate image_latents over channels dimension
|
| 198 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 199 |
+
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
print(latent_model_input.shape, t, image_embeddings.shape, added_time_ids[0].shape)
|
| 204 |
+
|
| 205 |
+
# predict the noise residual
|
| 206 |
+
noise_pred = self.unet(
|
| 207 |
+
latent_model_input, # 2/4/6,21,8,72,72
|
| 208 |
+
t, # float
|
| 209 |
+
encoder_hidden_states=image_embeddings, # 42/84/126,1,1024
|
| 210 |
+
added_time_ids=added_time_ids, # 2/4/6,21
|
| 211 |
+
return_dict=False,
|
| 212 |
+
)[0] # 1/2/3,21,4,72,72
|
| 213 |
+
|
| 214 |
+
# perform guidance
|
| 215 |
+
if self.do_classifier_free_guidance:
|
| 216 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 217 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 218 |
+
|
| 219 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 220 |
+
step_output = self.scheduler.step(noise_pred, t, latents) # EulerDiscreteScheduler
|
| 221 |
+
latents = step_output.prev_sample
|
| 222 |
+
normal_latents = step_output.pred_original_sample
|
| 223 |
+
|
| 224 |
+
if self.return_intermediate_steps:
|
| 225 |
+
if needs_upcasting:
|
| 226 |
+
self.vae.to(dtype=torch.float16)
|
| 227 |
+
frames = self.decode_latents(normal_latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
|
| 228 |
+
frames = self.video_processor.postprocess_video(frames, "pil")
|
| 229 |
+
intermediate_steps.append(frames)
|
| 230 |
+
|
| 231 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 232 |
+
progress_bar.update()
|
| 233 |
+
|
| 234 |
+
if not output_type == "latent":
|
| 235 |
+
# cast back to fp16 if needed
|
| 236 |
+
if needs_upcasting:
|
| 237 |
+
self.vae.to(dtype=torch.float16)
|
| 238 |
+
frames = self.decode_latents(latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
|
| 239 |
+
frames = self.video_processor.postprocess_video(frames, output_type)
|
| 240 |
+
else:
|
| 241 |
+
frames = latents
|
| 242 |
+
|
| 243 |
+
self.maybe_free_model_hooks()
|
| 244 |
+
|
| 245 |
+
return frames, intermediate_steps
|
| 246 |
+
|
| 247 |
+
@torch.no_grad()
|
| 248 |
+
def __call__(
|
| 249 |
+
self,
|
| 250 |
+
input_image: PIL.Image.Image,
|
| 251 |
+
reference_images: List[PIL.Image.Image],
|
| 252 |
+
num_inference_steps: int = 25,
|
| 253 |
+
replace_reference_steps: List[int] = list(),
|
| 254 |
+
return_intermediate_steps: bool = False,
|
| 255 |
+
seed: int = 42,
|
| 256 |
+
same_starting_latents: bool = True,
|
| 257 |
+
refinement: bool = False,
|
| 258 |
+
refine_frames: Optional[list] = None,
|
| 259 |
+
add_noise: bool = True,
|
| 260 |
+
):
|
| 261 |
+
# 0. Set seed
|
| 262 |
+
self.generator = torch.manual_seed(seed)
|
| 263 |
+
|
| 264 |
+
# 1. Check inputs. Raise error if not correct
|
| 265 |
+
self.check_inputs(input_image, HEIGHT, WIDTH)
|
| 266 |
+
|
| 267 |
+
# 2. Define call parameters
|
| 268 |
+
self.num_inference_steps = num_inference_steps
|
| 269 |
+
self.return_intermediate_steps = return_intermediate_steps
|
| 270 |
+
self.replace_reference_steps = replace_reference_steps
|
| 271 |
+
self._guidance_scale = MAX_CFG
|
| 272 |
+
|
| 273 |
+
# z0_mid_images = None
|
| 274 |
+
|
| 275 |
+
# 3. Encode input image (CLIP)
|
| 276 |
+
image_embeddings_combined = [
|
| 277 |
+
self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 278 |
+
self._encode_image(input_image, self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 279 |
+
self._encode_image(input_image, self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
|
| 283 |
+
embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5])
|
| 284 |
+
reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
|
| 285 |
+
image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
|
| 286 |
+
|
| 287 |
+
# 4. Encode using VAE
|
| 288 |
+
image = self._preprocess_image(input_image)
|
| 289 |
+
|
| 290 |
+
ref_image = self._preprocess_image(reference_images[-1])
|
| 291 |
+
z0_reference_images = self._preprocess_reference_images(reference_images)
|
| 292 |
+
|
| 293 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 294 |
+
if needs_upcasting:
|
| 295 |
+
self.vae.to(dtype=torch.float32)
|
| 296 |
+
|
| 297 |
+
image_latents = self._encode_vae_image(image, self.device, NUM_VID, self.do_classifier_free_guidance)
|
| 298 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 299 |
+
|
| 300 |
+
ref_image_latents = self._encode_vae_image(ref_image, self.device, NUM_VID, self.do_classifier_free_guidance)
|
| 301 |
+
ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
|
| 302 |
+
|
| 303 |
+
image_latents_full = [
|
| 304 |
+
ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 305 |
+
image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 306 |
+
image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
image_latents = torch.cat(image_latents_full, dim=0)
|
| 310 |
+
image_latents_order = torch.tensor([0, 2, 4, 1, 3, 5])
|
| 311 |
+
image_latents = image_latents[image_latents_order]
|
| 312 |
+
|
| 313 |
+
if needs_upcasting: # cast back to fp16 if needed
|
| 314 |
+
self.vae.to(dtype=torch.float16)
|
| 315 |
+
|
| 316 |
+
num_processes = 3
|
| 317 |
+
|
| 318 |
+
# 5. Get Added Time IDs
|
| 319 |
+
added_time_ids = self._get_add_time_ids(
|
| 320 |
+
image_embeddings.dtype,
|
| 321 |
+
num_processes,
|
| 322 |
+
self.do_classifier_free_guidance,
|
| 323 |
+
) # list of tensor [2, 21] or [4, 21] or [6, 21] -> just 4x the same
|
| 324 |
+
|
| 325 |
+
added_time_ids = [a.to(self.device) for a in added_time_ids]
|
| 326 |
+
|
| 327 |
+
timesteps, self.num_inference_steps = retrieve_timesteps(self.scheduler, self.num_inference_steps, self.device)
|
| 328 |
+
|
| 329 |
+
# 7. Prepare latent variables
|
| 330 |
+
num_channels_latents = self.unet.config.in_channels # 8
|
| 331 |
+
latents = self.prepare_latents(
|
| 332 |
+
BATCH_SIZE * num_processes,
|
| 333 |
+
NUM_FRAMES,
|
| 334 |
+
num_channels_latents,
|
| 335 |
+
HEIGHT,
|
| 336 |
+
WIDTH,
|
| 337 |
+
image_embeddings.dtype,
|
| 338 |
+
self.device,
|
| 339 |
+
self.generator,
|
| 340 |
+
) # 2/3,21,4,72,72
|
| 341 |
+
|
| 342 |
+
if same_starting_latents:
|
| 343 |
+
latents[0] = latents[1] = latents[2]
|
| 344 |
+
|
| 345 |
+
# 8. Prepare guidance scale
|
| 346 |
+
guidance_scale = torch.cat(
|
| 347 |
+
[
|
| 348 |
+
torch.linspace(MIN_CFG, MAX_CFG, NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
|
| 349 |
+
torch.linspace(MAX_CFG, MIN_CFG, NUM_FRAMES - NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
|
| 350 |
+
],
|
| 351 |
+
dim=-1,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
guidance_scale = guidance_scale.to(self.device, latents.dtype)
|
| 355 |
+
guidance_scale = guidance_scale.repeat(BATCH_SIZE, 1)
|
| 356 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim) # [1,21,1,1,1]
|
| 357 |
+
|
| 358 |
+
self._guidance_scale = guidance_scale
|
| 359 |
+
|
| 360 |
+
# 9. Denoising loop
|
| 361 |
+
frames, intemediate_steps = self._denoise_loop(
|
| 362 |
+
latents=latents,
|
| 363 |
+
image_latents=image_latents,
|
| 364 |
+
image_embeddings=image_embeddings,
|
| 365 |
+
added_time_ids=added_time_ids,
|
| 366 |
+
timesteps=timesteps,
|
| 367 |
+
z0_reference_images=z0_reference_images,
|
| 368 |
+
output_type="pil",
|
| 369 |
+
add_noise=add_noise,
|
| 370 |
+
)
|
| 371 |
+
new_front_image = None
|
| 372 |
+
if refinement:
|
| 373 |
+
assert refine_frames is not None
|
| 374 |
+
current_front_frame_idx = refine_frames[-1]
|
| 375 |
+
shift = NUM_FRAMES - current_front_frame_idx
|
| 376 |
+
|
| 377 |
+
mid_images = frames[1]
|
| 378 |
+
shape_images = frames[2]
|
| 379 |
+
|
| 380 |
+
new_front_image = mid_images[current_front_frame_idx]
|
| 381 |
+
|
| 382 |
+
# roll the lists
|
| 383 |
+
reference_images = reference_images[shift:] + reference_images[:shift]
|
| 384 |
+
shape_images = shape_images[shift:] + shape_images[:shift]
|
| 385 |
+
mid_images = mid_images[shift:] + mid_images[:shift]
|
| 386 |
+
|
| 387 |
+
latents = self.prepare_latents(
|
| 388 |
+
BATCH_SIZE * num_processes,
|
| 389 |
+
NUM_FRAMES,
|
| 390 |
+
num_channels_latents,
|
| 391 |
+
HEIGHT,
|
| 392 |
+
WIDTH,
|
| 393 |
+
image_embeddings.dtype,
|
| 394 |
+
self.device,
|
| 395 |
+
self.generator,
|
| 396 |
+
)
|
| 397 |
+
if same_starting_latents:
|
| 398 |
+
latents[0] = latents[1] = latents[2]
|
| 399 |
+
|
| 400 |
+
timesteps, self.num_inference_steps = retrieve_timesteps(
|
| 401 |
+
self.scheduler, self.num_inference_steps, self.device
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
ref_image = self._preprocess_image(z0_reference_images[-1])
|
| 405 |
+
ref_image_latents = self._encode_vae_image(
|
| 406 |
+
ref_image, self.device, NUM_VID, self.do_classifier_free_guidance
|
| 407 |
+
)
|
| 408 |
+
ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
|
| 409 |
+
|
| 410 |
+
mid_image = self._preprocess_image(mid_images[-1])
|
| 411 |
+
mid_image_latents = self._encode_vae_image(
|
| 412 |
+
mid_image, self.device, NUM_VID, self.do_classifier_free_guidance
|
| 413 |
+
)
|
| 414 |
+
mid_image_latents = mid_image_latents.to(image_embeddings.dtype)
|
| 415 |
+
|
| 416 |
+
shape_image = self._preprocess_image(shape_images[-1])
|
| 417 |
+
shape_image_latents = self._encode_vae_image(
|
| 418 |
+
shape_image, self.device, NUM_VID, self.do_classifier_free_guidance
|
| 419 |
+
)
|
| 420 |
+
shape_image_latents = shape_image_latents.to(image_embeddings.dtype)
|
| 421 |
+
|
| 422 |
+
image_latents_full = [
|
| 423 |
+
ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 424 |
+
mid_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 425 |
+
shape_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
image_latents = torch.cat(image_latents_full, dim=0)
|
| 429 |
+
image_latents = image_latents[image_latents_order]
|
| 430 |
+
|
| 431 |
+
# CLIP embeddings on the new front frame
|
| 432 |
+
image_embeddings_combined = [
|
| 433 |
+
self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 434 |
+
self._encode_image(mid_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 435 |
+
self._encode_image(shape_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 436 |
+
]
|
| 437 |
+
all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
|
| 438 |
+
embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5])
|
| 439 |
+
reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
|
| 440 |
+
image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
|
| 441 |
+
|
| 442 |
+
z0_mid_images = self._preprocess_reference_images(mid_images)
|
| 443 |
+
z0_shape_images = self._preprocess_reference_images(shape_images)
|
| 444 |
+
|
| 445 |
+
frames, intemediate_steps = self._denoise_loop(
|
| 446 |
+
latents=latents,
|
| 447 |
+
image_latents=image_latents,
|
| 448 |
+
image_embeddings=image_embeddings,
|
| 449 |
+
added_time_ids=added_time_ids,
|
| 450 |
+
timesteps=timesteps,
|
| 451 |
+
z0_reference_images=z0_reference_images,
|
| 452 |
+
z0_shape_images=z0_shape_images,
|
| 453 |
+
refinement=refinement,
|
| 454 |
+
refine_frames=refine_frames,
|
| 455 |
+
z0_mid_images=z0_mid_images,
|
| 456 |
+
add_noise=add_noise,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Roll back frames to original order
|
| 460 |
+
frames = [
|
| 461 |
+
frames[0][(-shift):] + frames[0][:-shift],
|
| 462 |
+
frames[1][(-shift):] + frames[1][:-shift],
|
| 463 |
+
frames[2][(-shift):] + frames[2][:-shift],
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
if return_intermediate_steps:
|
| 467 |
+
return frames, new_front_image, intemediate_steps
|
| 468 |
+
|
| 469 |
+
return frames, new_front_image, None
|
diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion_rotate.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import PIL.Image
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 6 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
| 7 |
+
StableVideoDiffusionPipeline,
|
| 8 |
+
_append_dims,
|
| 9 |
+
randn_tensor,
|
| 10 |
+
retrieve_timesteps,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from self_attn_swap import ACTIVATE_LAYER_CANDIDATE_SV3D, SharedAttentionProcessorThree
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
HEIGHT = 576
|
| 17 |
+
WIDTH = 576
|
| 18 |
+
NUM_FRAMES = 21
|
| 19 |
+
NOISE_AUG_STRENGTH = 1e-5
|
| 20 |
+
DECODE_CHUNK_SIZE = 2
|
| 21 |
+
NUM_VID = 1
|
| 22 |
+
GENERATOR = torch.manual_seed(42)
|
| 23 |
+
OUTPUT_TYPE = "pil"
|
| 24 |
+
BATCH_SIZE = 1
|
| 25 |
+
MIN_CFG = 1.0
|
| 26 |
+
MAX_CFG = 2.5
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class StableVideo3DDiffusionPipelineRotate(StableVideoDiffusionPipeline):
|
| 30 |
+
def __init__(self, vae, image_encoder, unet, scheduler, feature_extractor):
|
| 31 |
+
super().__init__(vae, image_encoder, unet, scheduler, feature_extractor)
|
| 32 |
+
|
| 33 |
+
def _get_add_time_ids(
|
| 34 |
+
self, dtype: torch.dtype, num_processes, do_classifier_free_guidance: bool
|
| 35 |
+
) -> List[torch.Tensor]:
|
| 36 |
+
cond_aug = torch.tensor([NOISE_AUG_STRENGTH] * 21, dtype=dtype).repeat(BATCH_SIZE * num_processes, 1)
|
| 37 |
+
|
| 38 |
+
if do_classifier_free_guidance:
|
| 39 |
+
cond_aug = torch.cat([cond_aug, cond_aug])
|
| 40 |
+
|
| 41 |
+
add_time_ids = [cond_aug]
|
| 42 |
+
|
| 43 |
+
self.unet.to(dtype=torch.float16)
|
| 44 |
+
self.vae.to(dtype=torch.float16)
|
| 45 |
+
|
| 46 |
+
return add_time_ids
|
| 47 |
+
|
| 48 |
+
def prepare_video_latents(
|
| 49 |
+
self,
|
| 50 |
+
images: List[torch.Tensor],
|
| 51 |
+
timestep: torch.Tensor,
|
| 52 |
+
add_noise: bool = True,
|
| 53 |
+
active_size: Optional[int] = None,
|
| 54 |
+
original_latents: Optional[torch.Tensor] = None,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
"""Prepare video latents by encoding frames and optionally adding noise."""
|
| 57 |
+
encoded_frames = [self._encode_vae_image(image, self.device, NUM_VID, False) for image in images]
|
| 58 |
+
encoded_frames = [frame.to(images[0].dtype) for frame in encoded_frames]
|
| 59 |
+
|
| 60 |
+
# TODO: check scaling factor?
|
| 61 |
+
encoded_frames = [self.vae.config.scaling_factor * frame for frame in encoded_frames]
|
| 62 |
+
|
| 63 |
+
# add noise
|
| 64 |
+
if add_noise:
|
| 65 |
+
video_latents = [
|
| 66 |
+
self.scheduler.add_noise(
|
| 67 |
+
frame,
|
| 68 |
+
randn_tensor(encoded_frames[0].shape, GENERATOR, self.device, images[0].dtype),
|
| 69 |
+
timestep,
|
| 70 |
+
)
|
| 71 |
+
for frame in encoded_frames
|
| 72 |
+
]
|
| 73 |
+
else:
|
| 74 |
+
video_latents = encoded_frames
|
| 75 |
+
|
| 76 |
+
if active_size is not None and original_latents is not None:
|
| 77 |
+
for i in range(len(video_latents)):
|
| 78 |
+
if NUM_FRAMES - active_size - 1 <= i < NUM_FRAMES - 1:
|
| 79 |
+
video_latents[i] = original_latents[i].unsqueeze(0)
|
| 80 |
+
|
| 81 |
+
return torch.stack(video_latents, dim=1)
|
| 82 |
+
|
| 83 |
+
def activate_layers(self, config: Dict[str, List[Union[float, int]]]) -> Dict[str, AttnProcessor2_0]:
|
| 84 |
+
"""Activate swapping attention mechanism in specific UNet layers."""
|
| 85 |
+
spatial_attn = [layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".transformer_blocks.0.attn1" in layer]
|
| 86 |
+
temporal_attn = [
|
| 87 |
+
layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".temporal_transformer_blocks.0.attn1" in layer
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
assert len(spatial_attn) == len(config["spatial_ratio"]) == len(config["spatial_strength"])
|
| 91 |
+
assert len(temporal_attn) == len(config["temporal_ratio"]) == len(config["temporal_strength"])
|
| 92 |
+
|
| 93 |
+
ratios = {}
|
| 94 |
+
for layer, ratio, strength in zip(spatial_attn, config["spatial_ratio"], config["spatial_strength"]):
|
| 95 |
+
ratios[layer] = {"ratio": ratio, "strength": strength}
|
| 96 |
+
|
| 97 |
+
for layer, ratio, strength in zip(temporal_attn, config["temporal_ratio"], config["temporal_strength"]):
|
| 98 |
+
ratios[layer] = {"ratio": ratio, "strength": strength}
|
| 99 |
+
|
| 100 |
+
attn_procs = {}
|
| 101 |
+
|
| 102 |
+
for layer in self.unet.attn_processors.keys():
|
| 103 |
+
if layer in ratios:
|
| 104 |
+
attn_procs[layer] = SharedAttentionProcessorThree(
|
| 105 |
+
unet_chunk_size=2, activate_step_indices=config["activate_steps"], ratio=ratios[layer]
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
attn_procs[layer] = AttnProcessor2_0()
|
| 109 |
+
|
| 110 |
+
self.unet.set_attn_processor(attn_procs)
|
| 111 |
+
|
| 112 |
+
return attn_procs
|
| 113 |
+
|
| 114 |
+
def _decode_vae_frames(self, image_latents: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
frames = []
|
| 116 |
+
for i in range(21):
|
| 117 |
+
frame = self.vae.decode(image_latents[:, i], self.device).sample
|
| 118 |
+
frames.append(frame)
|
| 119 |
+
return torch.stack(frames, dim=2)
|
| 120 |
+
|
| 121 |
+
def _preprocess_reference_images(self, reference_images: List[PIL.Image.Image]) -> List[torch.Tensor]:
|
| 122 |
+
"""Helper method to preprocess reference images consistently"""
|
| 123 |
+
processed_images = []
|
| 124 |
+
for image in reference_images:
|
| 125 |
+
ref_image = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
|
| 126 |
+
ref_noise = randn_tensor(ref_image.shape, GENERATOR, self.device, ref_image.dtype)
|
| 127 |
+
ref_image = ref_image + NOISE_AUG_STRENGTH * ref_noise
|
| 128 |
+
processed_images.append(ref_image)
|
| 129 |
+
return processed_images
|
| 130 |
+
|
| 131 |
+
def _preprocess_image(self, image: Union[PIL.Image.Image, torch.Tensor]) -> torch.Tensor:
|
| 132 |
+
"""Preprocess a single image with noise augmentation"""
|
| 133 |
+
processed = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
|
| 134 |
+
noise = randn_tensor(processed.shape, GENERATOR, self.device, processed.dtype)
|
| 135 |
+
return processed + NOISE_AUG_STRENGTH * noise
|
| 136 |
+
|
| 137 |
+
def _denoise_loop(
|
| 138 |
+
self,
|
| 139 |
+
latents: torch.Tensor,
|
| 140 |
+
image_latents: torch.Tensor,
|
| 141 |
+
image_embeddings: torch.Tensor,
|
| 142 |
+
added_time_ids: List[torch.Tensor],
|
| 143 |
+
timesteps: torch.Tensor,
|
| 144 |
+
mids_active_size: int,
|
| 145 |
+
z0_mid_images: List[torch.Tensor],
|
| 146 |
+
z0_reference_images: Optional[List[torch.Tensor]] = None,
|
| 147 |
+
z0_shape_images: Optional[List[torch.Tensor]] = None,
|
| 148 |
+
):
|
| 149 |
+
num_warmup_steps = len(timesteps) - self.num_inference_steps * self.scheduler.order
|
| 150 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 151 |
+
|
| 152 |
+
intermediate_steps = []
|
| 153 |
+
|
| 154 |
+
normal_latents = None
|
| 155 |
+
|
| 156 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float16):
|
| 157 |
+
with self.progress_bar(total=self.num_inference_steps) as progress_bar:
|
| 158 |
+
for i, t in enumerate(timesteps):
|
| 159 |
+
if i in self.replace_reference_steps:
|
| 160 |
+
latents[0] = self.prepare_video_latents(
|
| 161 |
+
z0_reference_images,
|
| 162 |
+
timestep=t.repeat(1),
|
| 163 |
+
add_noise=True,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
latents[1] = self.prepare_video_latents(
|
| 167 |
+
z0_mid_images,
|
| 168 |
+
timestep=t.repeat(1),
|
| 169 |
+
add_noise=True,
|
| 170 |
+
active_size=mids_active_size if i > 5 else None,
|
| 171 |
+
original_latents=latents[1],
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if z0_shape_images is not None:
|
| 175 |
+
latents[2] = self.prepare_video_latents(
|
| 176 |
+
z0_shape_images,
|
| 177 |
+
timestep=t.repeat(1),
|
| 178 |
+
add_noise=True,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# expand the latents if we are doing cfg
|
| 182 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 183 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 184 |
+
|
| 185 |
+
# Concatenate image_latents over channels dimension
|
| 186 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 187 |
+
|
| 188 |
+
torch.cuda.empty_cache()
|
| 189 |
+
|
| 190 |
+
# predict the noise residual
|
| 191 |
+
noise_pred = self.unet(
|
| 192 |
+
latent_model_input, # 2/4/6,21,8,72,72
|
| 193 |
+
t, # float
|
| 194 |
+
encoder_hidden_states=image_embeddings, # 42/84/126,1,1024
|
| 195 |
+
added_time_ids=added_time_ids, # 2/4/6,21
|
| 196 |
+
return_dict=False,
|
| 197 |
+
)[0] # 1/2/3,21,4,72,72
|
| 198 |
+
|
| 199 |
+
# perform guidance
|
| 200 |
+
if self.do_classifier_free_guidance:
|
| 201 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 202 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 203 |
+
|
| 204 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 205 |
+
step_output = self.scheduler.step(noise_pred, t, latents) # EulerDiscreteScheduler
|
| 206 |
+
latents = step_output.prev_sample
|
| 207 |
+
normal_latents = step_output.pred_original_sample
|
| 208 |
+
|
| 209 |
+
if self.return_intermediate_steps:
|
| 210 |
+
if needs_upcasting:
|
| 211 |
+
self.vae.to(dtype=torch.float16)
|
| 212 |
+
frames = self.decode_latents(normal_latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
|
| 213 |
+
frames = self.video_processor.postprocess_video(frames, OUTPUT_TYPE)
|
| 214 |
+
intermediate_steps.append(frames)
|
| 215 |
+
|
| 216 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 217 |
+
progress_bar.update()
|
| 218 |
+
|
| 219 |
+
if not OUTPUT_TYPE == "latent":
|
| 220 |
+
# cast back to fp16 if needed
|
| 221 |
+
if needs_upcasting:
|
| 222 |
+
self.vae.to(dtype=torch.float16)
|
| 223 |
+
frames = self.decode_latents(latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
|
| 224 |
+
frames = self.video_processor.postprocess_video(frames, OUTPUT_TYPE)
|
| 225 |
+
else:
|
| 226 |
+
frames = latents
|
| 227 |
+
|
| 228 |
+
self.maybe_free_model_hooks()
|
| 229 |
+
|
| 230 |
+
return frames, intermediate_steps
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def __call__(
|
| 234 |
+
self,
|
| 235 |
+
mid_images: List[PIL.Image.Image],
|
| 236 |
+
reference_images: List[PIL.Image.Image],
|
| 237 |
+
shape_images: Optional[List[PIL.Image.Image]] = None,
|
| 238 |
+
num_inference_steps: int = 25,
|
| 239 |
+
replace_reference_steps: List[int] = list(),
|
| 240 |
+
return_intermediate_steps: bool = False,
|
| 241 |
+
mids_active_size: int = 5,
|
| 242 |
+
):
|
| 243 |
+
# 1. Check inputs. Raise error if not correct
|
| 244 |
+
self.check_inputs(mid_images[-1], HEIGHT, WIDTH)
|
| 245 |
+
|
| 246 |
+
# 2. Define call parameters
|
| 247 |
+
self.num_inference_steps = num_inference_steps
|
| 248 |
+
self.return_intermediate_steps = return_intermediate_steps
|
| 249 |
+
self.replace_reference_steps = replace_reference_steps
|
| 250 |
+
self._guidance_scale = MAX_CFG
|
| 251 |
+
|
| 252 |
+
# 3. Encode input image (CLIP)
|
| 253 |
+
image_embeddings_combined = [
|
| 254 |
+
self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 255 |
+
self._encode_image(mid_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
|
| 256 |
+
]
|
| 257 |
+
if shape_images is not None:
|
| 258 |
+
image_embeddings_combined.append(
|
| 259 |
+
self._encode_image(shape_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance)
|
| 260 |
+
)
|
| 261 |
+
all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
|
| 262 |
+
embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5]) if shape_images else torch.tensor([0, 2, 1, 3])
|
| 263 |
+
reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
|
| 264 |
+
image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
|
| 265 |
+
|
| 266 |
+
# 4. Encode using VAE
|
| 267 |
+
image = self._preprocess_image(mid_images[-1])
|
| 268 |
+
ref_image = self._preprocess_image(reference_images[-1])
|
| 269 |
+
|
| 270 |
+
z0_reference_images = self._preprocess_reference_images(reference_images)
|
| 271 |
+
z0_mid_images = self._preprocess_reference_images(mid_images)
|
| 272 |
+
|
| 273 |
+
if shape_images is not None:
|
| 274 |
+
shape_image = self._preprocess_image(shape_images[-1])
|
| 275 |
+
z0_shape_images = self._preprocess_reference_images(
|
| 276 |
+
shape_images,
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
shape_image = None
|
| 280 |
+
z0_shape_images = None
|
| 281 |
+
|
| 282 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 283 |
+
if needs_upcasting:
|
| 284 |
+
self.vae.to(dtype=torch.float32)
|
| 285 |
+
|
| 286 |
+
image_latents = self._encode_vae_image(image, self.device, NUM_VID, self.do_classifier_free_guidance)
|
| 287 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 288 |
+
|
| 289 |
+
ref_image_latents = self._encode_vae_image(ref_image, self.device, NUM_VID, self.do_classifier_free_guidance)
|
| 290 |
+
ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
|
| 291 |
+
|
| 292 |
+
if shape_images is not None:
|
| 293 |
+
shape_image_latents = self._encode_vae_image(
|
| 294 |
+
shape_image, self.device, NUM_VID, self.do_classifier_free_guidance
|
| 295 |
+
)
|
| 296 |
+
shape_image_latents = shape_image_latents.to(image_embeddings.dtype)
|
| 297 |
+
|
| 298 |
+
image_latents_full = [
|
| 299 |
+
ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 300 |
+
image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
if shape_images is not None:
|
| 304 |
+
shape_image_latents = shape_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1)
|
| 305 |
+
image_latents_full.append(shape_image_latents)
|
| 306 |
+
|
| 307 |
+
image_latents = torch.cat(image_latents_full, dim=0)
|
| 308 |
+
image_latents_order = torch.tensor([0, 2, 4, 1, 3, 5]) if shape_images else torch.tensor([0, 2, 1, 3])
|
| 309 |
+
image_latents = image_latents[image_latents_order]
|
| 310 |
+
|
| 311 |
+
if needs_upcasting: # cast back to fp16 if needed
|
| 312 |
+
self.vae.to(dtype=torch.float16)
|
| 313 |
+
|
| 314 |
+
num_processes = 2 if shape_images is None else 3
|
| 315 |
+
|
| 316 |
+
# 5. Get Added Time IDs
|
| 317 |
+
added_time_ids = self._get_add_time_ids(
|
| 318 |
+
image_embeddings.dtype,
|
| 319 |
+
num_processes,
|
| 320 |
+
self.do_classifier_free_guidance,
|
| 321 |
+
) # list of tensor [2, 21] or [4, 21] or [6, 21] -> just 4x the same
|
| 322 |
+
|
| 323 |
+
added_time_ids = [a.to(self.device) for a in added_time_ids]
|
| 324 |
+
|
| 325 |
+
timesteps, self.num_inference_steps = retrieve_timesteps(self.scheduler, self.num_inference_steps, self.device)
|
| 326 |
+
|
| 327 |
+
# 7. Prepare latent variables
|
| 328 |
+
num_channels_latents = self.unet.config.in_channels # 8
|
| 329 |
+
latents = self.prepare_latents(
|
| 330 |
+
BATCH_SIZE * num_processes,
|
| 331 |
+
NUM_FRAMES,
|
| 332 |
+
num_channels_latents,
|
| 333 |
+
HEIGHT,
|
| 334 |
+
WIDTH,
|
| 335 |
+
image_embeddings.dtype,
|
| 336 |
+
self.device,
|
| 337 |
+
GENERATOR,
|
| 338 |
+
) # 2/3,21,4,72,72
|
| 339 |
+
|
| 340 |
+
# 8. Prepare guidance scale
|
| 341 |
+
guidance_scale = torch.cat(
|
| 342 |
+
[
|
| 343 |
+
torch.linspace(MIN_CFG, MAX_CFG, NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
|
| 344 |
+
torch.linspace(MAX_CFG, MIN_CFG, NUM_FRAMES - NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
|
| 345 |
+
],
|
| 346 |
+
dim=-1,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
guidance_scale = guidance_scale.to(self.device, latents.dtype)
|
| 350 |
+
guidance_scale = guidance_scale.repeat(BATCH_SIZE, 1)
|
| 351 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim) # [1,21,1,1,1]
|
| 352 |
+
|
| 353 |
+
self._guidance_scale = guidance_scale
|
| 354 |
+
|
| 355 |
+
# 9. Denoising loop
|
| 356 |
+
frames, intemediate_steps = self._denoise_loop(
|
| 357 |
+
latents=latents,
|
| 358 |
+
image_latents=image_latents,
|
| 359 |
+
image_embeddings=image_embeddings,
|
| 360 |
+
added_time_ids=added_time_ids,
|
| 361 |
+
timesteps=timesteps,
|
| 362 |
+
mids_active_size=mids_active_size,
|
| 363 |
+
z0_mid_images=z0_mid_images,
|
| 364 |
+
z0_reference_images=z0_reference_images,
|
| 365 |
+
z0_shape_images=z0_shape_images,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if return_intermediate_steps:
|
| 369 |
+
return frames, intemediate_steps
|
| 370 |
+
|
| 371 |
+
return frames
|
pretrained_sv3d/feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"image_mean": [
|
| 12 |
+
0.48145466,
|
| 13 |
+
0.4578275,
|
| 14 |
+
0.40821073
|
| 15 |
+
],
|
| 16 |
+
"image_processor_type": "CLIPImageProcessor",
|
| 17 |
+
"image_std": [
|
| 18 |
+
0.26862954,
|
| 19 |
+
0.26130258,
|
| 20 |
+
0.27577711
|
| 21 |
+
],
|
| 22 |
+
"resample": 3,
|
| 23 |
+
"rescale_factor": 0.00392156862745098,
|
| 24 |
+
"size": {
|
| 25 |
+
"shortest_edge": 224
|
| 26 |
+
}
|
| 27 |
+
}
|
pretrained_sv3d/image_encoder/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "stabilityai/stable-video-diffusion-img2vid-xt",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPVisionModelWithProjection"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_size": 1280,
|
| 10 |
+
"image_size": 224,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 5120,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"model_type": "clip_vision_model",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_channels": 3,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"patch_size": 14,
|
| 20 |
+
"projection_dim": 1024,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.45.2"
|
| 23 |
+
}
|
pretrained_sv3d/image_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed1e5af7b4042ca30ec29999a4a5cfcac90b7fb610fd05ace834f2dcbb763eab
|
| 3 |
+
size 2528371296
|
pretrained_sv3d/model_index.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37fe3c7758e588c386817b6e681f2aaa7bc8c212d628b7c36f758e0a6d972e29
|
| 3 |
+
size 492
|
pretrained_sv3d/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.30.3",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"final_sigmas_type": "zero",
|
| 9 |
+
"interpolation_type": "linear",
|
| 10 |
+
"num_train_timesteps": 1000,
|
| 11 |
+
"prediction_type": "v_prediction",
|
| 12 |
+
"rescale_betas_zero_snr": false,
|
| 13 |
+
"set_alpha_to_one": false,
|
| 14 |
+
"sigma_max": 700.0,
|
| 15 |
+
"sigma_min": 0.002,
|
| 16 |
+
"skip_prk_steps": true,
|
| 17 |
+
"steps_offset": 1,
|
| 18 |
+
"timestep_spacing": "leading",
|
| 19 |
+
"timestep_type": "continuous",
|
| 20 |
+
"trained_betas": null,
|
| 21 |
+
"use_karras_sigmas": true
|
| 22 |
+
}
|
pretrained_sv3d/unet/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SV3DUNetSpatioTemporalConditionModel",
|
| 3 |
+
"_diffusers_version": "0.30.3",
|
| 4 |
+
"addition_time_embed_dim": 256,
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
320,
|
| 7 |
+
640,
|
| 8 |
+
1280,
|
| 9 |
+
1280
|
| 10 |
+
],
|
| 11 |
+
"cross_attention_dim": 1024,
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 14 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 15 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 16 |
+
"DownBlockSpatioTemporal"
|
| 17 |
+
],
|
| 18 |
+
"in_channels": 8,
|
| 19 |
+
"layers_per_block": 2,
|
| 20 |
+
"num_attention_heads": [
|
| 21 |
+
5,
|
| 22 |
+
10,
|
| 23 |
+
20,
|
| 24 |
+
20
|
| 25 |
+
],
|
| 26 |
+
"num_frames": 25,
|
| 27 |
+
"out_channels": 4,
|
| 28 |
+
"projection_class_embeddings_input_dim": 256,
|
| 29 |
+
"sample_size": 72,
|
| 30 |
+
"transformer_layers_per_block": 1,
|
| 31 |
+
"up_block_types": [
|
| 32 |
+
"UpBlockSpatioTemporal",
|
| 33 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 34 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 35 |
+
"CrossAttnUpBlockSpatioTemporal"
|
| 36 |
+
]
|
| 37 |
+
}
|
pretrained_sv3d/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00d35a0c7e024ebc55feeecf55baa039700f3d2b2d396e58d7cd0e6bbb18eedd
|
| 3 |
+
size 6096060984
|
pretrained_sv3d/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.30.3",
|
| 4 |
+
"_name_or_path": "chenguolin/stable-diffusion-v1-5",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 512,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
pretrained_sv3d/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4d2b5932bb4151e54e694fd31ccf51fca908223c9485bd56cd0e1d83ad94c49
|
| 3 |
+
size 334643268
|
train.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import (
|
| 8 |
+
StableVideo3DDiffusionPipeline,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Configuration
|
| 12 |
+
BATCH_SIZE = 1
|
| 13 |
+
LR = 1e-5
|
| 14 |
+
NUM_EPOCHS = 10
|
| 15 |
+
SAVE_DIR = "checkpoints"
|
| 16 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def train():
|
| 21 |
+
# Create directories
|
| 22 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# Create pipeline
|
| 25 |
+
pipeline = StableVideo3DDiffusionPipeline.from_pretrained(
|
| 26 |
+
SV3D_PATH,
|
| 27 |
+
revision="fp16",
|
| 28 |
+
torch_dtype=torch.float16,
|
| 29 |
+
)
|
| 30 |
+
pipeline.to(DEVICE)
|
| 31 |
+
|
| 32 |
+
# freeze unet parts - freeze everything first
|
| 33 |
+
for param in pipeline.unet.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
# unfreeze only one specific layer (for example, the last output block)
|
| 37 |
+
for name, param in pipeline.unet.named_parameters():
|
| 38 |
+
if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name:
|
| 39 |
+
param.requires_grad = True
|
| 40 |
+
print(f"Unfreezing: {name}")
|
| 41 |
+
|
| 42 |
+
# Count trainable parameters
|
| 43 |
+
trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad)
|
| 44 |
+
total_params = sum(p.numel() for p in pipeline.unet.parameters())
|
| 45 |
+
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})")
|
| 46 |
+
|
| 47 |
+
# Setup optimizer - only train unfrozen parameters
|
| 48 |
+
optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR)
|
| 49 |
+
|
| 50 |
+
# Training loop
|
| 51 |
+
for epoch in range(NUM_EPOCHS):
|
| 52 |
+
pipeline.unet.train()
|
| 53 |
+
|
| 54 |
+
# Prepare for backward pass
|
| 55 |
+
optimizer.zero_grad()
|
| 56 |
+
|
| 57 |
+
latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
|
| 58 |
+
t = 0.123
|
| 59 |
+
encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE)
|
| 60 |
+
added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE)
|
| 61 |
+
target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
|
| 62 |
+
|
| 63 |
+
noise_pred = pipeline.unet(
|
| 64 |
+
latents,
|
| 65 |
+
t,
|
| 66 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 67 |
+
added_time_ids=[added_tim_ids],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
print(noise_pred.shape)
|
| 71 |
+
# loss = F.mse_loss(noise_pred, target_noise)
|
| 72 |
+
# Backward pass and optimizer step
|
| 73 |
+
# loss.backward()
|
| 74 |
+
# optimizer.step()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
train()
|