saliacoel commited on
Commit
6432d29
·
verified ·
1 Parent(s): 2a3805a

Upload export_birefnet_onnx.py

Browse files
Files changed (1) hide show
  1. export_birefnet_onnx.py +236 -175
export_birefnet_onnx.py CHANGED
@@ -1,30 +1,25 @@
1
  #!/usr/bin/env python3
2
  """
3
- Export BiRefNet .pth weights to ONNX.
4
-
5
- Works with the environment used by BiRefNet demo-style setups:
6
  - Python 3.10
7
- - torch==2.0.1+cu118
8
- - transformers==4.42.4 (IMPORTANT: newer transformers may require torch>=2.1 and will disable torch)
9
-
10
- Example:
11
- python export_birefnet_onnx.py \
12
- --code_dir birefnet_code \
13
- --weights weights/birefnet_finetuned_toonout.pth \
14
- --output birefnet_finetuned_toonout.onnx \
15
- --img_size 1024 \
16
- --opset 17 \
17
- --device cuda \
18
- --external_data
19
  """
20
 
21
  from __future__ import annotations
22
 
23
  import argparse
 
24
  import os
25
  import sys
26
- from pathlib import Path
27
- from typing import Any, Dict
28
 
29
  import torch
30
 
@@ -34,200 +29,266 @@ def _print_env(device: str) -> None:
34
  print("Python:", sys.version.replace("\n", " "))
35
  print("Torch:", torch.__version__)
36
  print("CUDA available:", torch.cuda.is_available())
37
- if device.startswith("cuda") and torch.cuda.is_available():
38
- idx = 0
39
  try:
40
  idx = torch.cuda.current_device()
 
41
  except Exception:
42
  pass
43
- try:
44
- name = torch.cuda.get_device_name(idx)
45
- except Exception:
46
- name = "cuda"
47
- print("CUDA device:", name)
48
 
49
 
50
- def _ensure_transformers_torch_backend_ok() -> None:
51
- # If transformers is installed but thinks torch is unavailable, BiRefNet will import dummy torch objects.
 
 
 
52
  try:
53
- import transformers # noqa
54
- from transformers.utils import is_torch_available # noqa
55
-
56
- if not is_torch_available():
57
- raise RuntimeError(
58
- "Transformers is installed but has DISABLED the PyTorch backend.\n"
59
- "This usually happens when your transformers version requires a newer torch.\n\n"
60
- "Fix (recommended):\n"
61
- " pip uninstall -y transformers tokenizers\n"
62
- " pip install transformers==4.42.4 huggingface_hub==0.23.4\n"
63
- )
64
- except ModuleNotFoundError:
65
- # BiRefNet HF-style code requires transformers; we'll let the import fail later with a clear error.
66
- pass
67
 
68
-
69
- def _try_register_deformconv_exporter() -> None:
70
- ok = False
71
- try:
72
- import deform_conv2d_onnx_exporter as d # type: ignore
73
-
74
- # Try common entry points
75
- for name in (
76
- "register_deform_conv2d_onnx_exporter",
77
- "register",
78
- "setup",
79
- ):
80
- fn = getattr(d, name, None)
81
- if callable(fn):
82
- fn()
83
- ok = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  break
85
- except Exception:
86
- ok = False
87
 
88
- print("DeformConv2d ONNX exporter:", "OK" if ok else "NOT LOADED (may fail if model uses DeformConv)")
 
89
 
 
 
 
 
 
90
 
91
- def _load_state_dict(weights_path: Path) -> Dict[str, Any]:
92
- sd = torch.load(str(weights_path), map_location="cpu")
93
- if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict):
94
- sd = sd["state_dict"]
95
 
96
- if not isinstance(sd, dict):
97
- raise ValueError("Weights file did not contain a state_dict-like dict.")
98
 
99
- # Remove 'module.' prefixes from DDP-trained checkpoints
100
- clean = {}
 
 
 
 
 
101
  for k, v in sd.items():
