VideoVAEPlus-tactile / src /models /autoencoder_temporal.py
WitneyWW's picture
Add source, configs, inference scripts
e7c18b3 verified
Raw
History Blame Contribute Delete
16 kB
import math
import torch
import torch.nn as nn
from src.modules.attention_temporal_videoae import *
from einops import rearrange, reduce, repeat
try:
import xformers
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
def silu(x):
# swish
return x * torch.sigmoid(x)
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return silu(x)
def Normalize(in_channels, norm_type="group"):
assert norm_type in ["group", "batch"]
if norm_type == "group":
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
elif norm_type == "batch":
return torch.nn.SyncBatchNorm(in_channels)
# Does not support dilation
class SamePadConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
bias=True,
padding_type="replicate",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
)
def forward(self, x):
# print(x.dtype)
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
class SamePadConvTranspose3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
bias=True,
padding_type="replicate",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.convt = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
bias=bias,
padding=tuple([k - 1 for k in kernel_size]),
)
def forward(self, x):
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
class ResBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
norm_type="group",
padding_type="replicate",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type)
self.conv1 = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type
)
self.dropout = torch.nn.Dropout(dropout)
self.norm2 = Normalize(in_channels, norm_type)
self.conv2 = SamePadConv3d(
out_channels, out_channels, kernel_size=3, padding_type=padding_type
)
if self.in_channels != self.out_channels:
self.conv_shortcut = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type
)
def forward(self, x):
h = x
h = self.norm1(h)
h = silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x)
return x + h
class SpatialCrossAttention(nn.Module):
def __init__(
self,
query_dim,
patch_size=1,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
# print(f"query dimension is {query_dim}")
self.patch_size = patch_size
patch_dim = query_dim * patch_size * patch_size
self.norm = nn.LayerNorm(patch_dim)
self.to_q = nn.Linear(patch_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, patch_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
b, c, t, height, width = x.shape
# patch: [patch_size, patch_size]
divide_factor_height = height // self.patch_size
divide_factor_width = width // self.patch_size
x = rearrange(
x,
"b c t (df1 ph) (df2 pw) -> (b t) (df1 df2) (ph pw c)",
df1=divide_factor_height,
df2=divide_factor_width,
ph=self.patch_size,
pw=self.patch_size,
)
x = self.norm(x)
context = default(context, x)
context = repeat(context, "b n d -> (b t) n d", b=b, t=t)
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.heads), (q, k, v)
)
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
mask = repeat(mask, "b j -> (b t h) () j", t=t, h=self.heads)
if XFORMERS_IS_AVAILBLE:
if exists(mask):
mask = mask.to(q.dtype)
max_neg_value = -torch.finfo(q.dtype).max
attn_bias = torch.zeros_like(mask)
attn_bias.masked_fill_(mask <= 0.5, max_neg_value)
mask = mask.detach().cpu()
attn_bias = attn_bias.expand(-1, q.shape[1], -1)
attn_bias_expansion_q = (attn_bias.shape[1] + 7) // 8 * 8
attn_bias_expansion_k = (attn_bias.shape[2] + 7) // 8 * 8
attn_bias_expansion = torch.zeros(
(attn_bias.shape[0], attn_bias_expansion_q, attn_bias_expansion_k),
dtype=attn_bias.dtype,
device=attn_bias.device,
)
attn_bias_expansion[:, : attn_bias.shape[1], : attn_bias.shape[2]] = (
attn_bias
)
attn_bias = attn_bias.detach().cpu()
out = xops.memory_efficient_attention(
q,
k,
v,
attn_bias=attn_bias_expansion[
:, : attn_bias.shape[1], : attn_bias.shape[2]
],
scale=self.scale,
)
else:
out = xops.memory_efficient_attention(q, k, v, scale=self.scale)
else:
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~(mask > 0.5), max_neg_value)
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.heads)
ret = self.to_out(out)
ret = rearrange(
ret,
"(b t) (df1 df2) (ph pw c) -> b c t (df1 ph) (df2 pw)",
b=b,
t=t,
df1=divide_factor_height,
df2=divide_factor_width,
ph=self.patch_size,
pw=self.patch_size,
)
return ret
# ---------------------------------------------------------------------------------------------------=
class EncoderTemporal1DCNN(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
attn_temporal_factor=[],
temporal_scale_factor=4,
hidden_channel=128,
**ignore_kwargs
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.temporal_scale_factor = temporal_scale_factor
# conv_in + resblock + down_block + resblock + down_block + final_block
self.conv_in = SamePadConv3d(
ch, hidden_channel, kernel_size=3, padding_type="replicate"
)
self.mid_blocks = nn.ModuleList()
num_ds = int(math.log2(temporal_scale_factor))
norm_type = "group"
curr_temporal_factor = 1
for i in range(num_ds):
block = nn.Module()
# compute in_ch, out_ch, stride
in_channels = hidden_channel * 2**i
out_channels = hidden_channel * 2 ** (i + 1)
temporal_stride = 2
curr_temporal_factor = curr_temporal_factor * 2
block.down = SamePadConv3d(
in_channels,
out_channels,
kernel_size=3,
stride=(temporal_stride, 1, 1),
padding_type="replicate",
)
block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
block.attn = nn.ModuleList()
if curr_temporal_factor in attn_temporal_factor:
block.attn.append(
SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
)
self.mid_blocks.append(block)
# n_times_downsample -= 1
self.final_block = nn.Sequential(
Normalize(out_channels, norm_type),
SiLU(),
SamePadConv3d(
out_channels, out_ch * 2, kernel_size=3, padding_type="replicate"
),
)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad_:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if isinstance(module, nn.Conv3d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x, text_embeddings=None, text_attn_mask=None):
# x: [b c t h w]
# x: [1, 4, 16, 32, 32]
# timestep embedding
h = self.conv_in(x)
for block in self.mid_blocks:
h = block.down(h)
h = block.res(h)
if len(block.attn) > 0:
for attn in block.attn:
h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
h = self.final_block(h)
return h
class TemporalUpsample(nn.Module):
def __init__(
self, size=None, scale_factor=None, mode="nearest", align_corners=None
):
super(TemporalUpsample, self).__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
)
class DecoderTemporal1DCNN(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
attn_temporal_factor=[],
temporal_scale_factor=4,
hidden_channel=128,
**ignore_kwargs
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.temporal_scale_factor = temporal_scale_factor
num_us = int(math.log2(temporal_scale_factor))
norm_type = "group"
# conv_in, mid_blocks, final_block
# out channel of encoder, before the last conv layer
enc_out_channels = hidden_channel * 2**num_us
self.conv_in = SamePadConv3d(
ch, enc_out_channels, kernel_size=3, padding_type="replicate"
)
self.mid_blocks = nn.ModuleList()
curr_temporal_factor = self.temporal_scale_factor
for i in range(num_us):
block = nn.Module()
in_channels = (
enc_out_channels if i == 0 else hidden_channel * 2 ** (num_us - i + 1)
) # max_us: 3
out_channels = hidden_channel * 2 ** (num_us - i)
temporal_stride = 2
# block.up = SamePadConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=(temporal_stride, 1, 1))
block.up = torch.nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=(2, 1, 1),
padding=(1, 1, 1),
output_padding=(1, 0, 0),
)
block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
block.attn1 = nn.ModuleList()
if curr_temporal_factor in attn_temporal_factor:
block.attn1.append(
SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
)
block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
block.attn2 = nn.ModuleList()
if curr_temporal_factor in attn_temporal_factor:
block.attn2.append(
SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
)
curr_temporal_factor = curr_temporal_factor / 2
self.mid_blocks.append(block)
self.conv_last = SamePadConv3d(out_channels, out_ch, kernel_size=3)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad_:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if isinstance(module, nn.Conv3d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if isinstance(module, nn.ConvTranspose3d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x, text_embeddings=None, text_attn_mask=None):
# x: [b c t h w]
h = self.conv_in(x)
for i, block in enumerate(self.mid_blocks):
h = block.up(h)
h = block.res1(h)
if len(block.attn1) > 0:
for attn in block.attn1:
h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
h = block.res2(h)
if len(block.attn2) > 0:
for attn in block.attn2:
h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
h = self.conv_last(h)
return h