| ONNX models available in the `Files and versions` tab. You need both the .onnx and the .onnx.data files to inference the model. | |
| ## How to convert to ONNX: | |
| 1) download the model from https://github.com/hellozhuo/pidinet#:~:text=vary%20too%20much)%3A-,Model,-ODS | |
| 2) Git clone the pidinet repo `git clone https://github.com/hellozhuo/pidinet.git` | |
| 3) run the following code via CLI: | |
| ```python | |
| """ | |
| Export a PiDiNet checkpoint to ONNX. | |
| Example: | |
| python pidinet_to_onnx.py \ | |
| --checkpoint table5_pidinet.pth \ | |
| --output pidinet_table5.onnx \ | |
| --config carv4 --sa --dil --height 512 --width 512 | |
| """ | |
| import argparse | |
| from types import SimpleNamespace | |
| import torch | |
| from pidinet.models import ( | |
| pidinet_converted, | |
| pidinet_small_converted, | |
| pidinet_tiny_converted, | |
| ) | |
| from pidinet.models.convert_pidinet import convert_pidinet | |
| MODEL_BUILDERS = { | |
| "base": pidinet_converted, | |
| "small": pidinet_small_converted, | |
| "tiny": pidinet_tiny_converted, | |
| } | |
| def build_model(config: str, sa: bool, dil: bool, variant: str) -> torch.nn.Module: | |
| """Create the converted PiDiNet model (uses vanilla convs).""" | |
| if variant not in MODEL_BUILDERS: | |
| raise ValueError(f"Unsupported variant '{variant}' (choose from {list(MODEL_BUILDERS)})") | |
| args = SimpleNamespace(config=config, sa=sa, dil=dil) | |
| return MODEL_BUILDERS[variant](args) | |
| def _read_checkpoint(ckpt_path: str): | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state = checkpoint.get("state_dict", checkpoint) | |
| return _strip_module_prefix(state) | |
| def _infer_flags_from_state(state_dict): | |
| """Infer sa/dil from checkpoint contents.""" | |
| has_sa = any(k.startswith("attentions.") for k in state_dict) | |
| has_dil = any(k.startswith("dilations.") for k in state_dict) | |
| return has_sa, has_dil | |
| def _strip_module_prefix(state_dict): | |
| """Remove a leading 'module.' (from DataParallel) if present.""" | |
| if not any(k.startswith("module.") for k in state_dict.keys()): | |
| return state_dict | |
| return {k.replace("module.", "", 1): v for k, v in state_dict.items()} | |
| def export_onnx(model, dummy, output_path: str, opset: int): | |
| output_names = ["side1", "side2", "side3", "side4", "fused"] | |
| dynamic_axes = { | |
| "image": {0: "batch", 2: "height", 3: "width"}, | |
| "side1": {0: "batch", 2: "height", 3: "width"}, | |
| "side2": {0: "batch", 2: "height", 3: "width"}, | |
| "side3": {0: "batch", 2: "height", 3: "width"}, | |
| "side4": {0: "batch", 2: "height", 3: "width"}, | |
| "fused": {0: "batch", 2: "height", 3: "width"}, | |
| } | |
| torch.onnx.export( | |
| model, | |
| dummy, | |
| output_path, | |
| opset_version=opset, | |
| input_names=["image"], | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| do_constant_folding=True, | |
| ) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Convert PiDiNet checkpoint to ONNX.") | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default="pidinet_model/table5_pidinet.pth", | |
| help="Path to PiDiNet checkpoint (.pth).", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="pidinet_table5.onnx", | |
| help="Path to write ONNX file.", | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="carv4", | |
| help="Model config name (see pidinet/models/config.py).", | |
| ) | |
| parser.add_argument("--sa", action="store_true", help="Use CSAM.") | |
| parser.add_argument("--dil", action="store_true", help="Use CDCM.") | |
| parser.add_argument("--height", type=int, default=512, help="Dummy input height.") | |
| parser.add_argument("--width", type=int, default=512, help="Dummy input width.") | |
| parser.add_argument("--batch", type=int, default=1, help="Dummy batch size.") | |
| parser.add_argument( | |
| "--opset", | |
| type=int, | |
| default=18, | |
| help="ONNX opset version (>=18 recommended to avoid converter errors).", | |
| ) | |
| parser.add_argument( | |
| "--cuda", | |
| action="store_true", | |
| help="Export with the model on CUDA (optional).", | |
| ) | |
| parser.add_argument( | |
| "--variant", | |
| choices=["base", "small", "tiny"], | |
| default="base", | |
| help="Width of the PiDiNet: 'base' (table5_pidinet), 'small' (table5_pidinet-small), or 'tiny' (table5_pidinet-tiny).", | |
| ) | |
| parser.add_argument( | |
| "--strict-flags", | |
| action="store_true", | |
| help="Do not auto-adjust --sa/--dil based on checkpoint contents.", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| raw_state = _read_checkpoint(args.checkpoint) | |
| inferred_sa, inferred_dil = _infer_flags_from_state(raw_state) | |
| sa = inferred_sa or args.sa | |
| dil = inferred_dil or args.dil | |
| if not args.strict_flags: | |
| if args.sa and not inferred_sa: | |
| print("Checkpoint lacks attention layers; disabling --sa for this export.") | |
| sa = False | |
| if args.dil and not inferred_dil: | |
| print("Checkpoint lacks dilation modules; disabling --dil for this export.") | |
| dil = False | |
| device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu") | |
| print(f"Export settings -> variant: {args.variant}, sa: {sa}, dil: {dil}, config: {args.config}") | |
| model = build_model(args.config, sa, dil, args.variant) | |
| model.load_state_dict(convert_pidinet(raw_state, args.config)) | |
| model.eval().to(device) | |
| dummy = torch.randn(args.batch, 3, args.height, args.width, device=device) | |
| export_onnx(model, dummy, args.output, args.opset) | |
| print(f"Exported ONNX to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |
| ## How do inference the pidinet onnx: | |
| ```python | |
| """ | |
| Run the PiDiNet ONNX model on one image and save the fused edge map. | |
| Example: | |
| python test_pidinet_onnx.py \ | |
| --onnx model_PIDINET/pidinet_table5.onnx \ | |
| --image Images/example.jpg \ | |
| --output Results/example_edges.png | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None] | |
| STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None] | |
| def preprocess(img_path: Path) -> np.ndarray: | |
| img = Image.open(img_path).convert("RGB") | |
| arr = np.asarray(img, dtype=np.float32) / 255.0 # HWC in [0,1] | |
| arr = arr.transpose(2, 0, 1) # CHW | |
| arr = (arr - MEAN) / STD | |
| return arr[None, ...] # BCHW | |
| def postprocess(edge_map: np.ndarray, out_path: Path): | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| edge_map = np.clip(edge_map, 0.0, 1.0) | |
| edge_img = (edge_map * 255.0).astype(np.uint8) | |
| Image.fromarray(edge_img).save(out_path) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Test PiDiNet ONNX on a single image.") | |
| parser.add_argument( | |
| "--onnx", | |
| type=Path, | |
| default=Path("model_PIDINET/pidinet_table5.onnx"), | |
| help="Path to the PiDiNet ONNX file.", | |
| ) | |
| parser.add_argument( | |
| "--image", | |
| type=Path, | |
| required=True, | |
| help="Input image path.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=Path("Results/pidinet_edges.png"), | |
| help="Where to save the fused edge map.", | |
| ) | |
| parser.add_argument( | |
| "--provider", | |
| type=str, | |
| default="CPUExecutionProvider", | |
| help="ONNX Runtime provider (e.g., CPUExecutionProvider or CUDAExecutionProvider).", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| session = ort.InferenceSession( | |
| str(args.onnx), | |
| providers=[args.provider], | |
| ) | |
| inp = preprocess(args.image) | |
| outputs = session.run(None, {"image": inp}) | |
| fused = np.array(outputs[-1])[0, 0] # fused edge map | |
| postprocess(fused, args.output) | |
| print(f"Saved edge map to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |
| ``` |