| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Simple implementation of AutoEncoderVQ for Cosmos3D.""" |
|
|
| import math |
|
|
| import torch |
| from einops import rearrange |
| from torch import nn |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from diffnext.models.autoencoders.modeling_utils import IdentityDistribution |
| from diffnext.models.autoencoders.modeling_utils import DecoderOutput, TilingMixin |
| from diffnext.models.autoencoders.quantizers import FSQuantizer |
| from diffnext.models.autoencoders.wavelets_utils import Patcher3D |
|
|
|
|
| class GroupNorm2D(nn.GroupNorm): |
| """2D group normalization.""" |
|
|
| def forward(self, x) -> torch.Tensor: |
| x, bsz = super().forward(x.transpose(1, 2).flatten(0, 1)), x.size(0) |
| return rearrange(x, "(b t) c h w -> b c t h w", b=bsz) |
|
|
|
|
| class Conv3d(nn.Conv3d): |
| """3D convolution.""" |
|
|
| def __init__(self, *args, **kwargs): |
| stride_t = kwargs.pop("time_stride", None) |
| super(Conv3d, self).__init__(*args, **kwargs) |
| pad_t = (self.kernel_size[0] - 1) + (1 - (stride_t or self.stride[0])) |
| self.stride = (stride_t or self.stride[0],) + self.stride[1:] |
| self.padding = (0,) + self.padding[1:] |
| self.pad = nn.ReplicationPad3d((0,) * 4 + (pad_t, 0)) |
| self.pad = nn.Identity() if self.kernel_size[0] == 1 else self.pad |
|
|
| @classmethod |
| def new_factorized(cls, dim, out_dim): |
| return nn.Sequential(cls(dim, out_dim, (1, 3, 3), 1, 1), cls(out_dim, out_dim, (3, 1, 1))) |
|
|
| def forward(self, x) -> torch.Tensor: |
| return super(Conv3d, self).forward(self.pad(x)) |
|
|
|
|
| class Attention(nn.Module): |
| """Multi-headed attention.""" |
|
|
| def __init__(self, dim, perm="(b t) 1 (h w) c"): |
| super(Attention, self).__init__() |
| self.group_norm, self.perm = GroupNorm2D(1, dim, eps=1e-6), perm |
| self.to_q, self.to_k, self.to_v = [nn.Linear(dim, dim) for _ in range(3)] |
| self.to_out = nn.ModuleList([nn.Linear(dim, dim)]) |
|
|
| @classmethod |
| def new_factorized(cls, dim) -> nn.Sequential: |
| return nn.Sequential(cls(dim, "(b t) 1 (h w) c"), cls(dim, "(b h w) 1 t c")) |
|
|
| def forward(self, x) -> torch.Tensor: |
| shortcut, x, (bsz, _, _, h, w) = x, self.group_norm(x), x.size() |
| x = rearrange(x, "b c t h w -> %s" % self.perm) |
| q, k, v = [f(x) for f in (self.to_q, self.to_k, self.to_v)] |
| o = self.to_out[0](nn.functional.scaled_dot_product_attention(q, k, v)) |
| return rearrange(o, "%s -> b c t h w" % self.perm, b=bsz, h=h, w=w).add_(shortcut) |
|
|
|
|
| class Resize(nn.Module): |
| """Downsample layer.""" |
|
|
| def __init__(self, dim, spatial=1, temporal=1): |
| super(Resize, self).__init__() |
| self.spatial, self.temporal = spatial, temporal |
| self.conv1, self.conv2 = nn.Identity(), nn.Identity() |
| if spatial == 1 or temporal == 1: |
| self.conv1 = Conv3d(dim, dim, (1, 3, 3), 2, time_stride=1) |
| self.conv2 = Conv3d(dim, dim, (3, 1, 1), 1, time_stride=2) if temporal else self.conv2 |
| elif spatial == 2 or temporal == 2: |
| self.conv1 = Conv3d(dim, dim, (3, 1, 1), 1, 0) if temporal else self.conv1 |
| self.conv2 = Conv3d(dim, dim, (1, 3, 3), 1, 1) |
| self.conv3 = Conv3d(dim, dim, 1) if spatial or temporal else nn.Identity() |
|
|
| def forward(self, x) -> torch.Tensor: |
| if self.spatial == 1: |
| _ = nn.functional.avg_pool3d(x, (1, 2, 2), (1, 2, 2)) |
| x = self.conv1(nn.functional.pad(x, (0, 1, 0, 1, 0, 0))).add_(_) |
| if self.temporal == 1: |
| x = nn.functional.pad(x, (0, 0, 0, 0, 1, 0), "replicate") |
| x = self.conv2(x).add_(nn.functional.avg_pool3d(x, (2, 1, 1), (2, 1, 1))) |
| if self.temporal == 2: |
| x = x.repeat_interleave(2, dim=2)[:, :, 1:] |
| x = self.conv1(x).add_(x) |
| if self.spatial == 2: |
| x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) |
| x = self.conv2(x).add_(x) |
| return self.conv3(x) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Resnet block.""" |
|
|
| def __init__(self, dim, out_dim): |
| super(ResBlock, self).__init__() |
| self.norm1 = GroupNorm2D(1, dim, eps=1e-6) |
| self.conv1 = Conv3d.new_factorized(dim, out_dim) |
| self.norm2 = GroupNorm2D(1, out_dim, eps=1e-6) |
| self.conv2 = Conv3d.new_factorized(out_dim, out_dim) |
| self.conv_shortcut = Conv3d(dim, out_dim, 1) if out_dim != dim else None |
| self.nonlinearity, self.dropout = nn.SiLU(), nn.Dropout(0) |
|
|
| def forward(self, x) -> torch.Tensor: |
| shortcut = self.conv_shortcut(x) if self.conv_shortcut else x |
| x = self.conv1(self.nonlinearity(self.norm1(x))) |
| return self.conv2(self.nonlinearity(self.norm2(x))).add_(shortcut) |
|
|
|
|
| class UNetResBlock(nn.Module): |
| """UNet resnet block.""" |
|
|
| def __init__(self, dim, out_dim, depth=2, downsample=None, upsample=None): |
| super(UNetResBlock, self).__init__() |
| block_dims = [(out_dim, out_dim) if i > 0 else (dim, out_dim) for i in range(depth)] |
| self.resnets = nn.ModuleList(ResBlock(*dims) for dims in block_dims) |
| self.downsamplers = nn.ModuleList([Resize(out_dim, *downsample)]) if downsample else [] |
| self.upsamplers = nn.ModuleList([Resize(out_dim, *upsample)]) if upsample else [] |
|
|
| def forward(self, x) -> torch.Tensor: |
| for resnet in self.resnets: |
| x = resnet(x) |
| x = self.downsamplers[0](x) if self.downsamplers else x |
| return self.upsamplers[0](x) if self.upsamplers else x |
|
|
|
|
| class UNetMidBlock(nn.Module): |
| """UNet mid block.""" |
|
|
| def __init__(self, dim, depth=1): |
| super(UNetMidBlock, self).__init__() |
| self.resnets = nn.ModuleList(ResBlock(dim, dim) for _ in range(depth + 1)) |
| self.attentions = nn.ModuleList(Attention.new_factorized(dim) for _ in range(depth)) |
|
|
| def forward(self, x) -> torch.Tensor: |
| x = self.resnets[0](x) |
| for attn, resnet in zip(self.attentions, self.resnets[1:]): |
| x = resnet(attn(x)) |
| return x |
|
|
|
|
| class Encoder(nn.Module): |
| """AE encoder.""" |
|
|
| def __init__( |
| self, |
| dim, |
| out_dim, |
| block_dims, |
| block_depth, |
| patch_size=4, |
| temporal_stride=8, |
| spatial_stride=8, |
| ): |
| super(Encoder, self).__init__() |
| spatial_downs = int(math.log2(spatial_stride)) - int(math.log2(patch_size)) |
| temporal_downs = int(math.log2(temporal_stride)) - int(math.log2(patch_size)) |
| self.patcher = Patcher3D(patch_size) |
| self.conv_in = Conv3d.new_factorized(dim * patch_size**3, block_dims[0]) |
| self.down_blocks = nn.ModuleList() |
| for i, dim in enumerate(block_dims[:-1]): |
| downsample, block_dim = None, block_dims[i + 1] |
| if i < len(block_dims) - 2: |
| downsample = int(i < spatial_downs), int(i < temporal_downs) |
| args = (dim, block_dim, block_depth) |
| self.down_blocks += [UNetResBlock(*args, downsample=downsample)] |
| self.mid_block = UNetMidBlock(block_dim) |
| self.conv_norm_out, self.conv_act = GroupNorm2D(1, block_dim, eps=1e-6), nn.SiLU() |
| self.conv_out = Conv3d.new_factorized(block_dim, out_dim) |
|
|
| def forward(self, x) -> torch.Tensor: |
| x = torch.cat([x[:, :, :1].repeat_interleave(self.patcher.patch_size, 2), x[:, :, 1:]], 2) |
| for _ in range(self.patcher.num_strides): |
| x = self.patcher.dwt(x) |
| x = self.conv_in(x) |
| for blk in self.down_blocks: |
| x = blk(x) |
| x = self.mid_block(x) |
| return self.conv_out(self.conv_act(self.conv_norm_out(x))) |
|
|
|
|
| class Decoder(nn.Module): |
| """AE decoder.""" |
|
|
| def __init__( |
| self, |
| dim, |
| out_dim, |
| block_dims, |
| block_depth, |
| patch_size=4, |
| temporal_stride=8, |
| spatial_stride=8, |
| ): |
| super(Decoder, self).__init__() |
| block_dims = list(reversed(block_dims)) |
| spatial_ups = int(math.log2(spatial_stride)) - int(math.log2(patch_size)) |
| temporal_ups = int(math.log2(temporal_stride)) - int(math.log2(patch_size)) |
| self.patcher = Patcher3D(patch_size) |
| self.conv_in = Conv3d.new_factorized(dim, block_dims[0]) |
| self.mid_block = UNetMidBlock(block_dims[0]) |
| self.up_blocks = nn.ModuleList() |
| for i, block_dim in enumerate(block_dims[:-1]): |
| upsample, dim = None, block_dims[max(i - 1, 0)] |
| if i < len(block_dims) - 2: |
| temporal = 0 < i < temporal_ups + 1 |
| spatial = temporal or (i < spatial_ups and spatial_ups > temporal_ups) |
| upsample = (2 if spatial else 0, 2 if temporal else 0) |
| args = (dim, block_dim, block_depth + 1) |
| self.up_blocks += [UNetResBlock(*args, upsample=upsample)] |
| self.conv_norm_out, self.conv_act = GroupNorm2D(1, block_dim, eps=1e-6), nn.SiLU() |
| self.conv_out = Conv3d.new_factorized(block_dim, out_dim * patch_size**3) |
|
|
| def forward(self, x) -> torch.Tensor: |
| x = self.conv_in(x) |
| x = self.mid_block(x) |
| for blk in self.up_blocks: |
| x = blk(x) |
| x = self.conv_out(self.conv_act(self.conv_norm_out(x))) |
| for _ in range(self.patcher.num_strides): |
| x = self.patcher.idwt(x) |
| return x[:, :, self.patcher.patch_size - 1 :] |
|
|
|
|
| class AutoencoderVQCosmos3D(ModelMixin, ConfigMixin, TilingMixin): |
| """AutoEncoder VQ.""" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels=3, |
| out_channels=3, |
| down_block_types=("DownEncoderBlock3D",) * 3, |
| up_block_types=("UpDecoderBlock3D",) * 3, |
| block_out_channels=(128, 256, 512, 512), |
| layers_per_block=2, |
| act_fn="silu", |
| latent_channels=16, |
| norm_num_groups=1, |
| sample_size=1024, |
| sample_frames=17, |
| num_vq_embeddings=64000, |
| vq_embed_dim=6, |
| force_upcast=False, |
| patch_size=4, |
| temporal_stride=4, |
| spatial_stride=8, |
| _quantizer_name="FSQuantizer", |
| ): |
| super(AutoencoderVQCosmos3D, self).__init__() |
| latent_min_t = (sample_frames - 1) // temporal_stride + 1 |
| TilingMixin.__init__(self, sample_frames, latent_min_t=latent_min_t, sample_ovr_t=1) |
| extra_args = {"patch_size": patch_size} |
| extra_args.update({"temporal_stride": temporal_stride, "spatial_stride": spatial_stride}) |
| channels, layers = block_out_channels, layers_per_block |
| self.encoder = Encoder(in_channels, latent_channels, channels, layers, **extra_args) |
| self.decoder = Decoder(latent_channels, out_channels, channels, layers, **extra_args) |
| self.quant_conv = Conv3d(latent_channels, vq_embed_dim, 1) |
| self.post_quant_conv = Conv3d(vq_embed_dim, latent_channels, 1) |
| self.quantizer, self.latent_dist = FSQuantizer(), IdentityDistribution |
|
|
| def scale_(self, x) -> torch.Tensor: |
| """Scale the input latents.""" |
| return x |
|
|
| def unscale_(self, x) -> torch.Tensor: |
| """Unscale the input latents.""" |
| return x |
|
|
| def encode(self, x) -> AutoencoderKLOutput: |
| """Encode the input samples.""" |
| z = self.tiled_encoder(self.forward(x)) |
| z = self.quant_conv(z) |
| posterior = self.latent_dist(self.quantizer.quantize(z)) |
| return AutoencoderKLOutput(latent_dist=posterior) |
|
|
| def decode(self, ids) -> DecoderOutput: |
| """Decode the input indices.""" |
| z = self.quantizer.dequantize(ids) |
| extra_dim = 2 if z.dim() == 4 else None |
| z = z.unsqueeze_(extra_dim) if extra_dim is not None else z |
| z = self.post_quant_conv(self.forward(z)) |
| x = self.tiled_decoder(z) |
| x = x.squeeze_(extra_dim) if extra_dim is not None else x |
| return DecoderOutput(sample=x) |
|
|
| def forward(self, x): |
| return x |
|
|