World_Model / URSA /diffnext /models /autoencoders /autoencoder_vq_cosmos3d.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Simple implementation of AutoEncoderVQ for Cosmos3D."""
import math
import torch
from einops import rearrange
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffnext.models.autoencoders.modeling_utils import IdentityDistribution
from diffnext.models.autoencoders.modeling_utils import DecoderOutput, TilingMixin
from diffnext.models.autoencoders.quantizers import FSQuantizer
from diffnext.models.autoencoders.wavelets_utils import Patcher3D
class GroupNorm2D(nn.GroupNorm):
"""2D group normalization."""
def forward(self, x) -> torch.Tensor:
x, bsz = super().forward(x.transpose(1, 2).flatten(0, 1)), x.size(0)
return rearrange(x, "(b t) c h w -> b c t h w", b=bsz)
class Conv3d(nn.Conv3d):
"""3D convolution."""
def __init__(self, *args, **kwargs):
stride_t = kwargs.pop("time_stride", None)
super(Conv3d, self).__init__(*args, **kwargs)
pad_t = (self.kernel_size[0] - 1) + (1 - (stride_t or self.stride[0]))
self.stride = (stride_t or self.stride[0],) + self.stride[1:]
self.padding = (0,) + self.padding[1:]
self.pad = nn.ReplicationPad3d((0,) * 4 + (pad_t, 0))
self.pad = nn.Identity() if self.kernel_size[0] == 1 else self.pad
@classmethod
def new_factorized(cls, dim, out_dim):
return nn.Sequential(cls(dim, out_dim, (1, 3, 3), 1, 1), cls(out_dim, out_dim, (3, 1, 1)))
def forward(self, x) -> torch.Tensor:
return super(Conv3d, self).forward(self.pad(x))
class Attention(nn.Module):
"""Multi-headed attention."""
def __init__(self, dim, perm="(b t) 1 (h w) c"):
super(Attention, self).__init__()
self.group_norm, self.perm = GroupNorm2D(1, dim, eps=1e-6), perm
self.to_q, self.to_k, self.to_v = [nn.Linear(dim, dim) for _ in range(3)]
self.to_out = nn.ModuleList([nn.Linear(dim, dim)])
@classmethod
def new_factorized(cls, dim) -> nn.Sequential:
return nn.Sequential(cls(dim, "(b t) 1 (h w) c"), cls(dim, "(b h w) 1 t c"))
def forward(self, x) -> torch.Tensor:
shortcut, x, (bsz, _, _, h, w) = x, self.group_norm(x), x.size()
x = rearrange(x, "b c t h w -> %s" % self.perm)
q, k, v = [f(x) for f in (self.to_q, self.to_k, self.to_v)]
o = self.to_out[0](nn.functional.scaled_dot_product_attention(q, k, v))
return rearrange(o, "%s -> b c t h w" % self.perm, b=bsz, h=h, w=w).add_(shortcut)
class Resize(nn.Module):
"""Downsample layer."""
def __init__(self, dim, spatial=1, temporal=1):
super(Resize, self).__init__()
self.spatial, self.temporal = spatial, temporal
self.conv1, self.conv2 = nn.Identity(), nn.Identity()
if spatial == 1 or temporal == 1: # Down.
self.conv1 = Conv3d(dim, dim, (1, 3, 3), 2, time_stride=1)
self.conv2 = Conv3d(dim, dim, (3, 1, 1), 1, time_stride=2) if temporal else self.conv2
elif spatial == 2 or temporal == 2: # Up.
self.conv1 = Conv3d(dim, dim, (3, 1, 1), 1, 0) if temporal else self.conv1
self.conv2 = Conv3d(dim, dim, (1, 3, 3), 1, 1)
self.conv3 = Conv3d(dim, dim, 1) if spatial or temporal else nn.Identity()
def forward(self, x) -> torch.Tensor:
if self.spatial == 1:
_ = nn.functional.avg_pool3d(x, (1, 2, 2), (1, 2, 2))
x = self.conv1(nn.functional.pad(x, (0, 1, 0, 1, 0, 0))).add_(_)
if self.temporal == 1:
x = nn.functional.pad(x, (0, 0, 0, 0, 1, 0), "replicate")
x = self.conv2(x).add_(nn.functional.avg_pool3d(x, (2, 1, 1), (2, 1, 1)))
if self.temporal == 2:
x = x.repeat_interleave(2, dim=2)[:, :, 1:]
x = self.conv1(x).add_(x)
if self.spatial == 2:
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
x = self.conv2(x).add_(x)
return self.conv3(x)
class ResBlock(nn.Module):
"""Resnet block."""
def __init__(self, dim, out_dim):
super(ResBlock, self).__init__()
self.norm1 = GroupNorm2D(1, dim, eps=1e-6)
self.conv1 = Conv3d.new_factorized(dim, out_dim)
self.norm2 = GroupNorm2D(1, out_dim, eps=1e-6)
self.conv2 = Conv3d.new_factorized(out_dim, out_dim)
self.conv_shortcut = Conv3d(dim, out_dim, 1) if out_dim != dim else None
self.nonlinearity, self.dropout = nn.SiLU(), nn.Dropout(0)
def forward(self, x) -> torch.Tensor:
shortcut = self.conv_shortcut(x) if self.conv_shortcut else x
x = self.conv1(self.nonlinearity(self.norm1(x)))
return self.conv2(self.nonlinearity(self.norm2(x))).add_(shortcut)
class UNetResBlock(nn.Module):
"""UNet resnet block."""
def __init__(self, dim, out_dim, depth=2, downsample=None, upsample=None):
super(UNetResBlock, self).__init__()
block_dims = [(out_dim, out_dim) if i > 0 else (dim, out_dim) for i in range(depth)]
self.resnets = nn.ModuleList(ResBlock(*dims) for dims in block_dims)
self.downsamplers = nn.ModuleList([Resize(out_dim, *downsample)]) if downsample else []
self.upsamplers = nn.ModuleList([Resize(out_dim, *upsample)]) if upsample else []
def forward(self, x) -> torch.Tensor:
for resnet in self.resnets:
x = resnet(x)
x = self.downsamplers[0](x) if self.downsamplers else x
return self.upsamplers[0](x) if self.upsamplers else x
class UNetMidBlock(nn.Module):
"""UNet mid block."""
def __init__(self, dim, depth=1):
super(UNetMidBlock, self).__init__()
self.resnets = nn.ModuleList(ResBlock(dim, dim) for _ in range(depth + 1))
self.attentions = nn.ModuleList(Attention.new_factorized(dim) for _ in range(depth))
def forward(self, x) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = resnet(attn(x))
return x
class Encoder(nn.Module):
"""AE encoder."""
def __init__(
self,
dim,
out_dim,
block_dims,
block_depth,
patch_size=4,
temporal_stride=8,
spatial_stride=8,
):
super(Encoder, self).__init__()
spatial_downs = int(math.log2(spatial_stride)) - int(math.log2(patch_size))
temporal_downs = int(math.log2(temporal_stride)) - int(math.log2(patch_size))
self.patcher = Patcher3D(patch_size)
self.conv_in = Conv3d.new_factorized(dim * patch_size**3, block_dims[0])
self.down_blocks = nn.ModuleList()
for i, dim in enumerate(block_dims[:-1]):
downsample, block_dim = None, block_dims[i + 1]
if i < len(block_dims) - 2:
downsample = int(i < spatial_downs), int(i < temporal_downs)
args = (dim, block_dim, block_depth)
self.down_blocks += [UNetResBlock(*args, downsample=downsample)]
self.mid_block = UNetMidBlock(block_dim)
self.conv_norm_out, self.conv_act = GroupNorm2D(1, block_dim, eps=1e-6), nn.SiLU()
self.conv_out = Conv3d.new_factorized(block_dim, out_dim)
def forward(self, x) -> torch.Tensor:
x = torch.cat([x[:, :, :1].repeat_interleave(self.patcher.patch_size, 2), x[:, :, 1:]], 2)
for _ in range(self.patcher.num_strides):
x = self.patcher.dwt(x)
x = self.conv_in(x)
for blk in self.down_blocks:
x = blk(x)
x = self.mid_block(x)
return self.conv_out(self.conv_act(self.conv_norm_out(x)))
class Decoder(nn.Module):
"""AE decoder."""
def __init__(
self,
dim,
out_dim,
block_dims,
block_depth,
patch_size=4,
temporal_stride=8,
spatial_stride=8,
):
super(Decoder, self).__init__()
block_dims = list(reversed(block_dims))
spatial_ups = int(math.log2(spatial_stride)) - int(math.log2(patch_size))
temporal_ups = int(math.log2(temporal_stride)) - int(math.log2(patch_size))
self.patcher = Patcher3D(patch_size)
self.conv_in = Conv3d.new_factorized(dim, block_dims[0])
self.mid_block = UNetMidBlock(block_dims[0])
self.up_blocks = nn.ModuleList()
for i, block_dim in enumerate(block_dims[:-1]):
upsample, dim = None, block_dims[max(i - 1, 0)]
if i < len(block_dims) - 2:
temporal = 0 < i < temporal_ups + 1
spatial = temporal or (i < spatial_ups and spatial_ups > temporal_ups)
upsample = (2 if spatial else 0, 2 if temporal else 0)
args = (dim, block_dim, block_depth + 1)
self.up_blocks += [UNetResBlock(*args, upsample=upsample)]
self.conv_norm_out, self.conv_act = GroupNorm2D(1, block_dim, eps=1e-6), nn.SiLU()
self.conv_out = Conv3d.new_factorized(block_dim, out_dim * patch_size**3)
def forward(self, x) -> torch.Tensor:
x = self.conv_in(x)
x = self.mid_block(x)
for blk in self.up_blocks:
x = blk(x)
x = self.conv_out(self.conv_act(self.conv_norm_out(x)))
for _ in range(self.patcher.num_strides):
x = self.patcher.idwt(x)
return x[:, :, self.patcher.patch_size - 1 :]
class AutoencoderVQCosmos3D(ModelMixin, ConfigMixin, TilingMixin):
"""AutoEncoder VQ."""
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock3D",) * 3,
up_block_types=("UpDecoderBlock3D",) * 3,
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=16,
norm_num_groups=1,
sample_size=1024,
sample_frames=17,
num_vq_embeddings=64000,
vq_embed_dim=6,
force_upcast=False,
patch_size=4,
temporal_stride=4,
spatial_stride=8,
_quantizer_name="FSQuantizer",
):
super(AutoencoderVQCosmos3D, self).__init__()
latent_min_t = (sample_frames - 1) // temporal_stride + 1
TilingMixin.__init__(self, sample_frames, latent_min_t=latent_min_t, sample_ovr_t=1)
extra_args = {"patch_size": patch_size}
extra_args.update({"temporal_stride": temporal_stride, "spatial_stride": spatial_stride})
channels, layers = block_out_channels, layers_per_block
self.encoder = Encoder(in_channels, latent_channels, channels, layers, **extra_args)
self.decoder = Decoder(latent_channels, out_channels, channels, layers, **extra_args)
self.quant_conv = Conv3d(latent_channels, vq_embed_dim, 1)
self.post_quant_conv = Conv3d(vq_embed_dim, latent_channels, 1)
self.quantizer, self.latent_dist = FSQuantizer(), IdentityDistribution
def scale_(self, x) -> torch.Tensor:
"""Scale the input latents."""
return x
def unscale_(self, x) -> torch.Tensor:
"""Unscale the input latents."""
return x
def encode(self, x) -> AutoencoderKLOutput:
"""Encode the input samples."""
z = self.tiled_encoder(self.forward(x))
z = self.quant_conv(z)
posterior = self.latent_dist(self.quantizer.quantize(z))
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, ids) -> DecoderOutput:
"""Decode the input indices."""
z = self.quantizer.dequantize(ids)
extra_dim = 2 if z.dim() == 4 else None
z = z.unsqueeze_(extra_dim) if extra_dim is not None else z
z = self.post_quant_conv(self.forward(z))
x = self.tiled_decoder(z)
x = x.squeeze_(extra_dim) if extra_dim is not None else x
return DecoderOutput(sample=x)
def forward(self, x): # NOOP.
return x