import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class FrameEncoder(nn.Module): def __init__(self, config: dict) -> None: super().__init__() assert len(config["mult"]) == len(config["down"]) encoder_layers = [nn.Conv2d(config["image_channels"], config["num_channels"], kernel_size=3, stride=1, padding=1)] input_channels = config["num_channels"] for m, d in zip(config["mult"], config["down"]): output_channels = m * config["num_channels"] encoder_layers.append(ResidualBlock(input_channels, output_channels)) input_channels = output_channels if d: encoder_layers.append(Downsample(output_channels)) encoder_layers.extend([ nn.GroupNorm(num_groups=32, num_channels=input_channels), nn.SiLU(inplace=True), nn.Conv2d(input_channels, config["latent_dim"], kernel_size=3, stride=1, padding=1) ]) self.encoder = nn.Sequential(*encoder_layers) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: b, t, _, _, _ = x.size() x = rearrange(x, 'b t c h w -> (b t) c h w') x = self.encoder(x) x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t) return x class FrameDecoder(nn.Module): def __init__(self, config: dict) -> None: super().__init__() assert len(config["mult"]) == len(config["down"]) decoder_layers = [] output_channels = config["num_channels"] for m, d in zip(config["mult"], config["down"]): input_channels = m * config["num_channels"] decoder_layers.append(ResidualBlock(input_channels, output_channels)) output_channels = input_channels if d: decoder_layers.append(Upsample(input_channels)) decoder_layers.reverse() decoder_layers.insert(0, nn.Conv2d(config["latent_dim"], input_channels, kernel_size=3, stride=1, padding=1)) decoder_layers.extend([ nn.GroupNorm(num_groups=32, num_channels=config["num_channels"]), nn.SiLU(inplace=True), nn.Conv2d(config["num_channels"], config["image_channels"], kernel_size=3, stride=1, padding=1) ]) self.decoder = nn.Sequential(*decoder_layers) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: b, t, _, _, _ = x.size() x = rearrange(x, 'b t c h w -> (b t) c h w') x = self.decoder(x) x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t) return x class ResidualBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, num_groups_norm: int = 32) -> None: super().__init__() self.f = nn.Sequential( nn.GroupNorm(num_groups_norm, in_channels), nn.SiLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.GroupNorm(num_groups_norm, out_channels), nn.SiLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) self.skip_projection = nn.Identity() if in_channels == out_channels else torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.skip_projection(x) + self.f(x) class Downsample(nn.Module): def __init__(self, num_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=2, stride=2, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class Upsample(nn.Module): def __init__(self, num_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, scale_factor=2.0, mode="nearest") return self.conv(x)