Spaces:
Running
Running
| import os | |
| import urllib.request | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| # ================================================================= | |
| # OFFICIAL REAL-ESRGAN 6B ANIME ARCHITECTURE DEFINITION | |
| # ================================================================= | |
| class ResidualDenseBlock_5C(nn.Module): | |
| def __init__(self, nf=64, gc=32, bias=True): | |
| super(ResidualDenseBlock_5C, self).__init__() | |
| self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) | |
| self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) | |
| self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) | |
| self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) | |
| self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) | |
| def forward(self, x): | |
| x1 = F.lrelu(self.conv1(x), 0.2, inplace=True) | |
| x2 = F.lrelu(self.conv2(torch.cat((x, x1), 1)), 0.2, inplace=True) | |
| x3 = F.lrelu(self.conv3(torch.cat((x, x1, x2), 1)), 0.2, inplace=True) | |
| x4 = F.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)), 0.2, inplace=True) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| class RRDB(nn.Module): | |
| def __init__(self, nf, gc=32): | |
| super(RRDB, self).__init__() | |
| self.RDB1 = ResidualDenseBlock_5C(nf, gc) | |
| self.RDB2 = ResidualDenseBlock_5C(nf, gc) | |
| self.RDB3 = ResidualDenseBlock_5C(nf, gc) | |
| def forward(self, x): | |
| return self.RDB3(self.RDB2(self.RDB1(x))) * 0.2 + x | |
| class RRDBNet(nn.Module): | |
| def __init__(self, in_nc=3, out_nc=3, nf=64, nb=6, gc=32, scale=4): | |
| super(RRDBNet, self).__init__() | |
| self.scale = scale | |
| self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1) | |
| self.body = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)]) | |
| self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1) | |
| self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1) | |
| self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| def forward(self, x): | |
| fea = self.conv_first(x) | |
| body_fea = self.conv_body(self.body(fea)) | |
| fea = fea + body_fea | |
| # PixelShuffle Upsampling Layer | |
| fea = self.lrelu(self.conv_hr(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))) | |
| out = self.conv_last(fea) | |
| return out | |
| # ================================================================= | |
| # DOWNLOAD AND EXPORT EXECUTOR | |
| # ================================================================= | |
| if __name__ == "__main__": | |
| pth_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" | |
| pth_path = "RealESRGAN_x4plus_anime_6B.pth" | |
| onnx_path = "RealESRGAN_x4plus_anime_6B.onnx" | |
| if not os.path.exists(pth_path): | |
| print("Downloading official PyTorch .pth weights...") | |
| opener = urllib.request.build_opener() | |
| opener.addheaders = [('User-Agent', 'Mozilla/5.0')] | |
| urllib.request.install_opener(opener) | |
| urllib.request.urlretrieve(pth_url, pth_path) | |
| print("Initializing RRDBNet Architecture (6 Blocks)...") | |
| model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=6, gc=32, scale=4) | |
| print("Loading weight checkpoints into model...") | |
| loadnet = torch.load(pth_path, map_location=torch.device('cpu')) | |
| if 'params_ema' in loadnet: | |
| keyname = 'params_ema' | |
| elif 'params' in loadnet: | |
| keyname = 'params' | |
| else: | |
| keyname = next(iter(loadnet)) | |
| model.load_state_dict(loadnet[keyname], strict=True) | |
| model.eval() | |
| # Create dummy tensor representing [Batch, Channels, Height, Width] | |
| dummy_input = torch.randn(1, 3, 64, 64, dtype=torch.float32) | |
| print("Exporting to ONNX layout with DYNAMIC SHAPES...") | |
| torch.onnx.export( | |
| model, | |
| dummy_input, | |
| onnx_path, | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'input': {2: 'height', 3: 'width'}, | |
| 'output': {2: 'height', 3: 'width'} | |
| } | |
| ) | |
| print(f"Successfully baked dynamic shape model to: {onnx_path}") |