Delete ref_embedder/conditional_embedder.py
Browse files
ref_embedder/conditional_embedder.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from diffusers import ModelMixin, ConfigMixin
|
| 4 |
-
from diffusers.configuration_utils import register_to_config
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class ConditionalEmbedder(ModelMixin, ConfigMixin):
|
| 8 |
-
"""
|
| 9 |
-
Patchifies VAE-encoded conditions (source video or reference image)
|
| 10 |
-
into the DiT hidden dimension space via a Conv3d layer.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
@register_to_config
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
in_dim: int = 48,
|
| 17 |
-
dim: int = 3072,
|
| 18 |
-
patch_size: list = [1, 2, 2],
|
| 19 |
-
zero_init: bool = True,
|
| 20 |
-
ref_pad_first: bool = False,
|
| 21 |
-
):
|
| 22 |
-
super().__init__()
|
| 23 |
-
kernel_size = tuple(patch_size)
|
| 24 |
-
self.patch_embedding = nn.Conv3d(
|
| 25 |
-
in_dim, dim, kernel_size=kernel_size, stride=kernel_size
|
| 26 |
-
)
|
| 27 |
-
self.ref_pad_first = ref_pad_first
|
| 28 |
-
if zero_init:
|
| 29 |
-
nn.init.zeros_(self.patch_embedding.weight)
|
| 30 |
-
nn.init.zeros_(self.patch_embedding.bias)
|
| 31 |
-
|
| 32 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
-
return self.patch_embedding(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|