pong / src /nn /patch.py
chrisxx's picture
Add Neural Pong application files
8746765
from torch import nn
from einops import rearrange
import torch as t
class Patch(nn.Module): # adapted from https://github.com/cloneofsimo/minRF
def __init__(self, in_channels=3, out_channels=64, patch_size=2):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
dim = out_channels
if dim % 32 == 0 and dim > 32:
self.init_conv_seq = nn.Sequential(
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
)
else:
self.init_conv_seq = nn.Sequential(
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
)
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
nn.init.constant_(self.x_embedder.bias, 0)
def forward(self, x):
batch, dur, c, h, w = x.shape
x = x.reshape(-1, c, h, w)
x = self.init_conv_seq(x)
x = self.patchify(x)
x = self.x_embedder(x)
x = x.reshape(batch, dur, -1, self.out_channels)
return x
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
class UnPatch(nn.Module):
def __init__(self, height, width, in_channels=64, out_channels=3, patch_size=2):
super().__init__()
self.width = width
self.height = height
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.unpatch = nn.Linear(in_channels, out_channels*patch_size**2)
def forward(self, x):
x = self.unpatch(x)
batch, dur, seq, d = x.shape
x = x.reshape(-1, seq, d)
x = self.unpatchify(x)
x = x.reshape(batch, dur, self.out_channels, self.height, self.width)
return x
def unpatchify(self, x):
c = self.out_channels
p = self.patch_size
h = self.height // p
w = self.width // p
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = t.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs