|
|
import torch |
|
|
from typing import Optional |
|
|
from diffusers.models.attention import Attention,FeedForward |
|
|
from torch import nn |
|
|
from typing import Union, Tuple |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
def video_to_image(func): |
|
|
def wrapper(self, x, *args, **kwargs): |
|
|
if x.dim() == 5: |
|
|
t = x.shape[2] |
|
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
|
x = func(self, x, *args, **kwargs) |
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) |
|
|
return x |
|
|
return wrapper |
|
|
|
|
|
def nonlinearity(x): |
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
def cast_tuple(t, length=1): |
|
|
return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) |
|
|
class Block(nn.Module): |
|
|
def __init__(self, *args, **kwargs) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
class Conv2d(nn.Conv2d): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
kernel_size: Union[int, Tuple[int]] = 3, |
|
|
stride: Union[int, Tuple[int]] = 1, |
|
|
padding: Union[str, int, Tuple[int]] = 0, |
|
|
dilation: Union[int, Tuple[int]] = 1, |
|
|
groups: int = 1, |
|
|
bias: bool = True, |
|
|
padding_mode: str = "zeros", |
|
|
device=None, |
|
|
dtype=None, |
|
|
) -> None: |
|
|
super().__init__( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size, |
|
|
stride, |
|
|
padding, |
|
|
dilation, |
|
|
groups, |
|
|
bias, |
|
|
padding_mode, |
|
|
device, |
|
|
dtype, |
|
|
) |
|
|
|
|
|
@video_to_image |
|
|
def forward(self, x): |
|
|
return super().forward(x) |
|
|
|
|
|
|
|
|
|
|
|
class CausalConv3d(nn.Module): |
|
|
def __init__( |
|
|
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs |
|
|
): |
|
|
super().__init__() |
|
|
self.kernel_size = cast_tuple(kernel_size, 3) |
|
|
self.time_kernel_size = self.kernel_size[0] |
|
|
self.chan_in = chan_in |
|
|
self.chan_out = chan_out |
|
|
stride = kwargs.pop("stride", 1) |
|
|
padding = kwargs.pop("padding", 0) |
|
|
padding = list(cast_tuple(padding, 3)) |
|
|
padding[0] = 0 |
|
|
stride = cast_tuple(stride, 3) |
|
|
self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) |
|
|
self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0)) |
|
|
self._init_weights(init_method) |
|
|
|
|
|
def _init_weights(self, init_method): |
|
|
ks = torch.tensor(self.kernel_size) |
|
|
if init_method == "avg": |
|
|
assert ( |
|
|
self.kernel_size[1] == 1 and self.kernel_size[2] == 1 |
|
|
), "only support temporal up/down sample" |
|
|
assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" |
|
|
weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) |
|
|
|
|
|
eyes = torch.concat( |
|
|
[ |
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1/3, |
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1/3, |
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1/3, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
weight[:, :, :, 0, 0] = eyes |
|
|
|
|
|
self.conv.weight = nn.Parameter( |
|
|
weight, |
|
|
requires_grad=True, |
|
|
) |
|
|
elif init_method == "zero": |
|
|
self.conv.weight = nn.Parameter( |
|
|
torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), |
|
|
requires_grad=True, |
|
|
) |
|
|
if self.conv.bias is not None: |
|
|
nn.init.constant_(self.conv.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
first_frame_pad = x[:, :, :1, :, :].repeat( |
|
|
(1, 1, self.time_kernel_size - 1, 1, 1) |
|
|
) |
|
|
x = torch.concatenate((first_frame_pad, x), dim=2) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class CausalConv3d_GC(CausalConv3d): |
|
|
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs): |
|
|
super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) |
|
|
def forward(self, x): |
|
|
|
|
|
first_frame_pad = x[:, :, :1, :, :].repeat( |
|
|
(1, 1, self.time_kernel_size - 1, 1, 1) |
|
|
) |
|
|
x = torch.concatenate((first_frame_pad, x), dim=2) |
|
|
return checkpoint(self.conv, x) |
|
|
class ActNorm(nn.Module): |
|
|
def __init__(self, num_features, logdet=False, affine=True, |
|
|
allow_reverse_init=False): |
|
|
assert affine |
|
|
super().__init__() |
|
|
self.logdet = logdet |
|
|
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) |
|
|
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) |
|
|
self.allow_reverse_init = allow_reverse_init |
|
|
|
|
|
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) |
|
|
|
|
|
def initialize(self, input): |
|
|
with torch.no_grad(): |
|
|
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) |
|
|
mean = ( |
|
|
flatten.mean(1) |
|
|
.unsqueeze(1) |
|
|
.unsqueeze(2) |
|
|
.unsqueeze(3) |
|
|
.permute(1, 0, 2, 3) |
|
|
) |
|
|
std = ( |
|
|
flatten.std(1) |
|
|
.unsqueeze(1) |
|
|
.unsqueeze(2) |
|
|
.unsqueeze(3) |
|
|
.permute(1, 0, 2, 3) |
|
|
) |
|
|
|
|
|
self.loc.data.copy_(-mean) |
|
|
self.scale.data.copy_(1 / (std + 1e-6)) |
|
|
|
|
|
def forward(self, input, reverse=False): |
|
|
if reverse: |
|
|
return self.reverse(input) |
|
|
if len(input.shape) == 2: |
|
|
input = input[:,:,None,None] |
|
|
squeeze = True |
|
|
else: |
|
|
squeeze = False |
|
|
|
|
|
_, _, height, width = input.shape |
|
|
|
|
|
if self.training and self.initialized.item() == 0: |
|
|
self.initialize(input) |
|
|
self.initialized.fill_(1) |
|
|
|
|
|
h = self.scale * (input + self.loc) |
|
|
|
|
|
if squeeze: |
|
|
h = h.squeeze(-1).squeeze(-1) |
|
|
|
|
|
if self.logdet: |
|
|
log_abs = torch.log(torch.abs(self.scale)) |
|
|
logdet = height*width*torch.sum(log_abs) |
|
|
logdet = logdet * torch.ones(input.shape[0]).to(input) |
|
|
return h, logdet |
|
|
|
|
|
return h |
|
|
|
|
|
def reverse(self, output): |
|
|
if self.training and self.initialized.item() == 0: |
|
|
if not self.allow_reverse_init: |
|
|
raise RuntimeError( |
|
|
"Initializing ActNorm in reverse direction is " |
|
|
"disabled by default. Use allow_reverse_init=True to enable." |
|
|
) |
|
|
else: |
|
|
self.initialize(output) |
|
|
self.initialized.fill_(1) |
|
|
|
|
|
if len(output.shape) == 2: |
|
|
output = output[:,:,None,None] |
|
|
squeeze = True |
|
|
else: |
|
|
squeeze = False |
|
|
|
|
|
h = output / self.scale - self.loc |
|
|
|
|
|
if squeeze: |
|
|
h = h.squeeze(-1).squeeze(-1) |
|
|
return h |
|
|
class PatchEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
patch_size: int = 2, |
|
|
in_channels: int = 16, |
|
|
embed_dim: int = 1920, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
|
|
) |
|
|
|
|
|
def forward(self, image_embeds: torch.Tensor): |
|
|
r""" |
|
|
Args: |
|
|
image_embeds (`torch.Tensor`): |
|
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width) or (batch_size, channels, height, width) |
|
|
Returns: |
|
|
embeds (`torch.Tensor`): |
|
|
(batch_size,num_frames x height x width,embed_dim) or (batch_size,1 x height x width,embed_dim) |
|
|
""" |
|
|
if image_embeds.dim() == 5: |
|
|
batch, num_frames, channels, height, width = image_embeds.shape |
|
|
image_embeds = image_embeds.reshape(-1, channels, height, width) |
|
|
else: |
|
|
batch, channels, height, width = image_embeds.shape |
|
|
num_frames = 1 |
|
|
|
|
|
image_embeds = self.proj(image_embeds) |
|
|
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) |
|
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) |
|
|
image_embeds = image_embeds.flatten(1, 2) |
|
|
|
|
|
return image_embeds |
|
|
|
|
|
class DiscTransformer(nn.Module): |
|
|
def __init__(self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
norm_hidden_states, gate_msa = self.norm1( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output |
|
|
|
|
|
|
|
|
norm_hidden_states, gate_ff = self.norm2( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
hidden_states = hidden_states + gate_ff * ff_output |
|
|
|
|
|
return hidden_states |
|
|
class LayerNormZero(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
conditioning_dim: int, |
|
|
embedding_dim: int, |
|
|
elementwise_affine: bool = True, |
|
|
eps: float = 1e-5, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embedding_dim |
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias) |
|
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1) |
|
|
if len(hidden_states.shape) == 3: |
|
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] |
|
|
gate = gate[:, None, :] |
|
|
else: |
|
|
hidden_states = self.norm(hidden_states) * (1 + scale) + shift |
|
|
return hidden_states, gate |
|
|
class AdaLayerNorm(nn.Module): |
|
|
r""" |
|
|
Norm layer modified to incorporate timestep embeddings. |
|
|
|
|
|
Parameters: |
|
|
embedding_dim (`int`): The size of each embedding vector. |
|
|
num_embeddings (`int`, *optional*): The size of the embeddings dictionary. |
|
|
output_dim (`int`, *optional*): |
|
|
norm_elementwise_affine (`bool`, defaults to `False): |
|
|
norm_eps (`bool`, defaults to `False`): |
|
|
chunk_dim (`int`, defaults to `0`): |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int, |
|
|
num_embeddings: Optional[int] = None, |
|
|
output_dim: Optional[int] = None, |
|
|
norm_elementwise_affine: bool = False, |
|
|
norm_eps: float = 1e-5, |
|
|
chunk_dim: int = 0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.chunk_dim = chunk_dim |
|
|
output_dim = output_dim or embedding_dim * 2 |
|
|
|
|
|
if num_embeddings is not None: |
|
|
self.emb = nn.Embedding(num_embeddings, embedding_dim) |
|
|
else: |
|
|
self.emb = None |
|
|
|
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(embedding_dim, output_dim) |
|
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
if self.emb is not None: |
|
|
temb = self.emb(timestep) |
|
|
|
|
|
temb = self.linear(self.silu(temb)) |
|
|
|
|
|
if self.chunk_dim == 1: |
|
|
|
|
|
|
|
|
shift, scale = temb.chunk(2, dim=1) |
|
|
shift = shift[:, None, :] |
|
|
scale = scale[:, None, :] |
|
|
else: |
|
|
scale, shift = temb.chunk(2, dim=0) |
|
|
|
|
|
x = self.norm(x) * (1 + scale) + shift |
|
|
return x |