import math from typing import Literal import numpy as np import torch import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn # trainner-redux https://github.com/the-database/traiNNer-redux # from traiNNer.utils.registry import ARCH_REGISTRY # neosr https://github.com/neosr-project/neosr/tree/master # from neosr.archs.arch_util import net_opt # from neosr.utils.registry import ARCH_REGISTRY # # upscale, __ = net_opt() # basic sr https://github.com/XPixelGroup/BasicSR/tree/master # from basicsr.utils.registry import ARCH_REGISTRY SampleMods = Literal[ "conv", "pixelshuffledirect", "pixelshuffle", "nearest+conv", "dysample", "transpose+conv", "lda", "pa_up", ] def ICNR(tensor, initializer, upscale_factor=2, *args, **kwargs): upscale_factor_squared = upscale_factor * upscale_factor assert tensor.shape[0] % upscale_factor_squared == 0, ( "The size of the first dimension: " f"tensor.shape[0] = {tensor.shape[0]}" " is not divisible by square of upscale_factor: " f"upscale_factor = {upscale_factor}" ) sub_kernel = torch.empty( tensor.shape[0] // upscale_factor_squared, *tensor.shape[1:] ) sub_kernel = initializer(sub_kernel, *args, **kwargs) return sub_kernel.repeat_interleave(upscale_factor_squared, dim=0) class DySample(nn.Module): """Adapted from 'Learning to Upsample by Learning to Sample': https://arxiv.org/abs/2308.15085 https://github.com/tiny-smart/dysample """ def __init__( self, in_channels: int = 64, out_ch: int = 3, scale: int = 2, groups: int = 4, end_convolution: bool = True, end_kernel=1, ) -> None: super().__init__() if in_channels <= groups or in_channels % groups != 0: msg = "Incorrect in_channels and groups values." raise ValueError(msg) out_channels = 2 * groups * scale**2 self.scale = scale self.groups = groups self.end_convolution = end_convolution if end_convolution: self.end_conv = nn.Conv2d( in_channels, out_ch, end_kernel, 1, end_kernel // 2 ) self.offset = nn.Conv2d(in_channels, out_channels, 1) self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False) if self.training: nn.init.trunc_normal_(self.offset.weight, std=0.02) nn.init.constant_(self.scope.weight, val=0) self.register_buffer("init_pos", self._init_pos()) def _init_pos(self) -> Tensor: h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale return ( torch.stack(torch.meshgrid([h, h], indexing="ij")) .transpose(1, 2) .repeat(1, self.groups, 1) .reshape(1, -1, 1, 1) ) def forward(self, x: Tensor) -> Tensor: offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos B, _, H, W = offset.shape offset = offset.view(B, 2, -1, H, W) coords_h = torch.arange(H) + 0.5 coords_w = torch.arange(W) + 0.5 coords = ( torch.stack(torch.meshgrid([coords_w, coords_h], indexing="ij")) .transpose(1, 2) .unsqueeze(1) .unsqueeze(0) .type(x.dtype) .to(x.device, non_blocking=True) ) normalizer = torch.tensor( [W, H], dtype=x.dtype, device=x.device, pin_memory=True ).view(1, 2, 1, 1, 1) coords = 2 * (coords + offset) / normalizer - 1 coords = ( F.pixel_shuffle(coords.reshape(B, -1, H, W), self.scale) .view(B, 2, -1, self.scale * H, self.scale * W) .permute(0, 2, 3, 4, 1) .contiguous() .flatten(0, 1) ) output = F.grid_sample( x.reshape(B * self.groups, -1, H, W), coords, mode="bilinear", align_corners=False, padding_mode="border", ).view(B, -1, self.scale * H, self.scale * W) if self.end_convolution: output = self.end_conv(output) return output class LayerNorm(nn.Module): def __init__(self, dim: int = 64, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) self.eps = eps self.dim = (dim,) def forward(self, x): if x.is_contiguous(memory_format=torch.channels_last): return F.layer_norm( x.permute(0, 2, 3, 1), self.dim, self.weight, self.bias, self.eps ).permute(0, 3, 1, 2) u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) return self.weight[:, None, None] * x + self.bias[:, None, None] class LDA_AQU(nn.Module): def __init__( self, in_channels=48, reduction_factor=4, nh=1, scale_factor=2.0, k_e=3, k_u=3, n_groups=2, range_factor=11, rpb=True, ) -> None: super().__init__() self.k_u = k_u self.num_head = nh self.scale_factor = scale_factor self.n_groups = n_groups self.offset_range_factor = range_factor self.attn_dim = in_channels // (reduction_factor * self.num_head) self.scale = self.attn_dim**-0.5 self.rpb = rpb self.hidden_dim = in_channels // reduction_factor self.proj_q = nn.Conv2d( in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False ) self.proj_k = nn.Conv2d( in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False ) self.group_channel = in_channels // (reduction_factor * self.n_groups) # print(self.group_channel) self.conv_offset = nn.Sequential( nn.Conv2d( self.group_channel, self.group_channel, 3, 1, 1, groups=self.group_channel, bias=False, ), LayerNorm(self.group_channel), nn.SiLU(), nn.Conv2d(self.group_channel, 2 * k_u**2, k_e, 1, k_e // 2), ) print(2 * k_u**2) self.layer_norm = LayerNorm(in_channels) self.pad = int((self.k_u - 1) / 2) base = np.arange(-self.pad, self.pad + 1).astype(np.float32) base_y = np.repeat(base, self.k_u) base_x = np.tile(base, self.k_u) base_offset = np.stack([base_y, base_x], axis=1).flatten() base_offset = torch.tensor(base_offset).view(1, -1, 1, 1) self.register_buffer("base_offset", base_offset, persistent=False) if self.rpb: self.relative_position_bias_table = nn.Parameter( torch.zeros( 1, self.num_head, 1, self.k_u**2, self.hidden_dim // self.num_head ) ) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform(m) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) nn.init.constant_(self.conv_offset[-1].weight, 0) nn.init.constant_(self.conv_offset[-1].bias, 0) def get_offset(self, offset, Hout, Wout): B, _, _, _ = offset.shape device = offset.device row_indices = torch.arange(Hout, device=device) col_indices = torch.arange(Wout, device=device) row_indices, col_indices = torch.meshgrid(row_indices, col_indices) index_tensor = torch.stack((row_indices, col_indices), dim=-1).view( 1, Hout, Wout, 2 ) offset = rearrange( offset, "b (kh kw d) h w -> b kh h kw w d", kh=self.k_u, kw=self.k_u ) offset = offset + index_tensor.view(1, 1, Hout, 1, Wout, 2) offset = offset.contiguous().view(B, self.k_u * Hout, self.k_u * Wout, 2) offset[..., 0] = 2 * offset[..., 0] / (Hout - 1) - 1 offset[..., 1] = 2 * offset[..., 1] / (Wout - 1) - 1 offset = offset.flip(-1) return offset def extract_feats(self, x, offset, ks=3): out = nn.functional.grid_sample( x, offset, mode="bilinear", padding_mode="zeros", align_corners=True ) out = rearrange(out, "b c (ksh h) (ksw w) -> b (ksh ksw) c h w", ksh=ks, ksw=ks) return out def forward(self, x): B, C, H, W = x.shape out_H, out_W = int(H * self.scale_factor), int(W * self.scale_factor) v = x x = self.layer_norm(x) q = self.proj_q(x) k = self.proj_k(x) q = torch.nn.functional.interpolate( q, (out_H, out_W), mode="bilinear", align_corners=True ) q_off = q.view(B * self.n_groups, -1, out_H, out_W) pred_offset = self.conv_offset(q_off) offset = pred_offset.tanh().mul(self.offset_range_factor) + self.base_offset.to( x.dtype ) k = k.view(B * self.n_groups, self.hidden_dim // self.n_groups, H, W) v = v.view(B * self.n_groups, C // self.n_groups, H, W) offset = self.get_offset(offset, out_H, out_W) k = self.extract_feats(k, offset=offset) v = self.extract_feats(v, offset=offset) q = rearrange(q, "b (nh c) h w -> b nh (h w) () c", nh=self.num_head) k = rearrange(k, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups) v = rearrange(v, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups) k = rearrange(k, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head) v = rearrange(v, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head) if self.rpb: k = k + self.relative_position_bias_table q = q * self.scale attn = q @ k.transpose(-1, -2) attn = attn.softmax(dim=-1) out = attn @ v out = rearrange(out, "b nh (h w) t c -> b (nh c) (t h) w", h=out_H) return out class PA(nn.Module): def __init__(self, dim) -> None: super().__init__() self.conv = nn.Sequential(nn.Conv2d(dim, dim, 1), nn.Sigmoid()) def forward(self, x): return x.mul(self.conv(x)) class UniUpsampleV3(nn.Sequential): def __init__( self, upsample: SampleMods = "pa_up", scale: int = 2, in_dim: int = 48, out_dim: int = 3, mid_dim: int = 48, group: int = 4, # Only DySample dysample_end_kernel=1, # needed only for compatibility with version 2 ) -> None: m = [] if scale == 1 or upsample == "conv": m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1)) elif upsample == "pixelshuffledirect": m.extend( [nn.Conv2d(in_dim, out_dim * scale**2, 3, 1, 1), nn.PixelShuffle(scale)] ) elif upsample == "pixelshuffle": m.extend([nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)]) if (scale & (scale - 1)) == 0: # scale = 2^n for _ in range(int(math.log2(scale))): m.extend( [nn.Conv2d(mid_dim, 4 * mid_dim, 3, 1, 1), nn.PixelShuffle(2)] ) elif scale == 3: m.extend([nn.Conv2d(mid_dim, 9 * mid_dim, 3, 1, 1), nn.PixelShuffle(3)]) else: raise ValueError( f"scale {scale} is not supported. Supported scales: 2^n and 3." ) m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) elif upsample == "nearest+conv": if (scale & (scale - 1)) == 0: for _ in range(int(math.log2(scale))): m.extend( ( nn.Conv2d(in_dim, in_dim, 3, 1, 1), nn.Upsample(scale_factor=2), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) ) m.extend( ( nn.Conv2d(in_dim, in_dim, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) ) elif scale == 3: m.extend( ( nn.Conv2d(in_dim, in_dim, 3, 1, 1), nn.Upsample(scale_factor=scale), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(in_dim, in_dim, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) ) else: raise ValueError( f"scale {scale} is not supported. Supported scales: 2^n and 3." ) m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1)) elif upsample == "dysample": if mid_dim != in_dim: m.extend( [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)] ) m.append( DySample(mid_dim, out_dim, scale, group, end_kernel=dysample_end_kernel) ) # m.append(nn.Conv2d(mid_dim, out_dim, dysample_end_kernel, 1, dysample_end_kernel//2)) # kernel 1 causes chromatic artifacts elif upsample == "transpose+conv": if scale == 2: m.append(nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1)) elif scale == 3: m.append(nn.ConvTranspose2d(in_dim, out_dim, 3, 3, 0)) elif scale == 4: m.extend( [ nn.ConvTranspose2d(in_dim, in_dim, 4, 2, 1), nn.GELU(), nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1), ] ) else: raise ValueError( f"scale {scale} is not supported. Supported scales: 2, 3, 4" ) m.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1)) elif upsample == "lda": if mid_dim != in_dim: m.extend( [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)] ) m.append(LDA_AQU(mid_dim, scale_factor=scale)) m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) elif upsample == "pa_up": if (scale & (scale - 1)) == 0: for _ in range(int(math.log2(scale))): m.extend( [ nn.Upsample(scale_factor=2), nn.Conv2d(in_dim, mid_dim, 3, 1, 1), PA(mid_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(mid_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ] ) in_dim = mid_dim elif scale == 3: m.extend( [ nn.Upsample(scale_factor=3), nn.Conv2d(in_dim, mid_dim, 3, 1, 1), PA(mid_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(mid_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ] ) else: raise ValueError( f"scale {scale} is not supported. Supported scales: 2^n and 3." ) m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) else: raise ValueError( f"An invalid Upsample was selected. Please choose one of {SampleMods}" ) super().__init__(*m) self.register_buffer( "MetaUpsample", torch.tensor( [ 3, # Block version, if you change something, please number from the end so that you can distinguish between authorized changes and third parties list(SampleMods.__args__).index(upsample), # UpSample method index scale, in_dim, out_dim, mid_dim, group, ], dtype=torch.uint8, ), ) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.scale = nn.Parameter(torch.ones(dim)) self.offset = nn.Parameter(torch.zeros(dim)) self.eps = nn.Parameter(torch.Tensor(torch.ones(1) * eps), requires_grad=False) self.rms = nn.Parameter( torch.Tensor(torch.ones(1) * (dim**-0.5)), requires_grad=False ) def forward(self, x: Tensor) -> Tensor: norm_x = torch.addcmul(self.eps, x.norm(2, dim=1, keepdim=True), self.rms) return torch.addcmul( self.offset[:, None, None], x.div(norm_x), self.scale[:, None, None] ) class CustomRFFT2(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") return torch.view_as_real(y) @staticmethod def symbolic(g, x: torch.Value): shp = g.op("Shape", x) iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64)) iW = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64)) nH = g.op("Gather", shp, iH, axis_i=0) nW = g.op("Gather", shp, iW, axis_i=0) axes_last = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64)) x_u = g.op("Unsqueeze", x, axes_last) zero = g.op("Sub", x_u, x_u) x_c = g.op("Concat", x_u, zero, axis_i=4) Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT) Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT) y = g.op("DFT", x_c, nW, axis_i=3, onesided_i=1) y = g.op("Div", y, g.op("Sqrt", Wf)) y = g.op("DFT", y, nH, axis_i=2, onesided_i=0) y = g.op("Div", y, g.op("Sqrt", Hf)) return y class CustomIRFFT2(torch.autograd.Function): @staticmethod def forward(ctx, x_ri: torch.Tensor): x_c = torch.view_as_complex(x_ri) return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho") @staticmethod def symbolic(g, x: torch.Value): shp = g.op("Shape", x) iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64)) iWr = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64)) nH = g.op("Gather", shp, iH, axis_i=0) nWr = g.op("Gather", shp, iWr, axis_i=0) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64)) nW = g.op("Mul", g.op("Sub", nWr, one), two) Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT) Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT) yH = g.op("DFT", x, nH, axis_i=2, inverse_i=1, onesided_i=0) yH = g.op("Mul", yH, g.op("Sqrt", Hf)) start = g.op("Sub", nWr, two) start = g.op( "Squeeze", start, g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), ) limit = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) step = g.op("Constant", value_t=torch.tensor(-1, dtype=torch.int64)) idx_r = g.op("Range", start, limit, step) mirW = g.op("Gather", yH, idx_r, axis_i=3) maskW = g.op("Constant", value_t=torch.tensor([1.0, -1.0], dtype=torch.float32)) maskW = g.op( "Unsqueeze", maskW, g.op("Constant", value_t=torch.tensor([0, 1, 2, 3], dtype=torch.int64)), ) mirWc = g.op("Mul", mirW, maskW) x_full = g.op("Concat", yH, mirWc, axis_i=3) y = g.op("DFT", x_full, nW, axis_i=3, inverse_i=1, onesided_i=0) y = g.op("Mul", y, g.op("Sqrt", Wf)) s0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) s1 = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) axC = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64)) y = g.op("Slice", y, s0, s1, axC) y = g.op("Squeeze", y, axC) return y class CustomRfft2Wrap(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): if self.training: y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") return torch.view_as_real(y) else: return CustomRFFT2().apply(x) class CustomIrfft2Wrap(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): if self.training: x_c = torch.view_as_complex(x) # [B,C,H,Wr] return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho") # [B,C,H,W] else: return CustomIRFFT2().apply(x) class FourierUnit(nn.Module): def __init__(self, in_channels: int = 48, out_channels: int = 48) -> None: super().__init__() self.rn = RMSNorm(out_channels * 2) self.post_norm = RMSNorm(out_channels) self.fdc = nn.Conv2d( in_channels=in_channels * 2, out_channels=out_channels * 2, kernel_size=1, bias=True, ) self.fpe = nn.Conv2d( in_channels=in_channels * 2, out_channels=in_channels * 2, kernel_size=3, padding=1, groups=in_channels * 2, bias=True, ) self.gelu = nn.GELU() self.irfft2 = CustomIrfft2Wrap() self.rfft2 = CustomRfft2Wrap() def forward(self, x: Tensor) -> Tensor: orig_dtype = x.dtype x = x.to(torch.float32) b, c, h, w = x.shape ffted = self.rfft2(x) ffted = ffted.permute(0, 4, 1, 2, 3).contiguous() ffted = ffted.view(b, c * 2, h, -1).to(orig_dtype) ffted = self.rn(ffted) ffted = self.fpe(ffted) + ffted ffted = self.fdc(ffted) ffted = self.gelu(ffted) ffted = ffted.view(b, c, 2, h, -1).permute(0, 1, 3, 4, 2).contiguous().float() out = self.irfft2(ffted) out = self.post_norm(out.to(orig_dtype)) return out class InceptionConv2d(nn.Module): """Inception convolution""" def __init__( self, fu_dim: int = 24, gc: int = 8, square_kernel_size: int = 13, band_kernel_size: int = 17, ) -> None: super().__init__() self.fu = FourierUnit(fu_dim, fu_dim) self.convhw = nn.Conv2d( gc, gc, square_kernel_size, padding=square_kernel_size // 2 ) self.convw = nn.Conv2d( gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), ) self.convh = nn.Conv2d( gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), ) def forward( self, x: Tensor, x_hw: Tensor, x_w: Tensor, xh: Tensor ) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.fu(x), self.convhw(x_hw), self.convw(x_w), self.convh(xh) class GatedCNNBlock(nn.Module): def __init__( self, dim: int = 64, expansion_ratio: float = 8 / 3, gc: int = 8, square_kernel_size: int = 13, band_kernel_size: int = 17, ) -> None: super().__init__() hidden = int(expansion_ratio * dim) // 8 * 8 self.norm = RMSNorm(dim) self.fc1 = nn.Conv2d(dim, hidden * 2, 3, 1, 1) self.act = nn.SiLU() self.split_indices = [hidden, hidden - dim, dim - gc * 3, gc, gc, gc] self.conv = InceptionConv2d( dim - gc * 3, gc, square_kernel_size, band_kernel_size ) self.fc2 = nn.Conv2d(hidden, dim, 3, 1, 1) def gated_forward(self, x: Tensor) -> Tensor: x = self.norm(x) x = self.fc1(x) g, i, c, c_hw, c_w, c_h = torch.split(x, self.split_indices, dim=1) c, c_hw, c_w, c_h = self.conv(c, c_hw, c_w, c_h) x = self.fc2(self.act(g) * torch.cat((i, c, c_hw, c_w, c_h), dim=1)) return x def forward(self, x: Tensor) -> Tensor: return self.gated_forward(x) + x # @ARCH_REGISTRY.register() class FIGSR(nn.Module): """Fourier Inception Gated Super Resolution""" def __init__( self, in_nc: int = 3, dim: int = 48, expansion_ratio: float = 8 / 3, scale: int = 4, # neosr style: # scale=upscale out_nc: int = 3, upsampler: SampleMods = "pixelshuffledirect", mid_dim: int = 32, n_blocks: int = 24, gc: int = 8, square_kernel_size: int = 13, band_kernel_size: int = 17, **kwargs, ) -> None: super().__init__() self.in_to_dim = nn.Conv2d(in_nc, dim, 3, 1, 1) self.pad = 2 self.gfisr_body_half = nn.Sequential( *[ GatedCNNBlock( dim, expansion_ratio, gc, square_kernel_size, band_kernel_size ) for _ in range(n_blocks // 2) ] ) self.gfisr_body_half_2 = nn.Sequential( *[ GatedCNNBlock( dim, expansion_ratio, gc, square_kernel_size, band_kernel_size ) for _ in range(n_blocks - n_blocks // 2) ] + [nn.Conv2d(dim, dim, 3, 1, 1)] ) self.cat_to_dim = nn.Conv2d(dim * 3, dim, 1) self.upscale = UniUpsampleV3( upsampler, scale, dim, out_nc, mid_dim, dysample_end_kernel=3 ) if upsampler == "pixelshuffledirect": weight = ICNR( self.upscale[0].weight, initializer=nn.init.kaiming_normal_, upscale_factor=scale, ) self.upscale[0].weight.data.copy_(weight) self.scale = scale self.shift = nn.Parameter(torch.ones(1, 3, 1, 1) * 0.5, requires_grad=True) self.scale_norm = nn.Parameter(torch.ones(1, 3, 1, 1) / 6, requires_grad=True) def load_state_dict(self, state_dict, strict=True, assign=True): state_dict["upscale.MetaUpsample"] = self.upscale.MetaUpsample return super().load_state_dict(state_dict, strict, assign) def forward(self, x: Tensor) -> Tensor: x = (x - self.shift) / self.scale_norm _, _, H, W = x.shape mod_pad_h = (self.pad - H % self.pad) % self.pad mod_pad_w = (self.pad - W % self.pad) % self.pad x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") x = self.in_to_dim(x) x0 = self.gfisr_body_half(x) x1 = self.gfisr_body_half_2(x0) x = self.cat_to_dim(torch.cat([x1, x, x0], dim=1)) x = self.upscale(x)[:, :, : H * self.scale, : W * self.scale] return x * self.scale_norm + self.shift