102
- nk = k[7:] if k.startswith("module.") else k
103
- clean[nk] = v
104
- return clean
105
-
106
-
107
- class _OutputWrapper(torch.nn.Module):
108
- def __init__(self, model: torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  super().__init__()
110
  self.model = model
 
111
 
112
  def forward(self, x: torch.Tensor) -> torch.Tensor:
113
- y = self.model(x)
114
-
115
- # BiRefNet sometimes returns lists/tuples/dicts; ONNX export wants a tensor.
116
- if isinstance(y, torch.Tensor):
117
- return y
118
-
119
- if isinstance(y, (list, tuple)) and len(y) > 0:
120
- # Most BiRefNet variants put the final prediction at the end
121
- last = y[-1]
122
- if isinstance(last, torch.Tensor):
123
- return last
124
- # fallback: first tensor found
125
- for item in y:
126
- if isinstance(item, torch.Tensor):
127
- return item
128
-
129
- if isinstance(y, dict):
130
- for key in ("pred", "mask", "out", "logits"):
131
- if key in y and isinstance(y[key], torch.Tensor):
132
- return y[key]
133
- # fallback: first tensor value
134
- for v in y.values():
135
- if isinstance(v, torch.Tensor):
136
- return v
137
-
138
- raise TypeError(f"Model forward returned unsupported type for ONNX export: {type(y)}")
139
 
140
 
141
  def main() -> None:
142
- ap = argparse.ArgumentParser()
143
- ap.add_argument("--code_dir", type=str, required=True, help="Folder containing birefnet.py (downloaded BiRefNet code)")
144
- ap.add_argument("--weights", type=str, required=True, help="Path to .pth weights")
145
- ap.add_argument("--output", type=str, required=True, help="Output .onnx path")
146
- ap.add_argument("--img_size", type=int, default=1024, help="Square input size (default 1024)")
147
- ap.add_argument("--opset", type=int, default=17, help="ONNX opset (default 17)")
148
- ap.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
149
- ap.add_argument("--half", action="store_true", help="Export in fp16 (not always supported)")
150
- ap.add_argument("--dynamic_axes", action="store_true", help="Enable dynamic batch axis")
151
- ap.add_argument("--external_data", action="store_true", help="Use external data format (for >2GB models)")
152
- args = ap.parse_args()
153
-
154
- code_dir = Path(args.code_dir).resolve()
155
- weights_path = Path(args.weights).resolve()
156
- out_path = Path(args.output).resolve()
157
 
158
  _print_env(args.device)
159
- _ensure_transformers_torch_backend_ok()
160
- _try_register_deformconv_exporter()
161
 
162
- if not code_dir.exists():
163
- raise FileNotFoundError(f"--code_dir not found: {code_dir}")
164
- if not weights_path.exists():
165
- raise FileNotFoundError(f"--weights not found: {weights_path}")
166
-
167
- sys.path.insert(0, str(code_dir))
168
- try:
169
- from birefnet import BiRefNet # type: ignore
170
- except Exception as e:
171
- raise RuntimeError(
172
- "Failed to import BiRefNet from your --code_dir.\n"
173
- f"code_dir={code_dir}\n"
174
- "Make sure birefnet.py exists there.\n"
175
- f"Original error: {e}"
176
- )
177
 
178
  print("== Building model ==")
179
  model = BiRefNet(bb_pretrained=False)
 
180
 
181
  print("== Loading weights ==")
182
- sd = _load_state_dict(weights_path)
183
- missing, unexpected = model.load_state_dict(sd, strict=False)
184
- print(f"Missing keys: {len(missing)}; Unexpected keys: {len(unexpected)}")
185
- if len(missing) > 0:
186
- print(" (missing example):", missing[:10])
187
- if len(unexpected) > 0:
188
- print(" (unexpected example):", unexpected[:10])
189
-
190
- device = torch.device(args.device if args.device != "cuda" else ("cuda" if torch.cuda.is_available() else "cpu"))
191
- model.eval().to(device)
192
-
193
- dtype = torch.float16 if args.half else torch.float32
194
- wrapper = _OutputWrapper(model).to(device)
195
-
196
- dummy = torch.randn(1, 3, args.img_size, args.img_size, device=device, dtype=dtype)
197
-
198
- # Quick forward to ensure it runs
 
 
 
 
 
 
 
 
 
 
 
199
  with torch.no_grad():
200
- _ = wrapper(dummy)
201
-
202
- out_path.parent.mkdir(parents=True, exist_ok=True)
203
 
204
  print("== Exporting ONNX ==")
205
- input_names = ["input"]
206
- output_names = ["output"]
207
 
 
 
208
  dynamic_axes = None
209
- if args.dynamic_axes:
210
- dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
211
-
212
- # External data format is needed if model > 2GB.
213
- use_external_data_format = bool(args.external_data)
214
-
215
- torch.onnx.export(
216
- wrapper,
217
- dummy,
218
- str(out_path),
219
- export_params=True,
220
- opset_version=args.opset,
221
- do_constant_folding=True,
222
- input_names=input_names,
223
- output_names=output_names,
224
- dynamic_axes=dynamic_axes,
225
- use_external_data_format=use_external_data_format,
226
- )
 
227
 
228
- print("Saved:", out_path)
229
- if use_external_data_format:
230
- print("External data format enabled: you may also have an .onnx.data file next to the ONNX.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  print("== Done ==")
233
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ BiRefNet (.pth) -> ONNX exporter that works with:
 
 
4
  - Python 3.10
5
+ - torch==2.0.1 (+cu118 recommended)
6
+ - transformers==4.42.4
7
+
8
+ Fixes:
9
+ - BiRefNet HF code uses relative imports (e.g. from .BiRefNet_config import ...),
10
+ so --code_dir must be imported as a *package*.
11
+ - Some public scripts pass use_external_data_format to torch.onnx.export, but
12
+ torch 2.0.1 does NOT support that keyword.
13
+ - Some checkpoints are saved from torch.compile and have keys prefixed with `_orig_mod.`.
 
 
 
14
  """
15
 
16
  from __future__ import annotations
17
 
18
  import argparse
19
+ import importlib
20
  import os
21
  import sys
22
+ from typing import Any, Dict, Iterable
 
23
 
24
  import torch
25
 
 
29
  print("Python:", sys.version.replace("\n", " "))
30
  print("Torch:", torch.__version__)
31
  print("CUDA available:", torch.cuda.is_available())
32
+ if torch.cuda.is_available():
 
33
  try:
34
  idx = torch.cuda.current_device()
35
+ print("CUDA device:", torch.cuda.get_device_name(idx))
36
  except Exception:
37
  pass
38
+ print("Requested device:", device)
 
 
 
 
39
 
40
 
41
+ def _try_register_deform_conv2d() -> bool:
42
+ """
43
+ Optional: register ONNX symbolic for torchvision's DeformConv2d.
44
+ Provided by deform-conv2d-onnx-exporter.
45
+ """
46
  try:
47
+ import deform_conv2d_onnx_exporter # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()
50
+ print("DeformConv2d ONNX exporter: OK")
51
+ return True
52
+ except Exception as e:
53
+ print("DeformConv2d ONNX exporter: NOT LOADED (may fail if model uses DeformConv)")
54
+ print(" Reason:", repr(e))
55
+ return False
56
+
57
+
58
+ def _ensure_pkg_and_import(code_dir: str):
59
+ """
60
+ Make sure code_dir is a real python package, then import <pkg>.birefnet
61
+ so that relative imports inside birefnet.py work.
62
+ """
63
+ code_dir = os.path.abspath(code_dir)
64
+ if not os.path.isdir(code_dir):
65
+ raise FileNotFoundError(f"--code_dir not found or not a directory: {code_dir}")
66
+
67
+ init_py = os.path.join(code_dir, "__init__.py")
68
+ if not os.path.exists(init_py):
69
+ # create empty __init__.py to make it a package
70
+ open(init_py, "a", encoding="utf-8").close()
71
+
72
+ pkg_name = os.path.basename(code_dir.rstrip("/"))
73
+ parent_dir = os.path.dirname(code_dir)
74
+ if parent_dir not in sys.path:
75
+ sys.path.insert(0, parent_dir)
76
+
77
+ # Import as package to satisfy relative imports
78
+ mod = importlib.import_module(f"{pkg_name}.birefnet")
79
+ return mod, pkg_name
80
+
81
+
82
+ def _extract_state_dict(ckpt_obj: Any) -> Dict[str, torch.Tensor]:
83
+ """
84
+ Accepts various checkpoint formats and returns a plain state_dict.
85
+ """
86
+ if isinstance(ckpt_obj, dict):
87
+ # common nesting keys
88
+ for k in ("state_dict", "model_state_dict", "model", "net", "params", "ema"):
89
+ v = ckpt_obj.get(k, None)
90
+ if isinstance(v, dict):
91
+ ckpt_obj = v
92
  break
 
 
93
 
94
+ if not isinstance(ckpt_obj, dict):
95
+ raise RuntimeError("Unsupported checkpoint format: expected a dict/state_dict.")
96
 
97
+ # At this point it should be {str: Tensor}
98
+ sd: Dict[str, torch.Tensor] = {}
99
+ for k, v in ckpt_obj.items():
100
+ if isinstance(k, str) and torch.is_tensor(v):
101
+ sd[k] = v
102
 
103
+ if not sd:
104
+ raise RuntimeError("Checkpoint dict contained no tensor parameters.")
105
+ return sd
 
106
 
 
 
107
 
108
+ def _normalize_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
109
+ """
110
+ Fix common prefixes:
111
+ - torch.compile checkpoints: `_orig_mod.`
112
+ - DataParallel / DDP: `module.`
113
+ """
114
+ out: Dict[str, torch.Tensor] = {}
115
  for k, v in sd.items():
116
+ nk = k
117
+ if nk.startswith("_orig_mod."):
118
+ nk = nk[len("_orig_mod.") :]
119
+ if nk.startswith("module."):
120
+ nk = nk[len("module.") :]
121
+ out[nk] = v
122
+ return out
123
+
124
+
125
+ def _iter_tensors(x: Any) -> Iterable[torch.Tensor]:
126
+ if torch.is_tensor(x):
127
+ yield x
128
+ elif isinstance(x, dict):
129
+ for v in x.values():
130
+ yield from _iter_tensors(v)
131
+ elif isinstance(x, (list, tuple)):
132
+ for v in x:
133
+ yield from _iter_tensors(v)
134
+
135
+
136
+ def _pick_best_output(out: Any, img_size: int | None) -> torch.Tensor:
137
+ """
138
+ BiRefNet forward can return nested structures (list/tuple/dict).
139
+ We want a single mask tensor [N,1,H,W] if possible.
140
+ """
141
+ tensors = list(_iter_tensors(out))
142
+ if not tensors:
143
+ raise RuntimeError("Model forward produced no tensors to export.")
144
+
145
+ # Prefer rank-4 tensors
146
+ cands = [t for t in tensors if t.dim() == 4]
147
+
148
+ # Prefer exact H/W match if provided
149
+ if img_size is not None and cands:
150
+ cands_hw = [t for t in cands if int(t.shape[-2]) == img_size and int(t.shape[-1]) == img_size]
151
+ if cands_hw:
152
+ cands = cands_hw
153
+
154
+ # Prefer single-channel outputs
155
+ if cands:
156
+ cands_c1 = [t for t in cands if int(t.shape[1]) == 1]
157
+ if cands_c1:
158
+ cands = cands_c1
159
+
160
+ return cands[0] if cands else tensors[0]
161
+
162
+
163
+ class _ExportWrapper(torch.nn.Module):
164
+ def __init__(self, model: torch.nn.Module, img_size: int | None):
165
  super().__init__()
166
  self.model = model
167
+ self.img_size = img_size
168
 
169
  def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ out = self.model(x)
171
+ y = _pick_best_output(out, self.img_size)
172
+ return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
 
175
  def main() -> None:
176
+ p = argparse.ArgumentParser()
177
+ p.add_argument("--code_dir", required=True, help="Folder that contains birefnet.py and BiRefNet_config.py")
178
+ p.add_argument("--weights", required=True, help="Path to .pth weights")
179
+ p.add_argument("--output", required=True, help="Output ONNX path, e.g. out.onnx")
180
+ p.add_argument("--img_size", type=int, default=1024, help="Dummy input resolution (square), default 1024")
181
+ p.add_argument("--opset", type=int, default=17, help="ONNX opset, default 17")
182
+ p.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="cuda or cpu")
183
+ p.add_argument("--dynamic", action="store_true", help="Export dynamic H/W axes (may break export)")
184
+ p.add_argument(
185
+ "--external_data",
186
+ action="store_true",
187
+ help="After export, re-save ONNX using external data (.onnx + .onnx.data).",
188
+ )
189
+ p.add_argument("--skip_onnx_check", action="store_true", help="Skip onnx.checker.check_model()")
190
+ args = p.parse_args()
191
 
192
  _print_env(args.device)
193
+ _try_register_deform_conv2d()
 
194
 
195
+ # Import model properly (as a package)
196
+ birefnet_mod, pkg_name = _ensure_pkg_and_import(args.code_dir)
197
+ if not hasattr(birefnet_mod, "BiRefNet"):
198
+ raise RuntimeError(f"BiRefNet class not found in {pkg_name}.birefnet")
199
+ BiRefNet = getattr(birefnet_mod, "BiRefNet")
 
 
 
 
 
 
 
 
 
 
200
 
201
  print("== Building model ==")
202
  model = BiRefNet(bb_pretrained=False)
203
+ model.eval()
204
 
205
  print("== Loading weights ==")
206
+ ckpt = torch.load(args.weights, map_location="cpu")
207
+ sd = _extract_state_dict(ckpt)
208
+ sd = _normalize_state_dict_keys(sd)
209
+
210
+ incompatible = model.load_state_dict(sd, strict=False)
211
+ missing = list(getattr(incompatible, "missing_keys", []))
212
+ unexpected = list(getattr(incompatible, "unexpected_keys", []))
213
+ print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
214
+ if missing:
215
+ print(" (first 20 missing):", missing[:20])
216
+ if unexpected:
217
+ print(" (first 20 unexpected):", unexpected[:20])
218
+
219
+ if args.device == "cuda":
220
+ if not torch.cuda.is_available():
221
+ raise RuntimeError("You asked for --device cuda but CUDA is not available.")
222
+ model = model.to("cuda")
223
+ dev = "cuda"
224
+ else:
225
+ model = model.to("cpu")
226
+ dev = "cpu"
227
+
228
+ wrapper = _ExportWrapper(model, img_size=args.img_size)
229
+ wrapper.eval()
230
+
231
+ dummy = torch.randn(1, 3, args.img_size, args.img_size, device=dev)
232
+
233
+ print("== Forward probe ==")
234
  with torch.no_grad():
235
+ probe_out = wrapper(dummy)
236
+ print("Picked output tensor shape:", tuple(probe_out.shape), "dtype:", probe_out.dtype)
 
237
 
238
  print("== Exporting ONNX ==")
239
+ out_path = os.path.abspath(args.output)
240
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
241
 
242
+ input_names = ["input"]
243
+ output_names = ["mask"]
244
  dynamic_axes = None
245
+ if args.dynamic:
246
+ dynamic_axes = {
247
+ "input": {0: "batch", 2: "height", 3: "width"},
248
+ "mask": {0: "batch", 2: "height", 3: "width"},
249
+ }
250
+
251
+ with torch.no_grad():
252
+ # IMPORTANT: torch 2.0.1 does NOT support use_external_data_format.
253
+ torch.onnx.export(
254
+ wrapper,
255
+ dummy,
256
+ out_path,
257
+ export_params=True,
258
+ opset_version=args.opset,
259
+ do_constant_folding=True,
260
+ input_names=input_names,
261
+ output_names=output_names,
262
+ dynamic_axes=dynamic_axes,
263
+ )
264
 
265
+ print("Output:", out_path)
266
+
267
+ if args.external_data or (not args.skip_onnx_check):
268
+ import onnx # type: ignore
269
+
270
+ print("== Loading ONNX ==")
271
+ onnx_model = onnx.load(out_path)
272
+
273
+ if not args.skip_onnx_check:
274
+ print("== ONNX checker ==")
275
+ onnx.checker.check_model(onnx_model)
276
+ print("ONNX checker: OK")
277
+
278
+ if args.external_data:
279
+ print("== Saving external data ==")
280
+ data_name = os.path.basename(out_path) + ".data"
281
+ onnx.save_model(
282
+ onnx_model,
283
+ out_path,
284
+ save_as_external_data=True,
285
+ all_tensors_to_one_file=True,
286
+ location=data_name,
287
+ size_threshold=1024, # bytes; moves almost everything out
288
+ )
289
+ print("Saved external-data ONNX:")
290
+ print(" Model:", out_path)
291
+ print(" Data :", os.path.join(os.path.dirname(out_path), data_name))
292
 
293
  print("== Done ==")
294