File size: 1,741 Bytes
5acc7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import sys
from pathlib import Path

import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from model.depthpolyp import build_depthpolyp


def load_checkpoint(path):
    try:
        return torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        return torch.load(path, map_location="cpu")


def parse_args():
    parser = argparse.ArgumentParser(description="Export DepthPolyp to ONNX.")
    parser.add_argument("--checkpoint", default="checkpoints/DepthPolyp_Kvasir.pth")
    parser.add_argument("--output", default="checkpoints/DepthPolyp_Kvasir.onnx")
    parser.add_argument("--image-size", type=int, default=224)
    parser.add_argument("--opset", type=int, default=17)
    return parser.parse_args()


def main():
    args = parse_args()
    model = build_depthpolyp(
        encoder_name="b0",
        in_channels=3,
        num_classes=2,
        decoder_channels=256,
        activation=None,
    )
    state_dict = load_checkpoint(args.checkpoint)
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    dummy = torch.randn(1, 3, args.image_size, args.image_size)
    torch.onnx.export(
        model,
        dummy,
        output_path,
        input_names=["image"],
        output_names=["segmentation", "depth"],
        opset_version=args.opset,
        do_constant_folding=True,
        dynamic_axes={
            "image": {0: "batch"},
            "segmentation": {0: "batch"},
            "depth": {0: "batch"},
        },
    )
    print(f"Exported ONNX model to {output_path}")


if __name__ == "__main__":
    main()