hkomp commited on
Commit
a13d12f
·
1 Parent(s): a00b35e

Add model and code

Browse files
Files changed (26) hide show
  1. diffusers_sv3d/__init__.py +2 -0
  2. diffusers_sv3d/__pycache__/__init__.cpython-311.pyc +0 -0
  3. diffusers_sv3d/models/__init__.py +1 -0
  4. diffusers_sv3d/models/__pycache__/__init__.cpython-311.pyc +0 -0
  5. diffusers_sv3d/models/unets/__init__.py +1 -0
  6. diffusers_sv3d/models/unets/__pycache__/__init__.cpython-311.pyc +0 -0
  7. diffusers_sv3d/models/unets/__pycache__/unet_spatio_temporal_condition.cpython-311.pyc +0 -0
  8. diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py +483 -0
  9. diffusers_sv3d/pipelines/__init__.py +1 -0
  10. diffusers_sv3d/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
  11. diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py +2 -0
  12. diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  13. diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion.cpython-311.pyc +0 -0
  14. diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion_rotate.cpython-311.pyc +0 -0
  15. diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py +469 -0
  16. diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion_rotate.py +371 -0
  17. pretrained_sv3d/feature_extractor/preprocessor_config.json +27 -0
  18. pretrained_sv3d/image_encoder/config.json +23 -0
  19. pretrained_sv3d/image_encoder/model.safetensors +3 -0
  20. pretrained_sv3d/model_index.json +3 -0
  21. pretrained_sv3d/scheduler/scheduler_config.json +22 -0
  22. pretrained_sv3d/unet/config.json +37 -0
  23. pretrained_sv3d/unet/diffusion_pytorch_model.safetensors +3 -0
  24. pretrained_sv3d/vae/config.json +38 -0
  25. pretrained_sv3d/vae/diffusion_pytorch_model.safetensors +3 -0
  26. train.py +79 -0
diffusers_sv3d/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models import SV3DUNetSpatioTemporalConditionModel
2
+ from .pipelines import StableVideo3DDiffusionPipeline, StableVideo3DDiffusionPipelineRotate
diffusers_sv3d/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (378 Bytes). View file
 
diffusers_sv3d/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unets import SV3DUNetSpatioTemporalConditionModel
diffusers_sv3d/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (251 Bytes). View file
 
diffusers_sv3d/models/unets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unet_spatio_temporal_condition import SV3DUNetSpatioTemporalConditionModel
diffusers_sv3d/models/unets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (282 Bytes). View file
 
