from collections import OrderedDict import functools import math import re from typing import Union, Dict import torch import torch.nn as nn from src.UltimateSDUpscale import USDU_util class RRDB(nn.Module): """Residual in Residual Dense Block.""" def __init__(self, nf: int, kernel_size: int = 3, gc: int = 32, stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str = None, act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA", _convtype: str = "Conv2D", _spectral_norm: bool = False, plus: bool = False, c2x2: bool = False) -> None: super().__init__() args = (nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode) self.RDB1 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2) self.RDB2 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2) self.RDB3 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.RDB3(self.RDB2(self.RDB1(x))) * 0.2 + x class ResidualDenseBlock_5C(nn.Module): """Residual Dense Block with 5 Convolutions.""" def __init__(self, nf: int = 64, kernel_size: int = 3, gc: int = 32, stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str = None, act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA", plus: bool = False, c2x2: bool = False) -> None: super().__init__() self.conv1x1 = None cb = lambda inc, outc, act=act_type: USDU_util.conv_block( inc, outc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act, mode=mode, c2x2=c2x2) self.conv1 = cb(nf, gc) self.conv2 = cb(nf + gc, gc) self.conv3 = cb(nf + 2 * gc, gc) self.conv4 = cb(nf + 3 * gc, gc) self.conv5 = cb(nf + 4 * gc, nf, act=None) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDBNet(nn.Module): """ESRGAN/Real-ESRGAN upscaling network.""" def __init__(self, state_dict: Dict[str, torch.Tensor], norm: str = None, act: str = "leakyrelu", upsampler: str = "upconv", mode: USDU_util.ConvMode = "CNA") -> None: super().__init__() self.model_arch, self.sub_type = "ESRGAN", "SR" self.state, self.norm, self.act, self.upsampler, self.mode = state_dict, norm, act, upsampler, mode self.state_map = { "model.0.weight": ("conv_first.weight",), "model.0.bias": ("conv_first.bias",), "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)"), } self.num_blocks = self._get_num_blocks() self.plus = any("conv1x1" in k for k in self.state) self.state = self._new_to_old_arch(self.state) self.key_arr = list(self.state.keys()) self.in_nc = self.state[self.key_arr[0]].shape[1] self.out_nc = self.state[self.key_arr[-1]].shape[0] self.scale = self._get_scale() self.num_filters = self.state[self.key_arr[0]].shape[0] self.supports_fp16 = self.supports_bfp16 = True self.min_size_restriction = self.shuffle_factor = None ups = [USDU_util.upconv_block(self.num_filters, self.num_filters, act_type=self.act) for _ in range(int(math.log(self.scale, 2)))] cb = lambda inc, outc, act=None: USDU_util.conv_block(inc, outc, 3, norm_type=None, act_type=act) self.model = USDU_util.sequential( cb(self.in_nc, self.num_filters), USDU_util.ShortcutBlock(USDU_util.sequential( *[RRDB(self.num_filters, 3, 32, norm_type=self.norm, act_type=self.act, plus=self.plus) for _ in range(self.num_blocks)], cb(self.num_filters, self.num_filters))), *ups, cb(self.num_filters, self.num_filters, act=self.act), cb(self.num_filters, self.out_nc)) self.load_state_dict(self.state, strict=False) def _new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Convert new arch state dict to old format.""" for kind in ("weight", "bias"): self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[f"model.1.sub./NB/.{kind}"] del self.state_map[f"model.1.sub./NB/.{kind}"] old_state = OrderedDict() for old_key, new_keys in self.state_map.items(): for new_key in new_keys: if r"\1" in old_key: for k, v in state.items(): sub = re.sub(new_key, old_key, k) if sub != k: old_state[sub] = v elif new_key in state: old_state[old_key] = state[new_key] max_upconv = 0 for key in state: if m := re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key): old_state[f"model.{int(m[2]) * 3}.{m[3]}"] = state[key] max_upconv = max(max_upconv, int(m[2]) * 3) for key in state: if key in ("HRconv.weight", "conv_hr.weight"): old_state[f"model.{max_upconv + 2}.weight"] = state[key] elif key in ("HRconv.bias", "conv_hr.bias"): old_state[f"model.{max_upconv + 2}.bias"] = state[key] elif key == "conv_last.weight": old_state[f"model.{max_upconv + 4}.weight"] = state[key] elif key == "conv_last.bias": old_state[f"model.{max_upconv + 4}.bias"] = state[key] return OrderedDict(sorted(old_state.items(), key=lambda x: int(x[0].split(".")[1]))) def _get_scale(self, min_part: int = 6) -> int: """Get upscale factor.""" return 2 ** sum(1 for p in self.state if len((ps := p.split("."))[1:]) == 2 and int(ps[1]) > min_part and ps[2] == "weight") def _get_num_blocks(self) -> int: """Get number of RRDB blocks.""" state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",) for sk in state_keys: if nbs := [int(m[1]) for k in self.state if (m := re.search(sk, k))]: return max(nbs) + 1 return 1 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) PyTorchSRModels = (RRDBNet,) PyTorchSRModel = Union[RRDBNet,] PyTorchModels = (*PyTorchSRModels,) PyTorchModel = Union[PyTorchSRModel]