|
|
""" |
|
|
Export ONNX model of MODNet with: |
|
|
input shape: (batch_size, 3, height, width) |
|
|
output shape: (batch_size, 1, height, width) |
|
|
|
|
|
Arguments: |
|
|
--ckpt-path: path of the checkpoint that will be converted |
|
|
--output-path: path for saving the ONNX model |
|
|
|
|
|
Example: |
|
|
python export_onnx.py \ |
|
|
--ckpt-path=modnet_photographic_portrait_matting.ckpt \ |
|
|
--output-path=modnet_photographic_portrait_matting.onnx |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Variable |
|
|
|
|
|
from . import modnet_onnx |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--ckpt-path', type=str, required=True, help='path of the checkpoint that will be converted') |
|
|
parser.add_argument('--output-path', type=str, required=True, help='path for saving the ONNX model') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.path.exists(args.ckpt_path): |
|
|
print('Cannot find checkpoint path: {0}'.format(args.ckpt_path)) |
|
|
exit() |
|
|
|
|
|
|
|
|
modnet = modnet_onnx.MODNet(backbone_pretrained=False) |
|
|
modnet = nn.DataParallel(modnet).cuda() |
|
|
state_dict = torch.load(args.ckpt_path) |
|
|
modnet.load_state_dict(state_dict) |
|
|
modnet.eval() |
|
|
|
|
|
|
|
|
batch_size = 1 |
|
|
height = 512 |
|
|
width = 512 |
|
|
dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda() |
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
|
modnet.module, dummy_input, args.output_path, export_params = True, |
|
|
input_names = ['input'], output_names = ['output'], |
|
|
dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}}) |
|
|
|