|
|
from torch import nn |
|
|
from einops import rearrange |
|
|
import torch as t |
|
|
|
|
|
class Patch(nn.Module): |
|
|
def __init__(self, in_channels=3, out_channels=64, patch_size=2): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
dim = out_channels |
|
|
if dim % 32 == 0 and dim > 32: |
|
|
self.init_conv_seq = nn.Sequential( |
|
|
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1), |
|
|
nn.SiLU(), |
|
|
nn.GroupNorm(32, dim // 2), |
|
|
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1), |
|
|
nn.SiLU(), |
|
|
nn.GroupNorm(32, dim // 2), |
|
|
) |
|
|
else: |
|
|
self.init_conv_seq = nn.Sequential( |
|
|
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1), |
|
|
nn.SiLU(), |
|
|
) |
|
|
|
|
|
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True) |
|
|
nn.init.constant_(self.x_embedder.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
batch, dur, c, h, w = x.shape |
|
|
x = x.reshape(-1, c, h, w) |
|
|
x = self.init_conv_seq(x) |
|
|
x = self.patchify(x) |
|
|
x = self.x_embedder(x) |
|
|
x = x.reshape(batch, dur, -1, self.out_channels) |
|
|
return x |
|
|
|
|
|
def patchify(self, x): |
|
|
B, C, H, W = x.size() |
|
|
x = x.view( |
|
|
B, |
|
|
C, |
|
|
H // self.patch_size, |
|
|
self.patch_size, |
|
|
W // self.patch_size, |
|
|
self.patch_size, |
|
|
) |
|
|
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) |
|
|
return x |
|
|
|
|
|
class UnPatch(nn.Module): |
|
|
def __init__(self, height, width, in_channels=64, out_channels=3, patch_size=2): |
|
|
super().__init__() |
|
|
self.width = width |
|
|
self.height = height |
|
|
self.patch_size = patch_size |
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.unpatch = nn.Linear(in_channels, out_channels*patch_size**2) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.unpatch(x) |
|
|
batch, dur, seq, d = x.shape |
|
|
x = x.reshape(-1, seq, d) |
|
|
x = self.unpatchify(x) |
|
|
x = x.reshape(batch, dur, self.out_channels, self.height, self.width) |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x): |
|
|
c = self.out_channels |
|
|
p = self.patch_size |
|
|
h = self.height // p |
|
|
w = self.width // p |
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
|
|
x = t.einsum("nhwpqc->nchpwq", x) |
|
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) |
|
|
return imgs |