linyq commited on
Commit
7b07a17
·
verified ·
1 Parent(s): 5c9c465

Delete ref_embedder/conditional_embedder.py

Browse files
ref_embedder/conditional_embedder.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from diffusers import ModelMixin, ConfigMixin
4
- from diffusers.configuration_utils import register_to_config
5
-
6
-
7
- class ConditionalEmbedder(ModelMixin, ConfigMixin):
8
- """
9
- Patchifies VAE-encoded conditions (source video or reference image)
10
- into the DiT hidden dimension space via a Conv3d layer.
11
- """
12
-
13
- @register_to_config
14
- def __init__(
15
- self,
16
- in_dim: int = 48,
17
- dim: int = 3072,
18
- patch_size: list = [1, 2, 2],
19
- zero_init: bool = True,
20
- ref_pad_first: bool = False,
21
- ):
22
- super().__init__()
23
- kernel_size = tuple(patch_size)
24
- self.patch_embedding = nn.Conv3d(
25
- in_dim, dim, kernel_size=kernel_size, stride=kernel_size
26
- )
27
- self.ref_pad_first = ref_pad_first
28
- if zero_init:
29
- nn.init.zeros_(self.patch_embedding.weight)
30
- nn.init.zeros_(self.patch_embedding.bias)
31
-
32
- def forward(self, x: torch.Tensor) -> torch.Tensor:
33
- return self.patch_embedding(x)