File size: 1,741 Bytes
5acc7ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import argparse
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
from model.depthpolyp import build_depthpolyp
def load_checkpoint(path):
try:
return torch.load(path, map_location="cpu", weights_only=True)
except TypeError:
return torch.load(path, map_location="cpu")
def parse_args():
parser = argparse.ArgumentParser(description="Export DepthPolyp to ONNX.")
parser.add_argument("--checkpoint", default="checkpoints/DepthPolyp_Kvasir.pth")
parser.add_argument("--output", default="checkpoints/DepthPolyp_Kvasir.onnx")
parser.add_argument("--image-size", type=int, default=224)
parser.add_argument("--opset", type=int, default=17)
return parser.parse_args()
def main():
args = parse_args()
model = build_depthpolyp(
encoder_name="b0",
in_channels=3,
num_classes=2,
decoder_channels=256,
activation=None,
)
state_dict = load_checkpoint(args.checkpoint)
model.load_state_dict(state_dict, strict=True)
model.eval()
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
dummy = torch.randn(1, 3, args.image_size, args.image_size)
torch.onnx.export(
model,
dummy,
output_path,
input_names=["image"],
output_names=["segmentation", "depth"],
opset_version=args.opset,
do_constant_folding=True,
dynamic_axes={
"image": {0: "batch"},
"segmentation": {0: "batch"},
"depth": {0: "batch"},
},
)
print(f"Exported ONNX model to {output_path}")
if __name__ == "__main__":
main()
|