ImageUpscaleCPU / export_onnx.py
zerovic's picture
Create export_onnx.py
79065b2 verified
Raw
History Blame Contribute Delete
4.2 kB
import os
import urllib.request
import torch
import torch.nn as nn
from torch.nn import functional as F
# =================================================================
# OFFICIAL REAL-ESRGAN 6B ANIME ARCHITECTURE DEFINITION
# =================================================================
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
def forward(self, x):
x1 = F.lrelu(self.conv1(x), 0.2, inplace=True)
x2 = F.lrelu(self.conv2(torch.cat((x, x1), 1)), 0.2, inplace=True)
x3 = F.lrelu(self.conv3(torch.cat((x, x1, x2), 1)), 0.2, inplace=True)
x4 = F.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)), 0.2, inplace=True)
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
return self.RDB3(self.RDB2(self.RDB1(x))) * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=6, gc=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
self.body = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1)
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
body_fea = self.conv_body(self.body(fea))
fea = fea + body_fea
# PixelShuffle Upsampling Layer
fea = self.lrelu(self.conv_hr(F.interpolate(fea, scale_factor=self.scale, mode='nearest')))
out = self.conv_last(fea)
return out
# =================================================================
# DOWNLOAD AND EXPORT EXECUTOR
# =================================================================
if __name__ == "__main__":
pth_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
pth_path = "RealESRGAN_x4plus_anime_6B.pth"
onnx_path = "RealESRGAN_x4plus_anime_6B.onnx"
if not os.path.exists(pth_path):
print("Downloading official PyTorch .pth weights...")
opener = urllib.request.build_opener()
opener.addheaders = [('User-Agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
urllib.request.urlretrieve(pth_url, pth_path)
print("Initializing RRDBNet Architecture (6 Blocks)...")
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=6, gc=32, scale=4)
print("Loading weight checkpoints into model...")
loadnet = torch.load(pth_path, map_location=torch.device('cpu'))
if 'params_ema' in loadnet:
keyname = 'params_ema'
elif 'params' in loadnet:
keyname = 'params'
else:
keyname = next(iter(loadnet))
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
# Create dummy tensor representing [Batch, Channels, Height, Width]
dummy_input = torch.randn(1, 3, 64, 64, dtype=torch.float32)
print("Exporting to ONNX layout with DYNAMIC SHAPES...")
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {2: 'height', 3: 'width'},
'output': {2: 'height', 3: 'width'}
}
)
print(f"Successfully baked dynamic shape model to: {onnx_path}")