DepthPolyp / scripts /export_onnx.py
ReaganWZY's picture
Upload DepthPolyp model artifacts
5acc7ae verified
import argparse
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
from model.depthpolyp import build_depthpolyp
def load_checkpoint(path):
try:
return torch.load(path, map_location="cpu", weights_only=True)
except TypeError:
return torch.load(path, map_location="cpu")
def parse_args():
parser = argparse.ArgumentParser(description="Export DepthPolyp to ONNX.")
parser.add_argument("--checkpoint", default="checkpoints/DepthPolyp_Kvasir.pth")
parser.add_argument("--output", default="checkpoints/DepthPolyp_Kvasir.onnx")
parser.add_argument("--image-size", type=int, default=224)
parser.add_argument("--opset", type=int, default=17)
return parser.parse_args()
def main():
args = parse_args()
model = build_depthpolyp(
encoder_name="b0",
in_channels=3,
num_classes=2,
decoder_channels=256,
activation=None,
)
state_dict = load_checkpoint(args.checkpoint)
model.load_state_dict(state_dict, strict=True)
model.eval()
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
dummy = torch.randn(1, 3, args.image_size, args.image_size)
torch.onnx.export(
model,
dummy,
output_path,
input_names=["image"],
output_names=["segmentation", "depth"],
opset_version=args.opset,
do_constant_folding=True,
dynamic_axes={
"image": {0: "batch"},
"segmentation": {0: "batch"},
"depth": {0: "batch"},
},
)
print(f"Exported ONNX model to {output_path}")
if __name__ == "__main__":
main()