import argparse import sys from pathlib import Path import torch REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT)) from model.depthpolyp import build_depthpolyp def load_checkpoint(path): try: return torch.load(path, map_location="cpu", weights_only=True) except TypeError: return torch.load(path, map_location="cpu") def parse_args(): parser = argparse.ArgumentParser(description="Export DepthPolyp to ONNX.") parser.add_argument("--checkpoint", default="checkpoints/DepthPolyp_Kvasir.pth") parser.add_argument("--output", default="checkpoints/DepthPolyp_Kvasir.onnx") parser.add_argument("--image-size", type=int, default=224) parser.add_argument("--opset", type=int, default=17) return parser.parse_args() def main(): args = parse_args() model = build_depthpolyp( encoder_name="b0", in_channels=3, num_classes=2, decoder_channels=256, activation=None, ) state_dict = load_checkpoint(args.checkpoint) model.load_state_dict(state_dict, strict=True) model.eval() output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) dummy = torch.randn(1, 3, args.image_size, args.image_size) torch.onnx.export( model, dummy, output_path, input_names=["image"], output_names=["segmentation", "depth"], opset_version=args.opset, do_constant_folding=True, dynamic_axes={ "image": {0: "batch"}, "segmentation": {0: "batch"}, "depth": {0: "batch"}, }, ) print(f"Exported ONNX model to {output_path}") if __name__ == "__main__": main()