diffusers_sv3d/models/unets/__pycache__/unet_spatio_temporal_condition.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ from diffusers.models.unets.unet_spatio_temporal_condition import *
4
+
5
+
6
+ # Copied from diffusers.models.unets.unet_spatio_temporal_condition UNetSpatioTemporalConditionModel
7
+ class SV3DUNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
8
+ r"""
9
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
10
+ returns a sample shaped output.
11
+
12
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
13
+ for all models (such as downloading or saving).
14
+
15
+ Parameters:
16
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
17
+ Height and width of input/output sample.
18
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
19
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
20
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
21
+ The tuple of downsample blocks to use.
22
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
23
+ The tuple of upsample blocks to use.
24
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
25
+ The tuple of output channels for each block.
26
+ addition_time_embed_dim: (`int`, defaults to 256):
27
+ Dimension to to encode the additional time ids.
28
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
29
+ The dimension of the projection of encoded `added_time_ids`.
30
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
31
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
32
+ The dimension of the cross attention features.
33
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
34
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
35
+ [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
36
+ [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
37
+ [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
38
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
39
+ The number of attention heads.
40
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
41
+ """
42
+
43
+ _supports_gradient_checkpointing = True
44
+
45
+ @register_to_config
46
+ def __init__(
47
+ self,
48
+ sample_size: Optional[int] = None,
49
+ in_channels: int = 8,
50
+ out_channels: int = 4,
51
+ down_block_types: Tuple[str] = (
52
+ "CrossAttnDownBlockSpatioTemporal",
53
+ "CrossAttnDownBlockSpatioTemporal",
54
+ "CrossAttnDownBlockSpatioTemporal",
55
+ "DownBlockSpatioTemporal",
56
+ ),
57
+ up_block_types: Tuple[str] = (
58
+ "UpBlockSpatioTemporal",
59
+ "CrossAttnUpBlockSpatioTemporal",
60
+ "CrossAttnUpBlockSpatioTemporal",
61
+ "CrossAttnUpBlockSpatioTemporal",
62
+ ),
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ addition_time_embed_dim: int = 256,
65
+ projection_class_embeddings_input_dim: int = 768,
66
+ layers_per_block: Union[int, Tuple[int]] = 2,
67
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
68
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
69
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
70
+ num_frames: int = 25,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.sample_size = sample_size
75
+
76
+ # Check inputs
77
+ if len(down_block_types) != len(up_block_types):
78
+ raise ValueError(
79
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
80
+ )
81
+
82
+ if len(block_out_channels) != len(down_block_types):
83
+ raise ValueError(
84
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
85
+ )
86
+
87
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
88
+ raise ValueError(
89
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
90
+ )
91
+
92
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
93
+ raise ValueError(
94
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
95
+ )
96
+
97
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
98
+ raise ValueError(
99
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
100
+ )
101
+
102
+ # input
103
+ self.conv_in = nn.Conv2d(
104
+ in_channels,
105
+ block_out_channels[0],
106
+ kernel_size=3,
107
+ padding=1,
108
+ )
109
+
110
+ # time
111
+ time_embed_dim = block_out_channels[0] * 4
112
+
113
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
114
+ timestep_input_dim = block_out_channels[0]
115
+
116
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
117
+
118
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
119
+ self.add_angle_proj = Timesteps(2*addition_time_embed_dim, True, downscale_freq_shift=0) # encode camera angles
120
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
121
+
122
+ self.down_blocks = nn.ModuleList([])
123
+ self.up_blocks = nn.ModuleList([])
124
+
125
+ if isinstance(num_attention_heads, int):
126
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
127
+
128
+ if isinstance(cross_attention_dim, int):
129
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
130
+
131
+ if isinstance(layers_per_block, int):
132
+ layers_per_block = [layers_per_block] * len(down_block_types)
133
+
134
+ if isinstance(transformer_layers_per_block, int):
135
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
136
+
137
+ blocks_time_embed_dim = time_embed_dim
138
+
139
+ # down
140
+ output_channel = block_out_channels[0]
141
+ for i, down_block_type in enumerate(down_block_types):
142
+ input_channel = output_channel
143
+ output_channel = block_out_channels[i]
144
+ is_final_block = i == len(block_out_channels) - 1
145
+
146
+ down_block = get_down_block(
147
+ down_block_type,
148
+ num_layers=layers_per_block[i],
149
+ transformer_layers_per_block=transformer_layers_per_block[i],
150
+ in_channels=input_channel,
151
+ out_channels=output_channel,
152
+ temb_channels=blocks_time_embed_dim,
153
+ add_downsample=not is_final_block,
154
+ resnet_eps=1e-5,
155
+ cross_attention_dim=cross_attention_dim[i],
156
+ num_attention_heads=num_attention_heads[i],
157
+ resnet_act_fn="silu",
158
+ )
159
+ self.down_blocks.append(down_block)
160
+
161
+ # mid
162
+ self.mid_block = UNetMidBlockSpatioTemporal(
163
+ block_out_channels[-1],
164
+ temb_channels=blocks_time_embed_dim,
165
+ transformer_layers_per_block=transformer_layers_per_block[-1],
166
+ cross_attention_dim=cross_attention_dim[-1],
167
+ num_attention_heads=num_attention_heads[-1],
168
+ )
169
+
170
+ # count how many layers upsample the images
171
+ self.num_upsamplers = 0
172
+
173
+ # up
174
+ reversed_block_out_channels = list(reversed(block_out_channels))
175
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
176
+ reversed_layers_per_block = list(reversed(layers_per_block))
177
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
178
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
179
+
180
+ output_channel = reversed_block_out_channels[0]
181
+ for i, up_block_type in enumerate(up_block_types):
182
+ is_final_block = i == len(block_out_channels) - 1
183
+
184
+ prev_output_channel = output_channel
185
+ output_channel = reversed_block_out_channels[i]
186
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
187
+
188
+ # add upsample block for all BUT final layer
189
+ if not is_final_block:
190
+ add_upsample = True
191
+ self.num_upsamplers += 1
192
+ else:
193
+ add_upsample = False
194
+
195
+ up_block = get_up_block(
196
+ up_block_type,
197
+ num_layers=reversed_layers_per_block[i] + 1,
198
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
199
+ in_channels=input_channel,
200
+ out_channels=output_channel,
201
+ prev_output_channel=prev_output_channel,
202
+ temb_channels=blocks_time_embed_dim,
203
+ add_upsample=add_upsample,
204
+ resnet_eps=1e-5,
205
+ resolution_idx=i,
206
+ cross_attention_dim=reversed_cross_attention_dim[i],
207
+ num_attention_heads=reversed_num_attention_heads[i],
208
+ resnet_act_fn="silu",
209
+ )
210
+ self.up_blocks.append(up_block)
211
+ prev_output_channel = output_channel
212
+
213
+ # out
214
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
215
+ self.conv_act = nn.SiLU()
216
+
217
+ self.conv_out = nn.Conv2d(
218
+ block_out_channels[0],
219
+ out_channels,
220
+ kernel_size=3,
221
+ padding=1,
222
+ )
223
+
224
+ @property
225
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
226
+ r"""
227
+ Returns:
228
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
229
+ indexed by its weight name.
230
+ """
231
+ # set recursively
232
+ processors = {}
233
+
234
+ def fn_recursive_add_processors(
235
+ name: str,
236
+ module: torch.nn.Module,
237
+ processors: Dict[str, AttentionProcessor],
238
+ ):
239
+ if hasattr(module, "get_processor"):
240
+ processors[f"{name}.processor"] = module.get_processor()
241
+
242
+ for sub_name, child in module.named_children():
243
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
244
+
245
+ return processors
246
+
247
+ for name, module in self.named_children():
248
+ fn_recursive_add_processors(name, module, processors)
249
+
250
+ return processors
251
+
252
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
253
+ r"""
254
+ Sets the attention processor to use to compute attention.
255
+
256
+ Parameters:
257
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
258
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
259
+ for **all** `Attention` layers.
260
+
261
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
262
+ processor. This is strongly recommended when setting trainable attention processors.
263
+
264
+ """
265
+ count = len(self.attn_processors.keys())
266
+
267
+ if isinstance(processor, dict) and len(processor) != count:
268
+ raise ValueError(
269
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
270
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
271
+ )
272
+
273
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
274
+ if hasattr(module, "set_processor"):
275
+ if not isinstance(processor, dict):
276
+ module.set_processor(processor)
277
+ else:
278
+ module.set_processor(processor.pop(f"{name}.processor"))
279
+
280
+ for sub_name, child in module.named_children():
281
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
282
+
283
+ for name, module in self.named_children():
284
+ fn_recursive_attn_processor(name, module, processor)
285
+
286
+ def set_default_attn_processor(self):
287
+ """
288
+ Disables custom attention processors and sets the default attention implementation.
289
+ """
290
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
291
+ processor = AttnProcessor()
292
+ else:
293
+ raise ValueError(
294
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
295
+ )
296
+
297
+ self.set_attn_processor(processor)
298
+
299
+ def _set_gradient_checkpointing(self, module, value=False):
300
+ if hasattr(module, "gradient_checkpointing"):
301
+ module.gradient_checkpointing = value
302
+
303
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
304
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
305
+ """
306
+ Sets the attention processor to use [feed forward
307
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
308
+
309
+ Parameters:
310
+ chunk_size (`int`, *optional*):
311
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
312
+ over each tensor of dim=`dim`.
313
+ dim (`int`, *optional*, defaults to `0`):
314
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
315
+ or dim=1 (sequence length).
316
+ """
317
+ if dim not in [0, 1]:
318
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
319
+
320
+ # By default chunk size is 1
321
+ chunk_size = chunk_size or 1
322
+
323
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
324
+ if hasattr(module, "set_chunk_feed_forward"):
325
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
326
+
327
+ for child in module.children():
328
+ fn_recursive_feed_forward(child, chunk_size, dim)
329
+
330
+ for module in self.children():
331
+ fn_recursive_feed_forward(module, chunk_size, dim)
332
+
333
+ def forward(
334
+ self,
335
+ sample: torch.Tensor,
336
+ timestep: Union[torch.Tensor, float, int],
337
+ encoder_hidden_states: torch.Tensor,
338
+ added_time_ids: Union[torch.Tensor, List[torch.Tensor]],
339
+ return_dict: bool = True,
340
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
341
+ r"""
342
+ The [`UNetSpatioTemporalConditionModel`] forward method.
343
+
344
+ Args:
345
+ sample (`torch.Tensor`):
346
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
347
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
348
+ encoder_hidden_states (`torch.Tensor`):
349
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
350
+ added_time_ids: (`torch.Tensor`):
351
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
352
+ embeddings and added to the time embeddings.
353
+ return_dict (`bool`, *optional*, defaults to `True`):
354
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
355
+ of a plain tuple.
356
+ Returns:
357
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
358
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
359
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
360
+ """
361
+ # 1. time
362
+ timesteps = timestep
363
+ if not torch.is_tensor(timesteps):
364
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
365
+ # This would be a good case for the `match` statement (Python 3.10+)
366
+ is_mps = sample.device.type == "mps"
367
+ if isinstance(timestep, float):
368
+ dtype = torch.float32 if is_mps else torch.float64
369
+ else:
370
+ dtype = torch.int32 if is_mps else torch.int64
371
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
372
+ elif len(timesteps.shape) == 0:
373
+ timesteps = timesteps[None].to(sample.device)
374
+
375
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
376
+ batch_size, num_frames = sample.shape[:2]
377
+ timesteps = timesteps.expand(batch_size)
378
+
379
+ t_emb = self.time_proj(timesteps)
380
+
381
+ # `Timesteps` does not contain any weights and will always return f32 tensors
382
+ # but time_embedding might actually be running in fp16. so we need to cast here.
383
+ # there might be better ways to encapsulate this.
384
+ t_emb = t_emb.to(dtype=sample.dtype)
385
+
386
+ emb = self.time_embedding(t_emb)
387
+
388
+ if isinstance(added_time_ids, torch.Tensor):
389
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
390
+ time_embeds = time_embeds.reshape((batch_size, -1))
391
+ time_embeds = time_embeds.to(emb.dtype)
392
+ aug_emb = self.add_embedding(time_embeds)
393
+ emb = emb + aug_emb
394
+
395
+ # Repeat the embeddings num_video_frames times
396
+ # emb: [batch, channels] -> [batch * frames, channels]
397
+ emb = emb.repeat_interleave(num_frames, dim=0)
398
+ elif isinstance(added_time_ids, list):
399
+ # Repeat the embeddings num_video_frames times
400
+ # emb: [batch, channels] -> [batch * frames, channels]
401
+ emb = emb.repeat_interleave(num_frames, dim=0)
402
+
403
+ cond_aug = added_time_ids[0]
404
+ cond_aug_emb = self.add_time_proj(cond_aug.flatten())
405
+ time_embeds = cond_aug_emb
406
+ time_embeds = time_embeds.to(emb.dtype)
407
+ aug_emb = self.add_embedding(time_embeds)
408
+ emb = emb + aug_emb
409
+ else:
410
+ raise ValueError
411
+
412
+ # Flatten the batch and frames dimensions
413
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
414
+ sample = sample.flatten(0, 1)
415
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
416
+
417
+ # Taken care of in the pipeline (to allow reference manipulations)
418
+ # encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
419
+
420
+ # 2. pre-process
421
+ sample = self.conv_in(sample)
422
+
423
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
424
+
425
+ down_block_res_samples = (sample,)
426
+ for downsample_block in self.down_blocks:
427
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
428
+ sample, res_samples = downsample_block(
429
+ hidden_states=sample,
430
+ temb=emb,
431
+ encoder_hidden_states=encoder_hidden_states,
432
+ image_only_indicator=image_only_indicator,
433
+ )
434
+ else:
435
+ sample, res_samples = downsample_block(
436
+ hidden_states=sample,
437
+ temb=emb,
438
+ image_only_indicator=image_only_indicator,
439
+ )
440
+
441
+ down_block_res_samples += res_samples
442
+
443
+ # 4. mid
444
+ sample = self.mid_block(
445
+ hidden_states=sample,
446
+ temb=emb,
447
+ encoder_hidden_states=encoder_hidden_states,
448
+ image_only_indicator=image_only_indicator,
449
+ )
450
+
451
+ # 5. up
452
+ for i, upsample_block in enumerate(self.up_blocks):
453
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
454
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
455
+
456
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
457
+ sample = upsample_block(
458
+ hidden_states=sample,
459
+ temb=emb,
460
+ res_hidden_states_tuple=res_samples,
461
+ encoder_hidden_states=encoder_hidden_states,
462
+ image_only_indicator=image_only_indicator,
463
+ )
464
+ else:
465
+ sample = upsample_block(
466
+ hidden_states=sample,
467
+ temb=emb,
468
+ res_hidden_states_tuple=res_samples,
469
+ image_only_indicator=image_only_indicator,
470
+ )
471
+
472
+ # 6. post-process
473
+ sample = self.conv_norm_out(sample)
474
+ sample = self.conv_act(sample)
475
+ sample = self.conv_out(sample)
476
+
477
+ # 7. Reshape back to original shape
478
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
479
+
480
+ if not return_dict:
481
+ return (sample,)
482
+
483
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diffusers_sv3d/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .stable_video_diffusion import StableVideo3DDiffusionPipeline, StableVideo3DDiffusionPipelineRotate
diffusers_sv3d/pipelines/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (318 Bytes). View file
 
diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pipeline_stable_video_3d_diffusion import StableVideo3DDiffusionPipeline
2
+ from .pipeline_stable_video_3d_diffusion_rotate import StableVideo3DDiffusionPipelineRotate
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (418 Bytes). View file
 
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
diffusers_sv3d/pipelines/stable_video_diffusion/__pycache__/pipeline_stable_video_3d_diffusion_rotate.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import PIL.Image
4
+ import torch
5
+ from diffusers.models.attention_processor import AttnProcessor2_0
6
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
7
+ StableVideoDiffusionPipeline,
8
+ _append_dims,
9
+ randn_tensor,
10
+ retrieve_timesteps,
11
+ )
12
+
13
+ from self_attn_swap import ACTIVATE_LAYER_CANDIDATE_SV3D, SharedAttentionProcessorThree
14
+
15
+ # Constants
16
+ HEIGHT = 576
17
+ WIDTH = 576
18
+ NUM_FRAMES = 21
19
+ NOISE_AUG_STRENGTH = 1e-5
20
+ DECODE_CHUNK_SIZE = 2
21
+ NUM_VID = 1
22
+ BATCH_SIZE = 1
23
+ MIN_CFG = 1.0
24
+ MAX_CFG = 2.5
25
+
26
+
27
+ class StableVideo3DDiffusionPipeline(StableVideoDiffusionPipeline):
28
+ def __init__(self, vae, image_encoder, unet, scheduler, feature_extractor):
29
+ super().__init__(vae, image_encoder, unet, scheduler, feature_extractor)
30
+
31
+ def _get_add_time_ids(
32
+ self, dtype: torch.dtype, num_processes, do_classifier_free_guidance: bool
33
+ ) -> List[torch.Tensor]:
34
+ cond_aug = torch.tensor([NOISE_AUG_STRENGTH] * 21, dtype=dtype).repeat(BATCH_SIZE * num_processes, 1)
35
+
36
+ if do_classifier_free_guidance:
37
+ cond_aug = torch.cat([cond_aug, cond_aug])
38
+
39
+ add_time_ids = [cond_aug]
40
+
41
+ self.unet.to(dtype=torch.float16)
42
+ self.vae.to(dtype=torch.float16)
43
+
44
+ return add_time_ids
45
+
46
+ def prepare_video_latents(
47
+ self,
48
+ images: List[torch.Tensor],
49
+ timestep: torch.Tensor,
50
+ add_noise: bool = True,
51
+ refine_frames: Optional[int] = None,
52
+ original_latents: Optional[torch.Tensor] = None,
53
+ ) -> torch.Tensor:
54
+ """Prepare video latents by encoding frames and optionally adding noise."""
55
+ encoded_frames = [self._encode_vae_image(image, self.device, NUM_VID, False) for image in images]
56
+ encoded_frames = [frame.to(images[0].dtype) for frame in encoded_frames]
57
+
58
+ # TODO: check scaling factor?
59
+ encoded_frames = [self.vae.config.scaling_factor * frame for frame in encoded_frames]
60
+
61
+ if add_noise:
62
+ video_latents = [
63
+ self.scheduler.add_noise(
64
+ frame,
65
+ randn_tensor(encoded_frames[0].shape, self.generator, self.device, images[0].dtype),
66
+ timestep,
67
+ )
68
+ for frame in encoded_frames
69
+ ]
70
+ else:
71
+ video_latents = encoded_frames
72
+
73
+ if refine_frames is not None and original_latents is not None:
74
+ video_latents = encoded_frames
75
+
76
+ for i in range(len(video_latents)):
77
+ if i in refine_frames:
78
+ video_latents[i] = original_latents[i].unsqueeze(0)
79
+
80
+ return torch.stack(video_latents, dim=1)
81
+
82
+ def activate_layers(self, config: Dict[str, List[Union[float, int]]], swapping_type="linear") -> Dict[str, AttnProcessor2_0]:
83
+ """Activate swapping attention mechanism in specific UNet layers."""
84
+
85
+ # Setup default values first
86
+ default_attn_procs = {}
87
+
88
+ for layer in self.unet.attn_processors.keys():
89
+ default_attn_procs[layer] = AttnProcessor2_0()
90
+
91
+ self.unet.set_attn_processor(default_attn_procs)
92
+
93
+ spatial_attn = [layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".transformer_blocks.0.attn1" in layer]
94
+ temporal_attn = [
95
+ layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".temporal_transformer_blocks.0.attn1" in layer
96
+ ]
97
+
98
+ assert len(spatial_attn) == len(config["spatial_ratio"]) == len(config["spatial_strength"])
99
+ assert len(temporal_attn) == len(config["temporal_ratio"]) == len(config["temporal_strength"])
100
+
101
+ ratios = {}
102
+ for layer, ratio, strength in zip(spatial_attn, config["spatial_ratio"], config["spatial_strength"]):
103
+ ratios[layer] = {"ratio": ratio, "strength": strength}
104
+
105
+ for layer, ratio, strength in zip(temporal_attn, config["temporal_ratio"], config["temporal_strength"]):
106
+ ratios[layer] = {"ratio": ratio, "strength": strength}
107
+
108
+ attn_procs = {}
109
+
110
+ for layer in self.unet.attn_processors.keys():
111
+ if layer in ratios:
112
+ attn_procs[layer] = SharedAttentionProcessorThree(
113
+ unet_chunk_size=2, activate_step_indices=config["activate_steps"], ratio=ratios[layer], swapping_type=swapping_type
114
+ )
115
+ else:
116
+ attn_procs[layer] = AttnProcessor2_0()
117
+
118
+ self.unet.set_attn_processor(attn_procs)
119
+
120
+ return attn_procs
121
+
122
+ def _decode_vae_frames(self, image_latents: torch.Tensor) -> torch.Tensor:
123
+ frames = []
124
+ for i in range(21):
125
+ frame = self.vae.decode(image_latents[:, i], self.device).sample
126
+ frames.append(frame)
127
+ return torch.stack(frames, dim=2)
128
+
129
+ def _preprocess_reference_images(self, reference_images: List[PIL.Image.Image]) -> List[torch.Tensor]:
130
+ """Helper method to preprocess reference images consistently"""
131
+ processed_images = []
132
+ for image in reference_images:
133
+ ref_image = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
134
+ ref_noise = randn_tensor(ref_image.shape, self.generator, self.device, ref_image.dtype)
135
+ ref_image = ref_image + NOISE_AUG_STRENGTH * ref_noise
136
+ processed_images.append(ref_image)
137
+ return processed_images
138
+
139
+ def _preprocess_image(self, image: Union[PIL.Image.Image, torch.Tensor]) -> torch.Tensor:
140
+ """Preprocess a single image with noise augmentation"""
141
+ processed = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
142
+ noise = randn_tensor(processed.shape, self.generator, self.device, processed.dtype)
143
+ return processed + NOISE_AUG_STRENGTH * noise
144
+
145
+ def _denoise_loop(
146
+ self,
147
+ latents: torch.Tensor,
148
+ image_latents: torch.Tensor,
149
+ image_embeddings: torch.Tensor,
150
+ added_time_ids: List[torch.Tensor],
151
+ timesteps: torch.Tensor,
152
+ z0_reference_images: Optional[List[torch.Tensor]] = None,
153
+ z0_shape_images: Optional[List[torch.Tensor]] = None,
154
+ refinement: bool = False,
155
+ refine_frames: Optional[list] = None,
156
+ z0_mid_images: Optional[List[torch.Tensor]] = None,
157
+ output_type: str = "pil",
158
+ add_noise: bool = True,
159
+ ):
160
+ num_warmup_steps = len(timesteps) - self.num_inference_steps * self.scheduler.order
161
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
162
+
163
+ intermediate_steps = []
164
+
165
+ normal_latents = None
166
+
167
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16):
168
+ with self.progress_bar(total=self.num_inference_steps) as progress_bar:
169
+ for i, t in enumerate(timesteps):
170
+ if i in self.replace_reference_steps:
171
+ latents[0] = self.prepare_video_latents(
172
+ z0_reference_images,
173
+ timestep=t.repeat(1),
174
+ add_noise=add_noise,
175
+ )
176
+
177
+ if refinement and z0_mid_images is not None:
178
+ latents[1] = self.prepare_video_latents(
179
+ z0_mid_images,
180
+ timestep=t.repeat(1),
181
+ add_noise=add_noise,
182
+ refine_frames=refine_frames,
183
+ original_latents=latents[1],
184
+ )
185
+
186
+ if refinement and z0_shape_images is not None:
187
+ latents[2] = self.prepare_video_latents(
188
+ z0_shape_images,
189
+ timestep=t.repeat(1),
190
+ add_noise=add_noise,
191
+ )
192
+
193
+ # expand the latents if we are doing cfg
194
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
195
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
196
+
197
+ # Concatenate image_latents over channels dimension
198
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
199
+
200
+ torch.cuda.empty_cache()
201
+
202
+
203
+ print(latent_model_input.shape, t, image_embeddings.shape, added_time_ids[0].shape)
204
+
205
+ # predict the noise residual
206
+ noise_pred = self.unet(
207
+ latent_model_input, # 2/4/6,21,8,72,72
208
+ t, # float
209
+ encoder_hidden_states=image_embeddings, # 42/84/126,1,1024
210
+ added_time_ids=added_time_ids, # 2/4/6,21
211
+ return_dict=False,
212
+ )[0] # 1/2/3,21,4,72,72
213
+
214
+ # perform guidance
215
+ if self.do_classifier_free_guidance:
216
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
217
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
218
+
219
+ # compute the previous noisy sample x_t -> x_t-1
220
+ step_output = self.scheduler.step(noise_pred, t, latents) # EulerDiscreteScheduler
221
+ latents = step_output.prev_sample
222
+ normal_latents = step_output.pred_original_sample
223
+
224
+ if self.return_intermediate_steps:
225
+ if needs_upcasting:
226
+ self.vae.to(dtype=torch.float16)
227
+ frames = self.decode_latents(normal_latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
228
+ frames = self.video_processor.postprocess_video(frames, "pil")
229
+ intermediate_steps.append(frames)
230
+
231
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
232
+ progress_bar.update()
233
+
234
+ if not output_type == "latent":
235
+ # cast back to fp16 if needed
236
+ if needs_upcasting:
237
+ self.vae.to(dtype=torch.float16)
238
+ frames = self.decode_latents(latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
239
+ frames = self.video_processor.postprocess_video(frames, output_type)
240
+ else:
241
+ frames = latents
242
+
243
+ self.maybe_free_model_hooks()
244
+
245
+ return frames, intermediate_steps
246
+
247
+ @torch.no_grad()
248
+ def __call__(
249
+ self,
250
+ input_image: PIL.Image.Image,
251
+ reference_images: List[PIL.Image.Image],
252
+ num_inference_steps: int = 25,
253
+ replace_reference_steps: List[int] = list(),
254
+ return_intermediate_steps: bool = False,
255
+ seed: int = 42,
256
+ same_starting_latents: bool = True,
257
+ refinement: bool = False,
258
+ refine_frames: Optional[list] = None,
259
+ add_noise: bool = True,
260
+ ):
261
+ # 0. Set seed
262
+ self.generator = torch.manual_seed(seed)
263
+
264
+ # 1. Check inputs. Raise error if not correct
265
+ self.check_inputs(input_image, HEIGHT, WIDTH)
266
+
267
+ # 2. Define call parameters
268
+ self.num_inference_steps = num_inference_steps
269
+ self.return_intermediate_steps = return_intermediate_steps
270
+ self.replace_reference_steps = replace_reference_steps
271
+ self._guidance_scale = MAX_CFG
272
+
273
+ # z0_mid_images = None
274
+
275
+ # 3. Encode input image (CLIP)
276
+ image_embeddings_combined = [
277
+ self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
278
+ self._encode_image(input_image, self.device, NUM_VID, self.do_classifier_free_guidance),
279
+ self._encode_image(input_image, self.device, NUM_VID, self.do_classifier_free_guidance),
280
+ ]
281
+
282
+ all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
283
+ embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5])
284
+ reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
285
+ image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
286
+
287
+ # 4. Encode using VAE
288
+ image = self._preprocess_image(input_image)
289
+
290
+ ref_image = self._preprocess_image(reference_images[-1])
291
+ z0_reference_images = self._preprocess_reference_images(reference_images)
292
+
293
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
294
+ if needs_upcasting:
295
+ self.vae.to(dtype=torch.float32)
296
+
297
+ image_latents = self._encode_vae_image(image, self.device, NUM_VID, self.do_classifier_free_guidance)
298
+ image_latents = image_latents.to(image_embeddings.dtype)
299
+
300
+ ref_image_latents = self._encode_vae_image(ref_image, self.device, NUM_VID, self.do_classifier_free_guidance)
301
+ ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
302
+
303
+ image_latents_full = [
304
+ ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
305
+ image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
306
+ image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
307
+ ]
308
+
309
+ image_latents = torch.cat(image_latents_full, dim=0)
310
+ image_latents_order = torch.tensor([0, 2, 4, 1, 3, 5])
311
+ image_latents = image_latents[image_latents_order]
312
+
313
+ if needs_upcasting: # cast back to fp16 if needed
314
+ self.vae.to(dtype=torch.float16)
315
+
316
+ num_processes = 3
317
+
318
+ # 5. Get Added Time IDs
319
+ added_time_ids = self._get_add_time_ids(
320
+ image_embeddings.dtype,
321
+ num_processes,
322
+ self.do_classifier_free_guidance,
323
+ ) # list of tensor [2, 21] or [4, 21] or [6, 21] -> just 4x the same
324
+
325
+ added_time_ids = [a.to(self.device) for a in added_time_ids]
326
+
327
+ timesteps, self.num_inference_steps = retrieve_timesteps(self.scheduler, self.num_inference_steps, self.device)
328
+
329
+ # 7. Prepare latent variables
330
+ num_channels_latents = self.unet.config.in_channels # 8
331
+ latents = self.prepare_latents(
332
+ BATCH_SIZE * num_processes,
333
+ NUM_FRAMES,
334
+ num_channels_latents,
335
+ HEIGHT,
336
+ WIDTH,
337
+ image_embeddings.dtype,
338
+ self.device,
339
+ self.generator,
340
+ ) # 2/3,21,4,72,72
341
+
342
+ if same_starting_latents:
343
+ latents[0] = latents[1] = latents[2]
344
+
345
+ # 8. Prepare guidance scale
346
+ guidance_scale = torch.cat(
347
+ [
348
+ torch.linspace(MIN_CFG, MAX_CFG, NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
349
+ torch.linspace(MAX_CFG, MIN_CFG, NUM_FRAMES - NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
350
+ ],
351
+ dim=-1,
352
+ )
353
+
354
+ guidance_scale = guidance_scale.to(self.device, latents.dtype)
355
+ guidance_scale = guidance_scale.repeat(BATCH_SIZE, 1)
356
+ guidance_scale = _append_dims(guidance_scale, latents.ndim) # [1,21,1,1,1]
357
+
358
+ self._guidance_scale = guidance_scale
359
+
360
+ # 9. Denoising loop
361
+ frames, intemediate_steps = self._denoise_loop(
362
+ latents=latents,
363
+ image_latents=image_latents,
364
+ image_embeddings=image_embeddings,
365
+ added_time_ids=added_time_ids,
366
+ timesteps=timesteps,
367
+ z0_reference_images=z0_reference_images,
368
+ output_type="pil",
369
+ add_noise=add_noise,
370
+ )
371
+ new_front_image = None
372
+ if refinement:
373
+ assert refine_frames is not None
374
+ current_front_frame_idx = refine_frames[-1]
375
+ shift = NUM_FRAMES - current_front_frame_idx
376
+
377
+ mid_images = frames[1]
378
+ shape_images = frames[2]
379
+
380
+ new_front_image = mid_images[current_front_frame_idx]
381
+
382
+ # roll the lists
383
+ reference_images = reference_images[shift:] + reference_images[:shift]
384
+ shape_images = shape_images[shift:] + shape_images[:shift]
385
+ mid_images = mid_images[shift:] + mid_images[:shift]
386
+
387
+ latents = self.prepare_latents(
388
+ BATCH_SIZE * num_processes,
389
+ NUM_FRAMES,
390
+ num_channels_latents,
391
+ HEIGHT,
392
+ WIDTH,
393
+ image_embeddings.dtype,
394
+ self.device,
395
+ self.generator,
396
+ )
397
+ if same_starting_latents:
398
+ latents[0] = latents[1] = latents[2]
399
+
400
+ timesteps, self.num_inference_steps = retrieve_timesteps(
401
+ self.scheduler, self.num_inference_steps, self.device
402
+ )
403
+
404
+ ref_image = self._preprocess_image(z0_reference_images[-1])
405
+ ref_image_latents = self._encode_vae_image(
406
+ ref_image, self.device, NUM_VID, self.do_classifier_free_guidance
407
+ )
408
+ ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
409
+
410
+ mid_image = self._preprocess_image(mid_images[-1])
411
+ mid_image_latents = self._encode_vae_image(
412
+ mid_image, self.device, NUM_VID, self.do_classifier_free_guidance
413
+ )
414
+ mid_image_latents = mid_image_latents.to(image_embeddings.dtype)
415
+
416
+ shape_image = self._preprocess_image(shape_images[-1])
417
+ shape_image_latents = self._encode_vae_image(
418
+ shape_image, self.device, NUM_VID, self.do_classifier_free_guidance
419
+ )
420
+ shape_image_latents = shape_image_latents.to(image_embeddings.dtype)
421
+
422
+ image_latents_full = [
423
+ ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
424
+ mid_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
425
+ shape_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
426
+ ]
427
+
428
+ image_latents = torch.cat(image_latents_full, dim=0)
429
+ image_latents = image_latents[image_latents_order]
430
+
431
+ # CLIP embeddings on the new front frame
432
+ image_embeddings_combined = [
433
+ self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
434
+ self._encode_image(mid_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
435
+ self._encode_image(shape_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
436
+ ]
437
+ all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
438
+ embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5])
439
+ reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
440
+ image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
441
+
442
+ z0_mid_images = self._preprocess_reference_images(mid_images)
443
+ z0_shape_images = self._preprocess_reference_images(shape_images)
444
+
445
+ frames, intemediate_steps = self._denoise_loop(
446
+ latents=latents,
447
+ image_latents=image_latents,
448
+ image_embeddings=image_embeddings,
449
+ added_time_ids=added_time_ids,
450
+ timesteps=timesteps,
451
+ z0_reference_images=z0_reference_images,
452
+ z0_shape_images=z0_shape_images,
453
+ refinement=refinement,
454
+ refine_frames=refine_frames,
455
+ z0_mid_images=z0_mid_images,
456
+ add_noise=add_noise,
457
+ )
458
+
459
+ # Roll back frames to original order
460
+ frames = [
461
+ frames[0][(-shift):] + frames[0][:-shift],
462
+ frames[1][(-shift):] + frames[1][:-shift],
463
+ frames[2][(-shift):] + frames[2][:-shift],
464
+ ]
465
+
466
+ if return_intermediate_steps:
467
+ return frames, new_front_image, intemediate_steps
468
+
469
+ return frames, new_front_image, None
diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion_rotate.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import PIL.Image
4
+ import torch
5
+ from diffusers.models.attention_processor import AttnProcessor2_0
6
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
7
+ StableVideoDiffusionPipeline,
8
+ _append_dims,
9
+ randn_tensor,
10
+ retrieve_timesteps,
11
+ )
12
+
13
+ from self_attn_swap import ACTIVATE_LAYER_CANDIDATE_SV3D, SharedAttentionProcessorThree
14
+
15
+ # Constants
16
+ HEIGHT = 576
17
+ WIDTH = 576
18
+ NUM_FRAMES = 21
19
+ NOISE_AUG_STRENGTH = 1e-5
20
+ DECODE_CHUNK_SIZE = 2
21
+ NUM_VID = 1
22
+ GENERATOR = torch.manual_seed(42)
23
+ OUTPUT_TYPE = "pil"
24
+ BATCH_SIZE = 1
25
+ MIN_CFG = 1.0
26
+ MAX_CFG = 2.5
27
+
28
+
29
+ class StableVideo3DDiffusionPipelineRotate(StableVideoDiffusionPipeline):
30
+ def __init__(self, vae, image_encoder, unet, scheduler, feature_extractor):
31
+ super().__init__(vae, image_encoder, unet, scheduler, feature_extractor)
32
+
33
+ def _get_add_time_ids(
34
+ self, dtype: torch.dtype, num_processes, do_classifier_free_guidance: bool
35
+ ) -> List[torch.Tensor]:
36
+ cond_aug = torch.tensor([NOISE_AUG_STRENGTH] * 21, dtype=dtype).repeat(BATCH_SIZE * num_processes, 1)
37
+
38
+ if do_classifier_free_guidance:
39
+ cond_aug = torch.cat([cond_aug, cond_aug])
40
+
41
+ add_time_ids = [cond_aug]
42
+
43
+ self.unet.to(dtype=torch.float16)
44
+ self.vae.to(dtype=torch.float16)
45
+
46
+ return add_time_ids
47
+
48
+ def prepare_video_latents(
49
+ self,
50
+ images: List[torch.Tensor],
51
+ timestep: torch.Tensor,
52
+ add_noise: bool = True,
53
+ active_size: Optional[int] = None,
54
+ original_latents: Optional[torch.Tensor] = None,
55
+ ) -> torch.Tensor:
56
+ """Prepare video latents by encoding frames and optionally adding noise."""
57
+ encoded_frames = [self._encode_vae_image(image, self.device, NUM_VID, False) for image in images]
58
+ encoded_frames = [frame.to(images[0].dtype) for frame in encoded_frames]
59
+
60
+ # TODO: check scaling factor?
61
+ encoded_frames = [self.vae.config.scaling_factor * frame for frame in encoded_frames]
62
+
63
+ # add noise
64
+ if add_noise:
65
+ video_latents = [
66
+ self.scheduler.add_noise(
67
+ frame,
68
+ randn_tensor(encoded_frames[0].shape, GENERATOR, self.device, images[0].dtype),
69
+ timestep,
70
+ )
71
+ for frame in encoded_frames
72
+ ]
73
+ else:
74
+ video_latents = encoded_frames
75
+
76
+ if active_size is not None and original_latents is not None:
77
+ for i in range(len(video_latents)):
78
+ if NUM_FRAMES - active_size - 1 <= i < NUM_FRAMES - 1:
79
+ video_latents[i] = original_latents[i].unsqueeze(0)
80
+
81
+ return torch.stack(video_latents, dim=1)
82
+
83
+ def activate_layers(self, config: Dict[str, List[Union[float, int]]]) -> Dict[str, AttnProcessor2_0]:
84
+ """Activate swapping attention mechanism in specific UNet layers."""
85
+ spatial_attn = [layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".transformer_blocks.0.attn1" in layer]
86
+ temporal_attn = [
87
+ layer for layer in ACTIVATE_LAYER_CANDIDATE_SV3D if ".temporal_transformer_blocks.0.attn1" in layer
88
+ ]
89
+
90
+ assert len(spatial_attn) == len(config["spatial_ratio"]) == len(config["spatial_strength"])
91
+ assert len(temporal_attn) == len(config["temporal_ratio"]) == len(config["temporal_strength"])
92
+
93
+ ratios = {}
94
+ for layer, ratio, strength in zip(spatial_attn, config["spatial_ratio"], config["spatial_strength"]):
95
+ ratios[layer] = {"ratio": ratio, "strength": strength}
96
+
97
+ for layer, ratio, strength in zip(temporal_attn, config["temporal_ratio"], config["temporal_strength"]):
98
+ ratios[layer] = {"ratio": ratio, "strength": strength}
99
+
100
+ attn_procs = {}
101
+
102
+ for layer in self.unet.attn_processors.keys():
103
+ if layer in ratios:
104
+ attn_procs[layer] = SharedAttentionProcessorThree(
105
+ unet_chunk_size=2, activate_step_indices=config["activate_steps"], ratio=ratios[layer]
106
+ )
107
+ else:
108
+ attn_procs[layer] = AttnProcessor2_0()
109
+
110
+ self.unet.set_attn_processor(attn_procs)
111
+
112
+ return attn_procs
113
+
114
+ def _decode_vae_frames(self, image_latents: torch.Tensor) -> torch.Tensor:
115
+ frames = []
116
+ for i in range(21):
117
+ frame = self.vae.decode(image_latents[:, i], self.device).sample
118
+ frames.append(frame)
119
+ return torch.stack(frames, dim=2)
120
+
121
+ def _preprocess_reference_images(self, reference_images: List[PIL.Image.Image]) -> List[torch.Tensor]:
122
+ """Helper method to preprocess reference images consistently"""
123
+ processed_images = []
124
+ for image in reference_images:
125
+ ref_image = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
126
+ ref_noise = randn_tensor(ref_image.shape, GENERATOR, self.device, ref_image.dtype)
127
+ ref_image = ref_image + NOISE_AUG_STRENGTH * ref_noise
128
+ processed_images.append(ref_image)
129
+ return processed_images
130
+
131
+ def _preprocess_image(self, image: Union[PIL.Image.Image, torch.Tensor]) -> torch.Tensor:
132
+ """Preprocess a single image with noise augmentation"""
133
+ processed = self.video_processor.preprocess(image, HEIGHT, WIDTH).to(self.device)
134
+ noise = randn_tensor(processed.shape, GENERATOR, self.device, processed.dtype)
135
+ return processed + NOISE_AUG_STRENGTH * noise
136
+
137
+ def _denoise_loop(
138
+ self,
139
+ latents: torch.Tensor,
140
+ image_latents: torch.Tensor,
141
+ image_embeddings: torch.Tensor,
142
+ added_time_ids: List[torch.Tensor],
143
+ timesteps: torch.Tensor,
144
+ mids_active_size: int,
145
+ z0_mid_images: List[torch.Tensor],
146
+ z0_reference_images: Optional[List[torch.Tensor]] = None,
147
+ z0_shape_images: Optional[List[torch.Tensor]] = None,
148
+ ):
149
+ num_warmup_steps = len(timesteps) - self.num_inference_steps * self.scheduler.order
150
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
151
+
152
+ intermediate_steps = []
153
+
154
+ normal_latents = None
155
+
156
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16):
157
+ with self.progress_bar(total=self.num_inference_steps) as progress_bar:
158
+ for i, t in enumerate(timesteps):
159
+ if i in self.replace_reference_steps:
160
+ latents[0] = self.prepare_video_latents(
161
+ z0_reference_images,
162
+ timestep=t.repeat(1),
163
+ add_noise=True,
164
+ )
165
+
166
+ latents[1] = self.prepare_video_latents(
167
+ z0_mid_images,
168
+ timestep=t.repeat(1),
169
+ add_noise=True,
170
+ active_size=mids_active_size if i > 5 else None,
171
+ original_latents=latents[1],
172
+ )
173
+
174
+ if z0_shape_images is not None:
175
+ latents[2] = self.prepare_video_latents(
176
+ z0_shape_images,
177
+ timestep=t.repeat(1),
178
+ add_noise=True,
179
+ )
180
+
181
+ # expand the latents if we are doing cfg
182
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
183
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
184
+
185
+ # Concatenate image_latents over channels dimension
186
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
187
+
188
+ torch.cuda.empty_cache()
189
+
190
+ # predict the noise residual
191
+ noise_pred = self.unet(
192
+ latent_model_input, # 2/4/6,21,8,72,72
193
+ t, # float
194
+ encoder_hidden_states=image_embeddings, # 42/84/126,1,1024
195
+ added_time_ids=added_time_ids, # 2/4/6,21
196
+ return_dict=False,
197
+ )[0] # 1/2/3,21,4,72,72
198
+
199
+ # perform guidance
200
+ if self.do_classifier_free_guidance:
201
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
202
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
203
+
204
+ # compute the previous noisy sample x_t -> x_t-1
205
+ step_output = self.scheduler.step(noise_pred, t, latents) # EulerDiscreteScheduler
206
+ latents = step_output.prev_sample
207
+ normal_latents = step_output.pred_original_sample
208
+
209
+ if self.return_intermediate_steps:
210
+ if needs_upcasting:
211
+ self.vae.to(dtype=torch.float16)
212
+ frames = self.decode_latents(normal_latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
213
+ frames = self.video_processor.postprocess_video(frames, OUTPUT_TYPE)
214
+ intermediate_steps.append(frames)
215
+
216
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
217
+ progress_bar.update()
218
+
219
+ if not OUTPUT_TYPE == "latent":
220
+ # cast back to fp16 if needed
221
+ if needs_upcasting:
222
+ self.vae.to(dtype=torch.float16)
223
+ frames = self.decode_latents(latents, NUM_FRAMES, DECODE_CHUNK_SIZE)
224
+ frames = self.video_processor.postprocess_video(frames, OUTPUT_TYPE)
225
+ else:
226
+ frames = latents
227
+
228
+ self.maybe_free_model_hooks()
229
+
230
+ return frames, intermediate_steps
231
+
232
+ @torch.no_grad()
233
+ def __call__(
234
+ self,
235
+ mid_images: List[PIL.Image.Image],
236
+ reference_images: List[PIL.Image.Image],
237
+ shape_images: Optional[List[PIL.Image.Image]] = None,
238
+ num_inference_steps: int = 25,
239
+ replace_reference_steps: List[int] = list(),
240
+ return_intermediate_steps: bool = False,
241
+ mids_active_size: int = 5,
242
+ ):
243
+ # 1. Check inputs. Raise error if not correct
244
+ self.check_inputs(mid_images[-1], HEIGHT, WIDTH)
245
+
246
+ # 2. Define call parameters
247
+ self.num_inference_steps = num_inference_steps
248
+ self.return_intermediate_steps = return_intermediate_steps
249
+ self.replace_reference_steps = replace_reference_steps
250
+ self._guidance_scale = MAX_CFG
251
+
252
+ # 3. Encode input image (CLIP)
253
+ image_embeddings_combined = [
254
+ self._encode_image(reference_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
255
+ self._encode_image(mid_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance),
256
+ ]
257
+ if shape_images is not None:
258
+ image_embeddings_combined.append(
259
+ self._encode_image(shape_images[-1], self.device, NUM_VID, self.do_classifier_free_guidance)
260
+ )
261
+ all_embeddings = torch.cat(image_embeddings_combined, dim=0) # uc, c, uc, c, (uc, c)
262
+ embeddings_order = torch.tensor([0, 2, 4, 1, 3, 5]) if shape_images else torch.tensor([0, 2, 1, 3])
263
+ reordered_embeddings = all_embeddings[embeddings_order] # uc, uc, (uc), c, c, (c)
264
+ image_embeddings = reordered_embeddings.repeat_interleave(NUM_FRAMES, dim=0)
265
+
266
+ # 4. Encode using VAE
267
+ image = self._preprocess_image(mid_images[-1])
268
+ ref_image = self._preprocess_image(reference_images[-1])
269
+
270
+ z0_reference_images = self._preprocess_reference_images(reference_images)
271
+ z0_mid_images = self._preprocess_reference_images(mid_images)
272
+
273
+ if shape_images is not None:
274
+ shape_image = self._preprocess_image(shape_images[-1])
275
+ z0_shape_images = self._preprocess_reference_images(
276
+ shape_images,
277
+ )
278
+ else:
279
+ shape_image = None
280
+ z0_shape_images = None
281
+
282
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
283
+ if needs_upcasting:
284
+ self.vae.to(dtype=torch.float32)
285
+
286
+ image_latents = self._encode_vae_image(image, self.device, NUM_VID, self.do_classifier_free_guidance)
287
+ image_latents = image_latents.to(image_embeddings.dtype)
288
+
289
+ ref_image_latents = self._encode_vae_image(ref_image, self.device, NUM_VID, self.do_classifier_free_guidance)
290
+ ref_image_latents = ref_image_latents.to(image_embeddings.dtype)
291
+
292
+ if shape_images is not None:
293
+ shape_image_latents = self._encode_vae_image(
294
+ shape_image, self.device, NUM_VID, self.do_classifier_free_guidance
295
+ )
296
+ shape_image_latents = shape_image_latents.to(image_embeddings.dtype)
297
+
298
+ image_latents_full = [
299
+ ref_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
300
+ image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1),
301
+ ]
302
+
303
+ if shape_images is not None:
304
+ shape_image_latents = shape_image_latents.unsqueeze(1).repeat(1, NUM_FRAMES, 1, 1, 1)
305
+ image_latents_full.append(shape_image_latents)
306
+
307
+ image_latents = torch.cat(image_latents_full, dim=0)
308
+ image_latents_order = torch.tensor([0, 2, 4, 1, 3, 5]) if shape_images else torch.tensor([0, 2, 1, 3])
309
+ image_latents = image_latents[image_latents_order]
310
+
311
+ if needs_upcasting: # cast back to fp16 if needed
312
+ self.vae.to(dtype=torch.float16)
313
+
314
+ num_processes = 2 if shape_images is None else 3
315
+
316
+ # 5. Get Added Time IDs
317
+ added_time_ids = self._get_add_time_ids(
318
+ image_embeddings.dtype,
319
+ num_processes,
320
+ self.do_classifier_free_guidance,
321
+ ) # list of tensor [2, 21] or [4, 21] or [6, 21] -> just 4x the same
322
+
323
+ added_time_ids = [a.to(self.device) for a in added_time_ids]
324
+
325
+ timesteps, self.num_inference_steps = retrieve_timesteps(self.scheduler, self.num_inference_steps, self.device)
326
+
327
+ # 7. Prepare latent variables
328
+ num_channels_latents = self.unet.config.in_channels # 8
329
+ latents = self.prepare_latents(
330
+ BATCH_SIZE * num_processes,
331
+ NUM_FRAMES,
332
+ num_channels_latents,
333
+ HEIGHT,
334
+ WIDTH,
335
+ image_embeddings.dtype,
336
+ self.device,
337
+ GENERATOR,
338
+ ) # 2/3,21,4,72,72
339
+
340
+ # 8. Prepare guidance scale
341
+ guidance_scale = torch.cat(
342
+ [
343
+ torch.linspace(MIN_CFG, MAX_CFG, NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
344
+ torch.linspace(MAX_CFG, MIN_CFG, NUM_FRAMES - NUM_FRAMES // 2 + 1)[1:].unsqueeze(0),
345
+ ],
346
+ dim=-1,
347
+ )
348
+
349
+ guidance_scale = guidance_scale.to(self.device, latents.dtype)
350
+ guidance_scale = guidance_scale.repeat(BATCH_SIZE, 1)
351
+ guidance_scale = _append_dims(guidance_scale, latents.ndim) # [1,21,1,1,1]
352
+
353
+ self._guidance_scale = guidance_scale
354
+
355
+ # 9. Denoising loop
356
+ frames, intemediate_steps = self._denoise_loop(
357
+ latents=latents,
358
+ image_latents=image_latents,
359
+ image_embeddings=image_embeddings,
360
+ added_time_ids=added_time_ids,
361
+ timesteps=timesteps,
362
+ mids_active_size=mids_active_size,
363
+ z0_mid_images=z0_mid_images,
364
+ z0_reference_images=z0_reference_images,
365
+ z0_shape_images=z0_shape_images,
366
+ )
367
+
368
+ if return_intermediate_steps:
369
+ return frames, intemediate_steps
370
+
371
+ return frames
pretrained_sv3d/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
pretrained_sv3d/image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-video-diffusion-img2vid-xt",
3
+ "architectures": [
4
+ "CLIPVisionModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.45.2"
23
+ }
pretrained_sv3d/image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed1e5af7b4042ca30ec29999a4a5cfcac90b7fb610fd05ace834f2dcbb763eab
3
+ size 2528371296
pretrained_sv3d/model_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37fe3c7758e588c386817b6e681f2aaa7bc8c212d628b7c36f758e0a6d972e29
3
+ size 492
pretrained_sv3d/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.30.3",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "final_sigmas_type": "zero",
9
+ "interpolation_type": "linear",
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "v_prediction",
12
+ "rescale_betas_zero_snr": false,
13
+ "set_alpha_to_one": false,
14
+ "sigma_max": 700.0,
15
+ "sigma_min": 0.002,
16
+ "skip_prk_steps": true,
17
+ "steps_offset": 1,
18
+ "timestep_spacing": "leading",
19
+ "timestep_type": "continuous",
20
+ "trained_betas": null,
21
+ "use_karras_sigmas": true
22
+ }
pretrained_sv3d/unet/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SV3DUNetSpatioTemporalConditionModel",
3
+ "_diffusers_version": "0.30.3",
4
+ "addition_time_embed_dim": 256,
5
+ "block_out_channels": [
6
+ 320,
7
+ 640,
8
+ 1280,
9
+ 1280
10
+ ],
11
+ "cross_attention_dim": 1024,
12
+ "down_block_types": [
13
+ "CrossAttnDownBlockSpatioTemporal",
14
+ "CrossAttnDownBlockSpatioTemporal",
15
+ "CrossAttnDownBlockSpatioTemporal",
16
+ "DownBlockSpatioTemporal"
17
+ ],
18
+ "in_channels": 8,
19
+ "layers_per_block": 2,
20
+ "num_attention_heads": [
21
+ 5,
22
+ 10,
23
+ 20,
24
+ 20
25
+ ],
26
+ "num_frames": 25,
27
+ "out_channels": 4,
28
+ "projection_class_embeddings_input_dim": 256,
29
+ "sample_size": 72,
30
+ "transformer_layers_per_block": 1,
31
+ "up_block_types": [
32
+ "UpBlockSpatioTemporal",
33
+ "CrossAttnUpBlockSpatioTemporal",
34
+ "CrossAttnUpBlockSpatioTemporal",
35
+ "CrossAttnUpBlockSpatioTemporal"
36
+ ]
37
+ }
pretrained_sv3d/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00d35a0c7e024ebc55feeecf55baa039700f3d2b2d396e58d7cd0e6bbb18eedd
3
+ size 6096060984
pretrained_sv3d/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.3",
4
+ "_name_or_path": "chenguolin/stable-diffusion-v1-5",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
pretrained_sv3d/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d2b5932bb4151e54e694fd31ccf51fca908223c9485bd56cd0e1d83ad94c49
3
+ size 334643268
train.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.optim import AdamW
5
+ import torch.nn.functional as F
6
+
7
+ from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import (
8
+ StableVideo3DDiffusionPipeline,
9
+ )
10
+
11
+ # Configuration
12
+ BATCH_SIZE = 1
13
+ LR = 1e-5
14
+ NUM_EPOCHS = 10
15
+ SAVE_DIR = "checkpoints"
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d")
18
+
19
+
20
+ def train():
21
+ # Create directories
22
+ os.makedirs(SAVE_DIR, exist_ok=True)
23
+
24
+ # Create pipeline
25
+ pipeline = StableVideo3DDiffusionPipeline.from_pretrained(
26
+ SV3D_PATH,
27
+ revision="fp16",
28
+ torch_dtype=torch.float16,
29
+ )
30
+ pipeline.to(DEVICE)
31
+
32
+ # freeze unet parts - freeze everything first
33
+ for param in pipeline.unet.parameters():
34
+ param.requires_grad = False
35
+
36
+ # unfreeze only one specific layer (for example, the last output block)
37
+ for name, param in pipeline.unet.named_parameters():
38
+ if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name:
39
+ param.requires_grad = True
40
+ print(f"Unfreezing: {name}")
41
+
42
+ # Count trainable parameters
43
+ trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad)
44
+ total_params = sum(p.numel() for p in pipeline.unet.parameters())
45
+ print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})")
46
+
47
+ # Setup optimizer - only train unfrozen parameters
48
+ optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR)
49
+
50
+ # Training loop
51
+ for epoch in range(NUM_EPOCHS):
52
+ pipeline.unet.train()
53
+
54
+ # Prepare for backward pass
55
+ optimizer.zero_grad()
56
+
57
+ latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
58
+ t = 0.123
59
+ encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE)
60
+ added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE)
61
+ target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
62
+
63
+ noise_pred = pipeline.unet(
64
+ latents,
65
+ t,
66
+ encoder_hidden_states=encoder_hidden_states,
67
+ added_time_ids=[added_tim_ids],
68
+ )
69
+
70
+ print(noise_pred.shape)
71
+ # loss = F.mse_loss(noise_pred, target_noise)
72
+ # Backward pass and optimizer step
73
+ # loss.backward()
74
+ # optimizer.step()
75
+
76
+
77
+
78
+ if __name__ == "__main__":
79
+ train()