semo / losses /modules.py
HappyP4nda's picture
Upload folder using huggingface_hub
bd546bf verified
import torch
from typing import Optional
from diffusers.models.attention import Attention,FeedForward
from torch import nn
from typing import Union, Tuple
import torch.nn.functional as F
from einops import rearrange
from torch.utils.checkpoint import checkpoint
def video_to_image(func):
def wrapper(self, x, *args, **kwargs):
if x.dim() == 5:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = func(self, x, *args, **kwargs)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
return wrapper
def nonlinearity(x):
return x * torch.sigmoid(x)
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length)
class Block(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class Conv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]] = 3,
stride: Union[int, Tuple[int]] = 1,
padding: Union[str, int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
@video_to_image
def forward(self, x):
return super().forward(x)
class CausalConv3d(nn.Module):
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
self.time_kernel_size = self.kernel_size[0]
self.chan_in = chan_in
self.chan_out = chan_out
stride = kwargs.pop("stride", 1)
padding = kwargs.pop("padding", 0)
padding = list(cast_tuple(padding, 3))
padding[0] = 0
stride = cast_tuple(stride, 3)
self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0))
self._init_weights(init_method)
def _init_weights(self, init_method):
ks = torch.tensor(self.kernel_size)
if init_method == "avg":
assert (
self.kernel_size[1] == 1 and self.kernel_size[2] == 1
), "only support temporal up/down sample"
assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
eyes = torch.concat(
[
torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
],
dim=-1,
)
weight[:, :, :, 0, 0] = eyes
self.conv.weight = nn.Parameter(
weight,
requires_grad=True,
)
elif init_method == "zero":
self.conv.weight = nn.Parameter(
torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
requires_grad=True,
)
if self.conv.bias is not None:
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
# 1 + 16 16 as video, 1 as image
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
) # b c t h w
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
return self.conv(x)
class CausalConv3d_GC(CausalConv3d):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs):
super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs)
def forward(self, x):
# 1 + 16 16 as video, 1 as image
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
) # b c t h w
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
return checkpoint(self.conv, x)
class ActNorm(nn.Module):
def __init__(self, num_features, logdet=False, affine=True,
allow_reverse_init=False):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:,:,None,None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height*width*torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:,:,None,None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
def forward(self, image_embeds: torch.Tensor):
r"""
Args:
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width) or (batch_size, channels, height, width)
Returns:
embeds (`torch.Tensor`):
(batch_size,num_frames x height x width,embed_dim) or (batch_size,1 x height x width,embed_dim)
"""
if image_embeds.dim() == 5:
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
else:
batch, channels, height, width = image_embeds.shape
num_frames = 1
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
return image_embeds # [batch, num_frames x height x width, channels]
class DiscTransformer(nn.Module):
def __init__(self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
)
# 2. Feed Forward
self.norm2 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
norm_hidden_states, gate_msa = self.norm1(
hidden_states, temb
)
# attention
attn_output = self.attn1(
hidden_states=norm_hidden_states,
)
hidden_states = hidden_states + gate_msa * attn_output
# norm & modulate
norm_hidden_states, gate_ff = self.norm2(
hidden_states, temb
)
# feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output
return hidden_states
class LayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embedding_dim
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1)
if len(hidden_states.shape) == 3:
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
gate = gate[:, None, :]
else:
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states, gate
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x