MyCustomNodes / export_birefnet_onnx.py
saliacoel's picture
Upload export_birefnet_onnx.py
eac2ec4 verified
raw
history blame
13.4 kB
#!/usr/bin/env python3
"""
BiRefNet .pth -> ONNX exporter (CPU/GPU), with robust deform_conv2d ONNX patch.
Fixes:
- deform_conv2d_onnx_exporter get_tensor_dim_size returning None (NoneType + int crash)
- checkpoints saved with _orig_mod. prefix (torch.compile)
- supports code_dir layouts:
A) HuggingFace-style: code_dir/birefnet.py (class BiRefNet inside)
B) GitHub-style: code_dir/models/birefnet.py + code_dir/utils.py
Recommended baseline: torch==2.0.1, opset 17, fixed input size (e.g. 1024x1024).
"""
from __future__ import annotations
import argparse
import importlib
import inspect
import os
import re
import sys
from typing import Any, Dict, Iterable, List, Tuple
import torch
import torch.nn as nn
# -------------------------
# DeformConv2d ONNX patching
# -------------------------
def _patch_and_register_deform_conv2d() -> None:
"""
Patch deform_conv2d_onnx_exporter.get_tensor_dim_size so it never returns None
for H/W when possible (fallback to tensor type sizes/strides), then register the op.
This specifically fixes:
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
at create_dcn_params(...): in_h = get_tensor_dim_size(input, 2) + ...
"""
try:
import deform_conv2d_onnx_exporter as d
import torch.onnx.symbolic_helper as sym_help
except Exception as e:
print(f"[deform_conv2d] exporter not available ({type(e).__name__}: {e})")
return
if not hasattr(d, "get_tensor_dim_size"):
print("[deform_conv2d] deform_conv2d_onnx_exporter.get_tensor_dim_size not found; cannot patch.")
return
orig_get = d.get_tensor_dim_size
def patched_get_tensor_dim_size(tensor, dim: int):
# 1) Try original
v = orig_get(tensor, dim)
if v is not None:
return v
# 2) Try torch's internal tensor sizes helper (sometimes more available than _get_tensor_dim_size)
try:
sizes = sym_help._get_tensor_sizes(tensor) # type: ignore[attr-defined]
if sizes is not None and len(sizes) > dim and sizes[dim] is not None:
return int(sizes[dim])
except Exception:
pass
# 3) Try TensorType sizes/strides (Colab-style fallback)
try:
import typing
from torch import _C
ttype = typing.cast(_C.TensorType, tensor.type())
tsizes = ttype.sizes()
if tsizes is not None and len(tsizes) > dim and tsizes[dim] is not None:
return int(tsizes[dim])
tstrides = ttype.strides()
# For contiguous NCHW: strides = (C*H*W, H*W, W, 1)
if tstrides is not None and len(tstrides) >= 4:
s0, s1, s2, s3 = tstrides[0], tstrides[1], tstrides[2], tstrides[3]
if dim == 3 and s2 is not None:
return int(s2) # W
if dim == 2 and s1 is not None and s2 not in (None, 0):
return int(s1 // s2) # H = (H*W)/W
if dim == 1 and s0 is not None and s1 not in (None, 0):
return int(s0 // s1) # C = (C*H*W)/(H*W)
if dim == 0:
# We export with batch=1 dummy input; safe fallback.
return 1
except Exception:
pass
# 4) Last-resort: batch=1 fallback, otherwise hard error with actionable message
if dim == 0:
return 1
raise RuntimeError(
f"[deform_conv2d] Could not infer static dim={dim} for a tensor during ONNX export "
f"(got None from torch). This typically happens with dynamic axes or missing shape info. "
f"Use a fixed input size (no dynamic axes) and export again."
)
d.get_tensor_dim_size = patched_get_tensor_dim_size # type: ignore[assignment]
# Register op after patching so the symbolic uses our patched helper at runtime
try:
d.register_deform_conv2d_onnx_op()
print("[deform_conv2d] Patched get_tensor_dim_size + registered deform_conv2d ONNX op.")
except Exception as e:
print(f"[deform_conv2d] register_deform_conv2d_onnx_op failed ({type(e).__name__}: {e})")
# -------------------------
# BiRefNet importing helpers
# -------------------------
def _ensure_importable_package_dir(code_dir: str) -> Tuple[str, str]:
"""
Make code_dir importable as a package so relative imports inside it work.
Used for HF-style code_dir that contains birefnet.py and BiRefNet_config.py.
"""
code_dir = os.path.abspath(code_dir)
parent = os.path.dirname(code_dir)
pkg = os.path.basename(code_dir)
init_py = os.path.join(code_dir, "__init__.py")
if not os.path.exists(init_py):
open(init_py, "a", encoding="utf-8").close()
if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", pkg):
safe_pkg = "birefnet_pkg"
safe_dir = os.path.join(parent, safe_pkg)
if not os.path.exists(safe_dir):
os.symlink(code_dir, safe_dir)
pkg = safe_pkg
code_dir = safe_dir
init_py = os.path.join(code_dir, "__init__.py")
if not os.path.exists(init_py):
open(init_py, "a", encoding="utf-8").close()
if parent not in sys.path:
sys.path.insert(0, parent)
return pkg, code_dir
def _detect_layout(code_dir: str) -> str:
code_dir = os.path.abspath(code_dir)
if os.path.isfile(os.path.join(code_dir, "models", "birefnet.py")) and os.path.isfile(os.path.join(code_dir, "utils.py")):
return "github"
if os.path.isfile(os.path.join(code_dir, "birefnet.py")):
return "hf"
raise FileNotFoundError(
f"Could not detect BiRefNet layout in {code_dir}.\n"
f"Expected either:\n"
f" - GitHub layout: models/birefnet.py and utils.py\n"
f" - HF layout: birefnet.py\n"
)
def _import_birefnet(code_dir: str):
layout = _detect_layout(code_dir)
if layout == "github":
# Mirror Colab: `from utils import check_state_dict` and `from models.birefnet import BiRefNet`
if code_dir not in sys.path:
sys.path.insert(0, code_dir)
from utils import check_state_dict # type: ignore
from models.birefnet import BiRefNet # type: ignore
return layout, BiRefNet, check_state_dict
# HF layout
pkg, _ = _ensure_importable_package_dir(code_dir)
mod = importlib.import_module(f"{pkg}.birefnet")
if not hasattr(mod, "BiRefNet"):
raise RuntimeError(f"BiRefNet class not found in {pkg}.birefnet")
return layout, getattr(mod, "BiRefNet"), None
# -------------------------
# Weight loading helpers
# -------------------------
def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
if isinstance(obj, dict):
if obj and all(torch.is_tensor(v) for v in obj.values()):
return obj # type: ignore[return-value]
for k in ["state_dict", "model", "model_state_dict", "net", "params", "weights", "ema"]:
if k in obj and isinstance(obj[k], dict) and obj[k] and all(torch.is_tensor(v) for v in obj[k].values()):
return obj[k] # type: ignore[return-value]
for v in obj.values():
if isinstance(v, dict) and v and all(torch.is_tensor(tv) for tv in v.values()):
return v # type: ignore[return-value]
raise RuntimeError("Could not find a state_dict inside the checkpoint.")
def _clean_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
prefixes = ["module.", "_orig_mod.", "model.", "net.", "state_dict."]
out: Dict[str, torch.Tensor] = {}
for k, v in sd.items():
nk = k
changed = True
while changed:
changed = False
for p in prefixes:
if nk.startswith(p):
nk = nk[len(p):]
changed = True
out[nk] = v
return out
def _pretty_list(xs: List[str], n: int = 20) -> List[str]:
return xs[:n] + (["..."] if len(xs) > n else [])
# -------------------------
# Output selection / wrapper
# -------------------------
def _walk_tensors(x: Any) -> Iterable[torch.Tensor]:
if torch.is_tensor(x):
yield x
return
if isinstance(x, dict):
for v in x.values():
yield from _walk_tensors(v)
elif isinstance(x, (list, tuple)):
for v in x:
yield from _walk_tensors(v)
def _pick_output_tensor(model_out: Any, height: int, width: int) -> torch.Tensor:
ts = list(_walk_tensors(model_out))
if not ts:
raise RuntimeError("Model forward returned no tensors.")
# Prefer (B,1,H,W) at (height,width)
for t in ts:
if t.ndim == 4 and t.shape[1] in (1, 3) and t.shape[2] == height and t.shape[3] == width:
return t
# Next: any 4D tensor with H,W == (height,width)
for t in ts:
if t.ndim == 4 and t.shape[2] == height and t.shape[3] == width:
return t
# Else: largest tensor
return max(ts, key=lambda z: z.numel())
class ExportWrapper(nn.Module):
def __init__(self, model: nn.Module, height: int, width: int):
super().__init__()
self.model = model
self.height = height
self.width = width
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.contiguous()
out = self.model(x)
return _pick_output_tensor(out, self.height, self.width)
# -------------------------
# Main
# -------------------------
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--code_dir", required=True)
ap.add_argument("--weights", required=True)
ap.add_argument("--output", required=True)
ap.add_argument("--width", type=int, default=1024)
ap.add_argument("--height", type=int, default=1024)
ap.add_argument("--opset", type=int, default=17)
ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
ap.add_argument("--skip_onnx_check", action="store_true")
args = ap.parse_args()
print("== Environment ==")
print("Python:", sys.version.replace("\n", " "))
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Requested device:", args.device)
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("You asked for --device cuda but CUDA is not available.")
device = torch.device(args.device)
print("Using device:", device)
# IMPORTANT: patch deform_conv2d exporter BEFORE export
_patch_and_register_deform_conv2d()
layout, BiRefNet, check_state_dict = _import_birefnet(args.code_dir)
print("BiRefNet layout detected:", layout)
print("== Building model ==")
kwargs = {}
try:
sig = inspect.signature(BiRefNet)
if "bb_pretrained" in sig.parameters:
kwargs["bb_pretrained"] = False
except Exception:
pass
model = BiRefNet(**kwargs) if kwargs else BiRefNet()
model.eval().to(device)
print("== Loading weights ==")
ckpt = torch.load(args.weights, map_location="cpu")
if layout == "github" and check_state_dict is not None:
# Colab-style path
sd = check_state_dict(ckpt)
missing, unexpected = model.load_state_dict(sd, strict=False)
else:
# HF-style path
sd = _extract_state_dict(ckpt)
sd = _clean_state_dict_keys(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = list(missing)
unexpected = list(unexpected)
print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
if missing:
print(" (first 20 missing):", _pretty_list(missing, 20))
if unexpected:
print(" (first 20 unexpected):", _pretty_list(unexpected, 20))
wrapper = ExportWrapper(model, height=args.height, width=args.width).eval().to(device)
print("== Forward probe ==")
dummy = torch.randn(1, 3, args.height, args.width, device=device)
with torch.no_grad():
out = wrapper(dummy)
print("Picked output shape:", tuple(out.shape), "dtype:", out.dtype)
print("== Exporting ONNX ==")
out_path = os.path.abspath(args.output)
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
# NOTE: No dynamic_axes by default (keeps shapes static and avoids shape None issues).
torch.onnx.export(
wrapper,
dummy,
out_path,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
verbose=False,
)
print("Saved ONNX to:", out_path)
if not args.skip_onnx_check:
print("== Checking ONNX ==")
import onnx
m = onnx.load(out_path)
onnx.checker.check_model(m)
print("ONNX check: OK")
try:
mb = os.path.getsize(out_path) / (1024 * 1024)
print(f"ONNX size: {mb:.1f} MB")
except Exception:
pass
if __name__ == "__main__":
main()