|
|
| """
|
| BiRefNet .pth -> ONNX exporter (CPU/GPU), with robust deform_conv2d ONNX patch.
|
|
|
| Fixes:
|
| - deform_conv2d_onnx_exporter get_tensor_dim_size returning None (NoneType + int crash)
|
| - checkpoints saved with _orig_mod. prefix (torch.compile)
|
| - supports code_dir layouts:
|
| A) HuggingFace-style: code_dir/birefnet.py (class BiRefNet inside)
|
| B) GitHub-style: code_dir/models/birefnet.py + code_dir/utils.py
|
|
|
| Recommended baseline: torch==2.0.1, opset 17, fixed input size (e.g. 1024x1024).
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import importlib
|
| import inspect
|
| import os
|
| import re
|
| import sys
|
| from typing import Any, Dict, Iterable, List, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _patch_and_register_deform_conv2d() -> None:
|
| """
|
| Patch deform_conv2d_onnx_exporter.get_tensor_dim_size so it never returns None
|
| for H/W when possible (fallback to tensor type sizes/strides), then register the op.
|
|
|
| This specifically fixes:
|
| TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
|
| at create_dcn_params(...): in_h = get_tensor_dim_size(input, 2) + ...
|
| """
|
| try:
|
| import deform_conv2d_onnx_exporter as d
|
| import torch.onnx.symbolic_helper as sym_help
|
| except Exception as e:
|
| print(f"[deform_conv2d] exporter not available ({type(e).__name__}: {e})")
|
| return
|
|
|
| if not hasattr(d, "get_tensor_dim_size"):
|
| print("[deform_conv2d] deform_conv2d_onnx_exporter.get_tensor_dim_size not found; cannot patch.")
|
| return
|
|
|
| orig_get = d.get_tensor_dim_size
|
|
|
| def patched_get_tensor_dim_size(tensor, dim: int):
|
|
|
| v = orig_get(tensor, dim)
|
| if v is not None:
|
| return v
|
|
|
|
|
| try:
|
| sizes = sym_help._get_tensor_sizes(tensor)
|
| if sizes is not None and len(sizes) > dim and sizes[dim] is not None:
|
| return int(sizes[dim])
|
| except Exception:
|
| pass
|
|
|
|
|
| try:
|
| import typing
|
| from torch import _C
|
|
|
| ttype = typing.cast(_C.TensorType, tensor.type())
|
| tsizes = ttype.sizes()
|
| if tsizes is not None and len(tsizes) > dim and tsizes[dim] is not None:
|
| return int(tsizes[dim])
|
|
|
| tstrides = ttype.strides()
|
|
|
| if tstrides is not None and len(tstrides) >= 4:
|
| s0, s1, s2, s3 = tstrides[0], tstrides[1], tstrides[2], tstrides[3]
|
|
|
| if dim == 3 and s2 is not None:
|
| return int(s2)
|
|
|
| if dim == 2 and s1 is not None and s2 not in (None, 0):
|
| return int(s1 // s2)
|
|
|
| if dim == 1 and s0 is not None and s1 not in (None, 0):
|
| return int(s0 // s1)
|
|
|
| if dim == 0:
|
|
|
| return 1
|
| except Exception:
|
| pass
|
|
|
|
|
| if dim == 0:
|
| return 1
|
|
|
| raise RuntimeError(
|
| f"[deform_conv2d] Could not infer static dim={dim} for a tensor during ONNX export "
|
| f"(got None from torch). This typically happens with dynamic axes or missing shape info. "
|
| f"Use a fixed input size (no dynamic axes) and export again."
|
| )
|
|
|
| d.get_tensor_dim_size = patched_get_tensor_dim_size
|
|
|
|
|
| try:
|
| d.register_deform_conv2d_onnx_op()
|
| print("[deform_conv2d] Patched get_tensor_dim_size + registered deform_conv2d ONNX op.")
|
| except Exception as e:
|
| print(f"[deform_conv2d] register_deform_conv2d_onnx_op failed ({type(e).__name__}: {e})")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _ensure_importable_package_dir(code_dir: str) -> Tuple[str, str]:
|
| """
|
| Make code_dir importable as a package so relative imports inside it work.
|
| Used for HF-style code_dir that contains birefnet.py and BiRefNet_config.py.
|
| """
|
| code_dir = os.path.abspath(code_dir)
|
| parent = os.path.dirname(code_dir)
|
| pkg = os.path.basename(code_dir)
|
|
|
| init_py = os.path.join(code_dir, "__init__.py")
|
| if not os.path.exists(init_py):
|
| open(init_py, "a", encoding="utf-8").close()
|
|
|
| if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", pkg):
|
| safe_pkg = "birefnet_pkg"
|
| safe_dir = os.path.join(parent, safe_pkg)
|
| if not os.path.exists(safe_dir):
|
| os.symlink(code_dir, safe_dir)
|
| pkg = safe_pkg
|
| code_dir = safe_dir
|
| init_py = os.path.join(code_dir, "__init__.py")
|
| if not os.path.exists(init_py):
|
| open(init_py, "a", encoding="utf-8").close()
|
|
|
| if parent not in sys.path:
|
| sys.path.insert(0, parent)
|
|
|
| return pkg, code_dir
|
|
|
|
|
| def _detect_layout(code_dir: str) -> str:
|
| code_dir = os.path.abspath(code_dir)
|
| if os.path.isfile(os.path.join(code_dir, "models", "birefnet.py")) and os.path.isfile(os.path.join(code_dir, "utils.py")):
|
| return "github"
|
| if os.path.isfile(os.path.join(code_dir, "birefnet.py")):
|
| return "hf"
|
| raise FileNotFoundError(
|
| f"Could not detect BiRefNet layout in {code_dir}.\n"
|
| f"Expected either:\n"
|
| f" - GitHub layout: models/birefnet.py and utils.py\n"
|
| f" - HF layout: birefnet.py\n"
|
| )
|
|
|
|
|
| def _import_birefnet(code_dir: str):
|
| layout = _detect_layout(code_dir)
|
|
|
| if layout == "github":
|
|
|
| if code_dir not in sys.path:
|
| sys.path.insert(0, code_dir)
|
| from utils import check_state_dict
|
| from models.birefnet import BiRefNet
|
| return layout, BiRefNet, check_state_dict
|
|
|
|
|
| pkg, _ = _ensure_importable_package_dir(code_dir)
|
| mod = importlib.import_module(f"{pkg}.birefnet")
|
| if not hasattr(mod, "BiRefNet"):
|
| raise RuntimeError(f"BiRefNet class not found in {pkg}.birefnet")
|
| return layout, getattr(mod, "BiRefNet"), None
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
|
| if isinstance(obj, dict):
|
| if obj and all(torch.is_tensor(v) for v in obj.values()):
|
| return obj
|
| for k in ["state_dict", "model", "model_state_dict", "net", "params", "weights", "ema"]:
|
| if k in obj and isinstance(obj[k], dict) and obj[k] and all(torch.is_tensor(v) for v in obj[k].values()):
|
| return obj[k]
|
| for v in obj.values():
|
| if isinstance(v, dict) and v and all(torch.is_tensor(tv) for tv in v.values()):
|
| return v
|
| raise RuntimeError("Could not find a state_dict inside the checkpoint.")
|
|
|
|
|
| def _clean_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| prefixes = ["module.", "_orig_mod.", "model.", "net.", "state_dict."]
|
| out: Dict[str, torch.Tensor] = {}
|
| for k, v in sd.items():
|
| nk = k
|
| changed = True
|
| while changed:
|
| changed = False
|
| for p in prefixes:
|
| if nk.startswith(p):
|
| nk = nk[len(p):]
|
| changed = True
|
| out[nk] = v
|
| return out
|
|
|
|
|
| def _pretty_list(xs: List[str], n: int = 20) -> List[str]:
|
| return xs[:n] + (["..."] if len(xs) > n else [])
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _walk_tensors(x: Any) -> Iterable[torch.Tensor]:
|
| if torch.is_tensor(x):
|
| yield x
|
| return
|
| if isinstance(x, dict):
|
| for v in x.values():
|
| yield from _walk_tensors(v)
|
| elif isinstance(x, (list, tuple)):
|
| for v in x:
|
| yield from _walk_tensors(v)
|
|
|
|
|
| def _pick_output_tensor(model_out: Any, height: int, width: int) -> torch.Tensor:
|
| ts = list(_walk_tensors(model_out))
|
| if not ts:
|
| raise RuntimeError("Model forward returned no tensors.")
|
|
|
| for t in ts:
|
| if t.ndim == 4 and t.shape[1] in (1, 3) and t.shape[2] == height and t.shape[3] == width:
|
| return t
|
|
|
| for t in ts:
|
| if t.ndim == 4 and t.shape[2] == height and t.shape[3] == width:
|
| return t
|
|
|
| return max(ts, key=lambda z: z.numel())
|
|
|
|
|
| class ExportWrapper(nn.Module):
|
| def __init__(self, model: nn.Module, height: int, width: int):
|
| super().__init__()
|
| self.model = model
|
| self.height = height
|
| self.width = width
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = x.contiguous()
|
| out = self.model(x)
|
| return _pick_output_tensor(out, self.height, self.width)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| ap = argparse.ArgumentParser()
|
| ap.add_argument("--code_dir", required=True)
|
| ap.add_argument("--weights", required=True)
|
| ap.add_argument("--output", required=True)
|
| ap.add_argument("--width", type=int, default=1024)
|
| ap.add_argument("--height", type=int, default=1024)
|
| ap.add_argument("--opset", type=int, default=17)
|
| ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
|
| ap.add_argument("--skip_onnx_check", action="store_true")
|
| args = ap.parse_args()
|
|
|
| print("== Environment ==")
|
| print("Python:", sys.version.replace("\n", " "))
|
| print("Torch:", torch.__version__)
|
| print("CUDA available:", torch.cuda.is_available())
|
| print("Requested device:", args.device)
|
|
|
| if args.device == "cuda" and not torch.cuda.is_available():
|
| raise RuntimeError("You asked for --device cuda but CUDA is not available.")
|
|
|
| device = torch.device(args.device)
|
| print("Using device:", device)
|
|
|
|
|
| _patch_and_register_deform_conv2d()
|
|
|
| layout, BiRefNet, check_state_dict = _import_birefnet(args.code_dir)
|
| print("BiRefNet layout detected:", layout)
|
|
|
| print("== Building model ==")
|
| kwargs = {}
|
| try:
|
| sig = inspect.signature(BiRefNet)
|
| if "bb_pretrained" in sig.parameters:
|
| kwargs["bb_pretrained"] = False
|
| except Exception:
|
| pass
|
|
|
| model = BiRefNet(**kwargs) if kwargs else BiRefNet()
|
| model.eval().to(device)
|
|
|
| print("== Loading weights ==")
|
| ckpt = torch.load(args.weights, map_location="cpu")
|
|
|
| if layout == "github" and check_state_dict is not None:
|
|
|
| sd = check_state_dict(ckpt)
|
| missing, unexpected = model.load_state_dict(sd, strict=False)
|
| else:
|
|
|
| sd = _extract_state_dict(ckpt)
|
| sd = _clean_state_dict_keys(sd)
|
| missing, unexpected = model.load_state_dict(sd, strict=False)
|
|
|
| missing = list(missing)
|
| unexpected = list(unexpected)
|
| print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
|
| if missing:
|
| print(" (first 20 missing):", _pretty_list(missing, 20))
|
| if unexpected:
|
| print(" (first 20 unexpected):", _pretty_list(unexpected, 20))
|
|
|
| wrapper = ExportWrapper(model, height=args.height, width=args.width).eval().to(device)
|
|
|
|
|
| print("== Forward probe ==")
|
| dummy = torch.randn(1, 3, args.height, args.width, device=device)
|
|
|
| with torch.no_grad():
|
| out = wrapper(dummy)
|
| print("Picked output shape:", tuple(out.shape), "dtype:", out.dtype)
|
|
|
| print("== Exporting ONNX ==")
|
| out_path = os.path.abspath(args.output)
|
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
|
|
|
|
| torch.onnx.export(
|
| wrapper,
|
| dummy,
|
| out_path,
|
| export_params=True,
|
| opset_version=args.opset,
|
| do_constant_folding=True,
|
| input_names=["input"],
|
| output_names=["output"],
|
| verbose=False,
|
| )
|
|
|
| print("Saved ONNX to:", out_path)
|
|
|
| if not args.skip_onnx_check:
|
| print("== Checking ONNX ==")
|
| import onnx
|
| m = onnx.load(out_path)
|
| onnx.checker.check_model(m)
|
| print("ONNX check: OK")
|
|
|
| try:
|
| mb = os.path.getsize(out_path) / (1024 * 1024)
|
| print(f"ONNX size: {mb:.1f} MB")
|
| except Exception:
|
| pass
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|