Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from .common import CausalConv3d | |
| class Upsampler(nn.Module): | |
| def __init__( | |
| self, | |
| spatial_upsample_factor: int = 1, | |
| temporal_upsample_factor: int = 1, | |
| ): | |
| super().__init__() | |
| self.spatial_upsample_factor = spatial_upsample_factor | |
| self.temporal_upsample_factor = temporal_upsample_factor | |
| class SpatialUpsampler3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__(spatial_upsample_factor=2) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") | |
| x = self.conv(x) | |
| return x | |
| class SpatialUpsamplerD2S3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__(spatial_upsample_factor=2) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 4, | |
| kernel_size=3, | |
| ) | |
| o, i, t, h, w = self.conv.weight.shape | |
| conv_weight = torch.empty(o // 4, i, t, h, w) | |
| nn.init.kaiming_normal_(conv_weight) | |
| conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") | |
| self.conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(self.conv.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(x) | |
| x = rearrange(x, "b (c p1 p2) t h w -> b c t (h p1) (w p2)", p1=2, p2=2) | |
| return x | |
| class TemporalUpsampler3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__( | |
| spatial_upsample_factor=1, | |
| temporal_upsample_factor=2, | |
| ) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.shape[2] > 1: | |
| first_frame, x = x[:, :, :1], x[:, :, 1:] | |
| x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") | |
| x = torch.cat([first_frame, x], dim=2) | |
| x = self.conv(x) | |
| return x | |
| class TemporalUpsamplerD2S3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__( | |
| spatial_upsample_factor=1, | |
| temporal_upsample_factor=2, | |
| ) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 2, | |
| kernel_size=3, | |
| ) | |
| o, i, t, h, w = self.conv.weight.shape | |
| conv_weight = torch.empty(o // 2, i, t, h, w) | |
| nn.init.kaiming_normal_(conv_weight) | |
| conv_weight = repeat(conv_weight, "o ... -> (o 2) ...") | |
| self.conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(self.conv.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(x) | |
| x = rearrange(x, "b (c p1) t h w -> b c (t p1) h w", p1=2) | |
| x = x[:, :, 1:] | |
| return x | |
| class SpatialTemporalUpsampler3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__( | |
| spatial_upsample_factor=2, | |
| temporal_upsample_factor=2, | |
| ) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| ) | |
| self.padding_flag = 0 | |
| self.set_3dgroupnorm = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") | |
| x = self.conv(x) | |
| if self.padding_flag == 0: | |
| if x.shape[2] > 1: | |
| first_frame, x = x[:, :, :1], x[:, :, 1:] | |
| x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest") | |
| x = torch.cat([first_frame, x], dim=2) | |
| elif self.padding_flag == 2 or self.padding_flag == 4 or self.padding_flag == 5 or self.padding_flag == 6: | |
| x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest") | |
| return x | |
| class SpatialTemporalUpsamplerD2S3D(Upsampler): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__( | |
| spatial_upsample_factor=2, | |
| temporal_upsample_factor=2, | |
| ) | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 8, | |
| kernel_size=3, | |
| ) | |
| o, i, t, h, w = self.conv.weight.shape | |
| conv_weight = torch.empty(o // 8, i, t, h, w) | |
| nn.init.kaiming_normal_(conv_weight) | |
| conv_weight = repeat(conv_weight, "o ... -> (o 8) ...") | |
| self.conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(self.conv.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(x) | |
| x = rearrange(x, "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", p1=2, p2=2, p3=2) | |
| x = x[:, :, 1:] | |
| return x | |