bdck commited on
Commit
796d12d
·
verified ·
1 Parent(s): 37e08ff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +169 -1
README.md CHANGED
@@ -1,4 +1,172 @@
1
- How do inference the pidinet onnx:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  ```python
4
  """
 
1
+ ## How to convert to ONNX:
2
+ 1) download the model from https://github.com/hellozhuo/pidinet#:~:text=vary%20too%20much)%3A-,Model,-ODS
3
+ 2) Git clone the pidinet repo `git clone https://github.com/hellozhuo/pidinet.git`
4
+ 3) run the following code via CLI:
5
+
6
+ ```python
7
+ """
8
+ Export a PiDiNet checkpoint to ONNX.
9
+
10
+ Example:
11
+ python pidinet_to_onnx.py \
12
+ --checkpoint pidinet_model/table5_pidinet.pth \
13
+ --output pidinet_table5.onnx \
14
+ --config carv4 --sa --dil --height 512 --width 512
15
+ """
16
+
17
+ import argparse
18
+ from types import SimpleNamespace
19
+ import torch
20
+
21
+ from pidinet.models import (
22
+ pidinet_converted,
23
+ pidinet_small_converted,
24
+ pidinet_tiny_converted,
25
+ )
26
+ from pidinet.models.convert_pidinet import convert_pidinet
27
+
28
+
29
+ MODEL_BUILDERS = {
30
+ "base": pidinet_converted,
31
+ "small": pidinet_small_converted,
32
+ "tiny": pidinet_tiny_converted,
33
+ }
34
+
35
+
36
+ def build_model(config: str, sa: bool, dil: bool, variant: str) -> torch.nn.Module:
37
+ """Create the converted PiDiNet model (uses vanilla convs)."""
38
+ if variant not in MODEL_BUILDERS:
39
+ raise ValueError(f"Unsupported variant '{variant}' (choose from {list(MODEL_BUILDERS)})")
40
+
41
+ args = SimpleNamespace(config=config, sa=sa, dil=dil)
42
+ return MODEL_BUILDERS[variant](args)
43
+
44
+
45
+ def _read_checkpoint(ckpt_path: str):
46
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
47
+ state = checkpoint.get("state_dict", checkpoint)
48
+ return _strip_module_prefix(state)
49
+
50
+
51
+ def _infer_flags_from_state(state_dict):
52
+ """Infer sa/dil from checkpoint contents."""
53
+ has_sa = any(k.startswith("attentions.") for k in state_dict)
54
+ has_dil = any(k.startswith("dilations.") for k in state_dict)
55
+ return has_sa, has_dil
56
+
57
+
58
+ def _strip_module_prefix(state_dict):
59
+ """Remove a leading 'module.' (from DataParallel) if present."""
60
+ if not any(k.startswith("module.") for k in state_dict.keys()):
61
+ return state_dict
62
+ return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
63
+
64
+
65
+ def export_onnx(model, dummy, output_path: str, opset: int):
66
+ output_names = ["side1", "side2", "side3", "side4", "fused"]
67
+ dynamic_axes = {
68
+ "image": {0: "batch", 2: "height", 3: "width"},
69
+ "side1": {0: "batch", 2: "height", 3: "width"},
70
+ "side2": {0: "batch", 2: "height", 3: "width"},
71
+ "side3": {0: "batch", 2: "height", 3: "width"},
72
+ "side4": {0: "batch", 2: "height", 3: "width"},
73
+ "fused": {0: "batch", 2: "height", 3: "width"},
74
+ }
75
+ torch.onnx.export(
76
+ model,
77
+ dummy,
78
+ output_path,
79
+ opset_version=opset,
80
+ input_names=["image"],
81
+ output_names=output_names,
82
+ dynamic_axes=dynamic_axes,
83
+ do_constant_folding=True,
84
+ )
85
+
86
+
87
+ def parse_args():
88
+ parser = argparse.ArgumentParser(description="Convert PiDiNet checkpoint to ONNX.")
89
+ parser.add_argument(
90
+ "--checkpoint",
91
+ type=str,
92
+ default="pidinet_model/table5_pidinet.pth",
93
+ help="Path to PiDiNet checkpoint (.pth).",
94
+ )
95
+ parser.add_argument(
96
+ "--output",
97
+ type=str,
98
+ default="pidinet_table5.onnx",
99
+ help="Path to write ONNX file.",
100
+ )
101
+ parser.add_argument(
102
+ "--config",
103
+ type=str,
104
+ default="carv4",
105
+ help="Model config name (see pidinet/models/config.py).",
106
+ )
107
+ parser.add_argument("--sa", action="store_true", help="Use CSAM.")
108
+ parser.add_argument("--dil", action="store_true", help="Use CDCM.")
109
+ parser.add_argument("--height", type=int, default=512, help="Dummy input height.")
110
+ parser.add_argument("--width", type=int, default=512, help="Dummy input width.")
111
+ parser.add_argument("--batch", type=int, default=1, help="Dummy batch size.")
112
+ parser.add_argument(
113
+ "--opset",
114
+ type=int,
115
+ default=18,
116
+ help="ONNX opset version (>=18 recommended to avoid converter errors).",
117
+ )
118
+ parser.add_argument(
119
+ "--cuda",
120
+ action="store_true",
121
+ help="Export with the model on CUDA (optional).",
122
+ )
123
+ parser.add_argument(
124
+ "--variant",
125
+ choices=["base", "small", "tiny"],
126
+ default="base",
127
+ help="Width of the PiDiNet: 'base' (table5_pidinet), 'small' (table5_pidinet-small), or 'tiny' (table5_pidinet-tiny).",
128
+ )
129
+ parser.add_argument(
130
+ "--strict-flags",
131
+ action="store_true",
132
+ help="Do not auto-adjust --sa/--dil based on checkpoint contents.",
133
+ )
134
+ return parser.parse_args()
135
+
136
+
137
+ def main():
138
+ args = parse_args()
139
+
140
+ raw_state = _read_checkpoint(args.checkpoint)
141
+ inferred_sa, inferred_dil = _infer_flags_from_state(raw_state)
142
+
143
+ sa = inferred_sa or args.sa
144
+ dil = inferred_dil or args.dil
145
+ if not args.strict_flags:
146
+ if args.sa and not inferred_sa:
147
+ print("Checkpoint lacks attention layers; disabling --sa for this export.")
148
+ sa = False
149
+ if args.dil and not inferred_dil:
150
+ print("Checkpoint lacks dilation modules; disabling --dil for this export.")
151
+ dil = False
152
+
153
+ device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
154
+ print(f"Export settings -> variant: {args.variant}, sa: {sa}, dil: {dil}, config: {args.config}")
155
+ model = build_model(args.config, sa, dil, args.variant)
156
+ model.load_state_dict(convert_pidinet(raw_state, args.config))
157
+ model.eval().to(device)
158
+
159
+ dummy = torch.randn(args.batch, 3, args.height, args.width, device=device)
160
+ export_onnx(model, dummy, args.output, args.opset)
161
+
162
+ print(f"Exported ONNX to {args.output}")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
167
+ ```
168
+
169
+ ## How do inference the pidinet onnx:
170
 
171
  ```python
172
  """