Spaces:
Running on Zero
Running on Zero
File size: 7,136 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | from collections import OrderedDict
import functools
import math
import re
from typing import Union, Dict
import torch
import torch.nn as nn
from src.UltimateSDUpscale import USDU_util
class RRDB(nn.Module):
"""Residual in Residual Dense Block."""
def __init__(self, nf: int, kernel_size: int = 3, gc: int = 32, stride: int = 1,
bias: bool = True, pad_type: str = "zero", norm_type: str = None,
act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA",
_convtype: str = "Conv2D", _spectral_norm: bool = False,
plus: bool = False, c2x2: bool = False) -> None:
super().__init__()
args = (nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
self.RDB1 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2)
self.RDB2 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2)
self.RDB3 = ResidualDenseBlock_5C(*args, plus=plus, c2x2=c2x2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.RDB3(self.RDB2(self.RDB1(x))) * 0.2 + x
class ResidualDenseBlock_5C(nn.Module):
"""Residual Dense Block with 5 Convolutions."""
def __init__(self, nf: int = 64, kernel_size: int = 3, gc: int = 32, stride: int = 1,
bias: bool = True, pad_type: str = "zero", norm_type: str = None,
act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA",
plus: bool = False, c2x2: bool = False) -> None:
super().__init__()
self.conv1x1 = None
cb = lambda inc, outc, act=act_type: USDU_util.conv_block(
inc, outc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act, mode=mode, c2x2=c2x2)
self.conv1 = cb(nf, gc)
self.conv2 = cb(nf + gc, gc)
self.conv3 = cb(nf + 2 * gc, gc)
self.conv4 = cb(nf + 3 * gc, gc)
self.conv5 = cb(nf + 4 * gc, nf, act=None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDBNet(nn.Module):
"""ESRGAN/Real-ESRGAN upscaling network."""
def __init__(self, state_dict: Dict[str, torch.Tensor], norm: str = None,
act: str = "leakyrelu", upsampler: str = "upconv",
mode: USDU_util.ConvMode = "CNA") -> None:
super().__init__()
self.model_arch, self.sub_type = "ESRGAN", "SR"
self.state, self.norm, self.act, self.upsampler, self.mode = state_dict, norm, act, upsampler, mode
self.state_map = {
"model.0.weight": ("conv_first.weight",),
"model.0.bias": ("conv_first.bias",),
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)"),
}
self.num_blocks = self._get_num_blocks()
self.plus = any("conv1x1" in k for k in self.state)
self.state = self._new_to_old_arch(self.state)
self.key_arr = list(self.state.keys())
self.in_nc = self.state[self.key_arr[0]].shape[1]
self.out_nc = self.state[self.key_arr[-1]].shape[0]
self.scale = self._get_scale()
self.num_filters = self.state[self.key_arr[0]].shape[0]
self.supports_fp16 = self.supports_bfp16 = True
self.min_size_restriction = self.shuffle_factor = None
ups = [USDU_util.upconv_block(self.num_filters, self.num_filters, act_type=self.act)
for _ in range(int(math.log(self.scale, 2)))]
cb = lambda inc, outc, act=None: USDU_util.conv_block(inc, outc, 3, norm_type=None, act_type=act)
self.model = USDU_util.sequential(
cb(self.in_nc, self.num_filters),
USDU_util.ShortcutBlock(USDU_util.sequential(
*[RRDB(self.num_filters, 3, 32, norm_type=self.norm, act_type=self.act, plus=self.plus)
for _ in range(self.num_blocks)],
cb(self.num_filters, self.num_filters))),
*ups,
cb(self.num_filters, self.num_filters, act=self.act),
cb(self.num_filters, self.out_nc))
self.load_state_dict(self.state, strict=False)
def _new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert new arch state dict to old format."""
for kind in ("weight", "bias"):
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[f"model.1.sub./NB/.{kind}"]
del self.state_map[f"model.1.sub./NB/.{kind}"]
old_state = OrderedDict()
for old_key, new_keys in self.state_map.items():
for new_key in new_keys:
if r"\1" in old_key:
for k, v in state.items():
sub = re.sub(new_key, old_key, k)
if sub != k: old_state[sub] = v
elif new_key in state:
old_state[old_key] = state[new_key]
max_upconv = 0
for key in state:
if m := re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key):
old_state[f"model.{int(m[2]) * 3}.{m[3]}"] = state[key]
max_upconv = max(max_upconv, int(m[2]) * 3)
for key in state:
if key in ("HRconv.weight", "conv_hr.weight"):
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
elif key in ("HRconv.bias", "conv_hr.bias"):
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
elif key == "conv_last.weight":
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
elif key == "conv_last.bias":
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
return OrderedDict(sorted(old_state.items(), key=lambda x: int(x[0].split(".")[1])))
def _get_scale(self, min_part: int = 6) -> int:
"""Get upscale factor."""
return 2 ** sum(1 for p in self.state if len((ps := p.split("."))[1:]) == 2
and int(ps[1]) > min_part and ps[2] == "weight")
def _get_num_blocks(self) -> int:
"""Get number of RRDB blocks."""
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",)
for sk in state_keys:
if nbs := [int(m[1]) for k in self.state if (m := re.search(sk, k))]:
return max(nbs) + 1
return 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
PyTorchSRModels = (RRDBNet,)
PyTorchSRModel = Union[RRDBNet,]
PyTorchModels = (*PyTorchSRModels,)
PyTorchModel = Union[PyTorchSRModel]
|