| import torch | |
| from torch import nn | |
| from einops import repeat | |
| from .helper_funcs import default | |
| class PixelShuffleUpsample(nn.Module): | |
| def __init__(self, dim, dim_out=None): | |
| super().__init__() | |
| dim_out = default(dim_out, dim) | |
| conv = nn.Conv2d(dim, dim_out * 4, 1) | |
| self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2)) | |
| self.init_conv_(conv) | |
| def init_conv_(self, conv): | |
| o, i, h, w = conv.weight.shape | |
| conv_weight = torch.empty(o // 4, i, h, w) | |
| nn.init.kaiming_uniform_(conv_weight) | |
| conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") | |
| conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(conv.bias.data) | |
| def forward(self, x): | |
| return self.net(x) | |