saliacoel commited on
Commit
34fc213
·
verified ·
1 Parent(s): d7390a9

Update export_birefnet_onnx.py

Browse files
Files changed (1) hide show
  1. export_birefnet_onnx.py +375 -297
export_birefnet_onnx.py CHANGED
@@ -1,297 +1,375 @@
1
- #!/usr/bin/env python3
2
- """
3
- BiRefNet (.pth) -> ONNX exporter that works with:
4
- - Python 3.10
5
- - torch==2.0.1 (+cu118 recommended)
6
- - transformers==4.42.4
7
-
8
- Fixes:
9
- - BiRefNet HF code uses relative imports (e.g. from .BiRefNet_config import ...),
10
- so --code_dir must be imported as a *package*.
11
- - Some public scripts pass use_external_data_format to torch.onnx.export, but
12
- torch 2.0.1 does NOT support that keyword.
13
- - Some checkpoints are saved from torch.compile and have keys prefixed with `_orig_mod.`.
14
- """
15
-
16
- from __future__ import annotations
17
-
18
- import argparse
19
- import importlib
20
- import os
21
- import sys
22
- from typing import Any, Dict, Iterable
23
-
24
- import torch
25
-
26
-
27
- def _print_env(device: str) -> None:
28
- print("== Environment ==")
29
- print("Python:", sys.version.replace("\n", " "))
30
- print("Torch:", torch.__version__)
31
- print("CUDA available:", torch.cuda.is_available())
32
- if torch.cuda.is_available():
33
- try:
34
- idx = torch.cuda.current_device()
35
- print("CUDA device:", torch.cuda.get_device_name(idx))
36
- except Exception:
37
- pass
38
- print("Requested device:", device)
39
-
40
-
41
- def _try_register_deform_conv2d() -> bool:
42
- """
43
- Optional: register ONNX symbolic for torchvision's DeformConv2d.
44
- Provided by deform-conv2d-onnx-exporter.
45
- """
46
- try:
47
- import deform_conv2d_onnx_exporter # type: ignore
48
-
49
- deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()
50
- print("DeformConv2d ONNX exporter: OK")
51
- return True
52
- except Exception as e:
53
- print("DeformConv2d ONNX exporter: NOT LOADED (may fail if model uses DeformConv)")
54
- print(" Reason:", repr(e))
55
- return False
56
-
57
-
58
- def _ensure_pkg_and_import(code_dir: str):
59
- """
60
- Make sure code_dir is a real python package, then import <pkg>.birefnet
61
- so that relative imports inside birefnet.py work.
62
- """
63
- code_dir = os.path.abspath(code_dir)
64
- if not os.path.isdir(code_dir):
65
- raise FileNotFoundError(f"--code_dir not found or not a directory: {code_dir}")
66
-
67
- init_py = os.path.join(code_dir, "__init__.py")
68
- if not os.path.exists(init_py):
69
- # create empty __init__.py to make it a package
70
- open(init_py, "a", encoding="utf-8").close()
71
-
72
- pkg_name = os.path.basename(code_dir.rstrip("/"))
73
- parent_dir = os.path.dirname(code_dir)
74
- if parent_dir not in sys.path:
75
- sys.path.insert(0, parent_dir)
76
-
77
- # Import as package to satisfy relative imports
78
- mod = importlib.import_module(f"{pkg_name}.birefnet")
79
- return mod, pkg_name
80
-
81
-
82
- def _extract_state_dict(ckpt_obj: Any) -> Dict[str, torch.Tensor]:
83
- """
84
- Accepts various checkpoint formats and returns a plain state_dict.
85
- """
86
- if isinstance(ckpt_obj, dict):
87
- # common nesting keys
88
- for k in ("state_dict", "model_state_dict", "model", "net", "params", "ema"):
89
- v = ckpt_obj.get(k, None)
90
- if isinstance(v, dict):
91
- ckpt_obj = v
92
- break
93
-
94
- if not isinstance(ckpt_obj, dict):
95
- raise RuntimeError("Unsupported checkpoint format: expected a dict/state_dict.")
96
-
97
- # At this point it should be {str: Tensor}
98
- sd: Dict[str, torch.Tensor] = {}
99
- for k, v in ckpt_obj.items():
100
- if isinstance(k, str) and torch.is_tensor(v):
101
- sd[k] = v
102
-
103
- if not sd:
104
- raise RuntimeError("Checkpoint dict contained no tensor parameters.")
105
- return sd
106
-
107
-
108
- def _normalize_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
109
- """
110
- Fix common prefixes:
111
- - torch.compile checkpoints: `_orig_mod.`
112
- - DataParallel / DDP: `module.`
113
- """
114
- out: Dict[str, torch.Tensor] = {}
115
- for k, v in sd.items():
116
- nk = k
117
- if nk.startswith("_orig_mod."):
118
- nk = nk[len("_orig_mod.") :]
119
- if nk.startswith("module."):
120
- nk = nk[len("module.") :]
121
- out[nk] = v
122
- return out
123
-
124
-
125
- def _iter_tensors(x: Any) -> Iterable[torch.Tensor]:
126
- if torch.is_tensor(x):
127
- yield x
128
- elif isinstance(x, dict):
129
- for v in x.values():
130
- yield from _iter_tensors(v)
131
- elif isinstance(x, (list, tuple)):
132
- for v in x:
133
- yield from _iter_tensors(v)
134
-
135
-
136
- def _pick_best_output(out: Any, img_size: int | None) -> torch.Tensor:
137
- """
138
- BiRefNet forward can return nested structures (list/tuple/dict).
139
- We want a single mask tensor [N,1,H,W] if possible.
140
- """
141
- tensors = list(_iter_tensors(out))
142
- if not tensors:
143
- raise RuntimeError("Model forward produced no tensors to export.")
144
-
145
- # Prefer rank-4 tensors
146
- cands = [t for t in tensors if t.dim() == 4]
147
-
148
- # Prefer exact H/W match if provided
149
- if img_size is not None and cands:
150
- cands_hw = [t for t in cands if int(t.shape[-2]) == img_size and int(t.shape[-1]) == img_size]
151
- if cands_hw:
152
- cands = cands_hw
153
-
154
- # Prefer single-channel outputs
155
- if cands:
156
- cands_c1 = [t for t in cands if int(t.shape[1]) == 1]
157
- if cands_c1:
158
- cands = cands_c1
159
-
160
- return cands[0] if cands else tensors[0]
161
-
162
-
163
- class _ExportWrapper(torch.nn.Module):
164
- def __init__(self, model: torch.nn.Module, img_size: int | None):
165
- super().__init__()
166
- self.model = model
167
- self.img_size = img_size
168
-
169
- def forward(self, x: torch.Tensor) -> torch.Tensor:
170
- out = self.model(x)
171
- y = _pick_best_output(out, self.img_size)
172
- return y
173
-
174
-
175
- def main() -> None:
176
- p = argparse.ArgumentParser()
177
- p.add_argument("--code_dir", required=True, help="Folder that contains birefnet.py and BiRefNet_config.py")
178
- p.add_argument("--weights", required=True, help="Path to .pth weights")
179
- p.add_argument("--output", required=True, help="Output ONNX path, e.g. out.onnx")
180
- p.add_argument("--img_size", type=int, default=1024, help="Dummy input resolution (square), default 1024")
181
- p.add_argument("--opset", type=int, default=17, help="ONNX opset, default 17")
182
- p.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="cuda or cpu")
183
- p.add_argument("--dynamic", action="store_true", help="Export dynamic H/W axes (may break export)")
184
- p.add_argument(
185
- "--external_data",
186
- action="store_true",
187
- help="After export, re-save ONNX using external data (.onnx + .onnx.data).",
188
- )
189
- p.add_argument("--skip_onnx_check", action="store_true", help="Skip onnx.checker.check_model()")
190
- args = p.parse_args()
191
-
192
- _print_env(args.device)
193
- _try_register_deform_conv2d()
194
-
195
- # Import model properly (as a package)
196
- birefnet_mod, pkg_name = _ensure_pkg_and_import(args.code_dir)
197
- if not hasattr(birefnet_mod, "BiRefNet"):
198
- raise RuntimeError(f"BiRefNet class not found in {pkg_name}.birefnet")
199
- BiRefNet = getattr(birefnet_mod, "BiRefNet")
200
-
201
- print("== Building model ==")
202
- model = BiRefNet(bb_pretrained=False)
203
- model.eval()
204
-
205
- print("== Loading weights ==")
206
- ckpt = torch.load(args.weights, map_location="cpu")
207
- sd = _extract_state_dict(ckpt)
208
- sd = _normalize_state_dict_keys(sd)
209
-
210
- incompatible = model.load_state_dict(sd, strict=False)
211
- missing = list(getattr(incompatible, "missing_keys", []))
212
- unexpected = list(getattr(incompatible, "unexpected_keys", []))
213
- print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
214
- if missing:
215
- print(" (first 20 missing):", missing[:20])
216
- if unexpected:
217
- print(" (first 20 unexpected):", unexpected[:20])
218
-
219
- if args.device == "cuda":
220
- if not torch.cuda.is_available():
221
- raise RuntimeError("You asked for --device cuda but CUDA is not available.")
222
- model = model.to("cuda")
223
- dev = "cuda"
224
- else:
225
- model = model.to("cpu")
226
- dev = "cpu"
227
-
228
- wrapper = _ExportWrapper(model, img_size=args.img_size)
229
- wrapper.eval()
230
-
231
- dummy = torch.randn(1, 3, args.img_size, args.img_size, device=dev)
232
-
233
- print("== Forward probe ==")
234
- with torch.no_grad():
235
- probe_out = wrapper(dummy)
236
- print("Picked output tensor shape:", tuple(probe_out.shape), "dtype:", probe_out.dtype)
237
-
238
- print("== Exporting ONNX ==")
239
- out_path = os.path.abspath(args.output)
240
- os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
241
-
242
- input_names = ["input"]
243
- output_names = ["mask"]
244
- dynamic_axes = None
245
- if args.dynamic:
246
- dynamic_axes = {
247
- "input": {0: "batch", 2: "height", 3: "width"},
248
- "mask": {0: "batch", 2: "height", 3: "width"},
249
- }
250
-
251
- with torch.no_grad():
252
- # IMPORTANT: torch 2.0.1 does NOT support use_external_data_format.
253
- torch.onnx.export(
254
- wrapper,
255
- dummy,
256
- out_path,
257
- export_params=True,
258
- opset_version=args.opset,
259
- do_constant_folding=True,
260
- input_names=input_names,
261
- output_names=output_names,
262
- dynamic_axes=dynamic_axes,
263
- )
264
-
265
- print("Output:", out_path)
266
-
267
- if args.external_data or (not args.skip_onnx_check):
268
- import onnx # type: ignore
269
-
270
- print("== Loading ONNX ==")
271
- onnx_model = onnx.load(out_path)
272
-
273
- if not args.skip_onnx_check:
274
- print("== ONNX checker ==")
275
- onnx.checker.check_model(onnx_model)
276
- print("ONNX checker: OK")
277
-
278
- if args.external_data:
279
- print("== Saving external data ==")
280
- data_name = os.path.basename(out_path) + ".data"
281
- onnx.save_model(
282
- onnx_model,
283
- out_path,
284
- save_as_external_data=True,
285
- all_tensors_to_one_file=True,
286
- location=data_name,
287
- size_threshold=1024, # bytes; moves almost everything out
288
- )
289
- print("Saved external-data ONNX:")
290
- print(" Model:", out_path)
291
- print(" Data :", os.path.join(os.path.dirname(out_path), data_name))
292
-
293
- print("== Done ==")
294
-
295
-
296
- if __name__ == "__main__":
297
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BiRefNet .pth -> ONNX exporter (CPU/GPU), with robust deform_conv2d ONNX patch.
4
+
5
+ Fixes:
6
+ - deform_conv2d_onnx_exporter get_tensor_dim_size returning None (NoneType + int crash)
7
+ - checkpoints saved with _orig_mod. prefix (torch.compile)
8
+ - supports code_dir layouts:
9
+ A) HuggingFace-style: code_dir/birefnet.py (class BiRefNet inside)
10
+ B) GitHub-style: code_dir/models/birefnet.py + code_dir/utils.py
11
+
12
+ Recommended baseline: torch==2.0.1, opset 17, fixed input size (e.g. 1024x1024).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import importlib
19
+ import inspect
20
+ import os
21
+ import re
22
+ import sys
23
+ from typing import Any, Dict, Iterable, List, Tuple
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ # -------------------------
30
+ # DeformConv2d ONNX patching
31
+ # -------------------------
32
+
33
+ def _patch_and_register_deform_conv2d() -> None:
34
+ """
35
+ Patch deform_conv2d_onnx_exporter.get_tensor_dim_size so it never returns None
36
+ for H/W when possible (fallback to tensor type sizes/strides), then register the op.
37
+
38
+ This specifically fixes:
39
+ TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
40
+ at create_dcn_params(...): in_h = get_tensor_dim_size(input, 2) + ...
41
+ """
42
+ try:
43
+ import deform_conv2d_onnx_exporter as d
44
+ import torch.onnx.symbolic_helper as sym_help
45
+ except Exception as e:
46
+ print(f"[deform_conv2d] exporter not available ({type(e).__name__}: {e})")
47
+ return
48
+
49
+ if not hasattr(d, "get_tensor_dim_size"):
50
+ print("[deform_conv2d] deform_conv2d_onnx_exporter.get_tensor_dim_size not found; cannot patch.")
51
+ return
52
+
53
+ orig_get = d.get_tensor_dim_size
54
+
55
+ def patched_get_tensor_dim_size(tensor, dim: int):
56
+ # 1) Try original
57
+ v = orig_get(tensor, dim)
58
+ if v is not None:
59
+ return v
60
+
61
+ # 2) Try torch's internal tensor sizes helper (sometimes more available than _get_tensor_dim_size)
62
+ try:
63
+ sizes = sym_help._get_tensor_sizes(tensor) # type: ignore[attr-defined]
64
+ if sizes is not None and len(sizes) > dim and sizes[dim] is not None:
65
+ return int(sizes[dim])
66
+ except Exception:
67
+ pass
68
+
69
+ # 3) Try TensorType sizes/strides (Colab-style fallback)
70
+ try:
71
+ import typing
72
+ from torch import _C
73
+
74
+ ttype = typing.cast(_C.TensorType, tensor.type())
75
+ tsizes = ttype.sizes()
76
+ if tsizes is not None and len(tsizes) > dim and tsizes[dim] is not None:
77
+ return int(tsizes[dim])
78
+
79
+ tstrides = ttype.strides()
80
+ # For contiguous NCHW: strides = (C*H*W, H*W, W, 1)
81
+ if tstrides is not None and len(tstrides) >= 4:
82
+ s0, s1, s2, s3 = tstrides[0], tstrides[1], tstrides[2], tstrides[3]
83
+
84
+ if dim == 3 and s2 is not None:
85
+ return int(s2) # W
86
+
87
+ if dim == 2 and s1 is not None and s2 not in (None, 0):
88
+ return int(s1 // s2) # H = (H*W)/W
89
+
90
+ if dim == 1 and s0 is not None and s1 not in (None, 0):
91
+ return int(s0 // s1) # C = (C*H*W)/(H*W)
92
+
93
+ if dim == 0:
94
+ # We export with batch=1 dummy input; safe fallback.
95
+ return 1
96
+ except Exception:
97
+ pass
98
+
99
+ # 4) Last-resort: batch=1 fallback, otherwise hard error with actionable message
100
+ if dim == 0:
101
+ return 1
102
+
103
+ raise RuntimeError(
104
+ f"[deform_conv2d] Could not infer static dim={dim} for a tensor during ONNX export "
105
+ f"(got None from torch). This typically happens with dynamic axes or missing shape info. "
106
+ f"Use a fixed input size (no dynamic axes) and export again."
107
+ )
108
+
109
+ d.get_tensor_dim_size = patched_get_tensor_dim_size # type: ignore[assignment]
110
+
111
+ # Register op after patching so the symbolic uses our patched helper at runtime
112
+ try:
113
+ d.register_deform_conv2d_onnx_op()
114
+ print("[deform_conv2d] Patched get_tensor_dim_size + registered deform_conv2d ONNX op.")
115
+ except Exception as e:
116
+ print(f"[deform_conv2d] register_deform_conv2d_onnx_op failed ({type(e).__name__}: {e})")
117
+
118
+
119
+ # -------------------------
120
+ # BiRefNet importing helpers
121
+ # -------------------------
122
+
123
+ def _ensure_importable_package_dir(code_dir: str) -> Tuple[str, str]:
124
+ """
125
+ Make code_dir importable as a package so relative imports inside it work.
126
+ Used for HF-style code_dir that contains birefnet.py and BiRefNet_config.py.
127
+ """
128
+ code_dir = os.path.abspath(code_dir)
129
+ parent = os.path.dirname(code_dir)
130
+ pkg = os.path.basename(code_dir)
131
+
132
+ init_py = os.path.join(code_dir, "__init__.py")
133
+ if not os.path.exists(init_py):
134
+ open(init_py, "a", encoding="utf-8").close()
135
+
136
+ if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", pkg):
137
+ safe_pkg = "birefnet_pkg"
138
+ safe_dir = os.path.join(parent, safe_pkg)
139
+ if not os.path.exists(safe_dir):
140
+ os.symlink(code_dir, safe_dir)
141
+ pkg = safe_pkg
142
+ code_dir = safe_dir
143
+ init_py = os.path.join(code_dir, "__init__.py")
144
+ if not os.path.exists(init_py):
145
+ open(init_py, "a", encoding="utf-8").close()
146
+
147
+ if parent not in sys.path:
148
+ sys.path.insert(0, parent)
149
+
150
+ return pkg, code_dir
151
+
152
+
153
+ def _detect_layout(code_dir: str) -> str:
154
+ code_dir = os.path.abspath(code_dir)
155
+ if os.path.isfile(os.path.join(code_dir, "models", "birefnet.py")) and os.path.isfile(os.path.join(code_dir, "utils.py")):
156
+ return "github"
157
+ if os.path.isfile(os.path.join(code_dir, "birefnet.py")):
158
+ return "hf"
159
+ raise FileNotFoundError(
160
+ f"Could not detect BiRefNet layout in {code_dir}.\n"
161
+ f"Expected either:\n"
162
+ f" - GitHub layout: models/birefnet.py and utils.py\n"
163
+ f" - HF layout: birefnet.py\n"
164
+ )
165
+
166
+
167
+ def _import_birefnet(code_dir: str):
168
+ layout = _detect_layout(code_dir)
169
+
170
+ if layout == "github":
171
+ # Mirror Colab: `from utils import check_state_dict` and `from models.birefnet import BiRefNet`
172
+ if code_dir not in sys.path:
173
+ sys.path.insert(0, code_dir)
174
+ from utils import check_state_dict # type: ignore
175
+ from models.birefnet import BiRefNet # type: ignore
176
+ return layout, BiRefNet, check_state_dict
177
+
178
+ # HF layout
179
+ pkg, _ = _ensure_importable_package_dir(code_dir)
180
+ mod = importlib.import_module(f"{pkg}.birefnet")
181
+ if not hasattr(mod, "BiRefNet"):
182
+ raise RuntimeError(f"BiRefNet class not found in {pkg}.birefnet")
183
+ return layout, getattr(mod, "BiRefNet"), None
184
+
185
+
186
+ # -------------------------
187
+ # Weight loading helpers
188
+ # -------------------------
189
+
190
+ def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
191
+ if isinstance(obj, dict):
192
+ if obj and all(torch.is_tensor(v) for v in obj.values()):
193
+ return obj # type: ignore[return-value]
194
+ for k in ["state_dict", "model", "model_state_dict", "net", "params", "weights", "ema"]:
195
+ if k in obj and isinstance(obj[k], dict) and obj[k] and all(torch.is_tensor(v) for v in obj[k].values()):
196
+ return obj[k] # type: ignore[return-value]
197
+ for v in obj.values():
198
+ if isinstance(v, dict) and v and all(torch.is_tensor(tv) for tv in v.values()):
199
+ return v # type: ignore[return-value]
200
+ raise RuntimeError("Could not find a state_dict inside the checkpoint.")
201
+
202
+
203
+ def _clean_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
204
+ prefixes = ["module.", "_orig_mod.", "model.", "net.", "state_dict."]
205
+ out: Dict[str, torch.Tensor] = {}
206
+ for k, v in sd.items():
207
+ nk = k
208
+ changed = True
209
+ while changed:
210
+ changed = False
211
+ for p in prefixes:
212
+ if nk.startswith(p):
213
+ nk = nk[len(p):]
214
+ changed = True
215
+ out[nk] = v
216
+ return out
217
+
218
+
219
+ def _pretty_list(xs: List[str], n: int = 20) -> List[str]:
220
+ return xs[:n] + (["..."] if len(xs) > n else [])
221
+
222
+
223
+ # -------------------------
224
+ # Output selection / wrapper
225
+ # -------------------------
226
+
227
+ def _walk_tensors(x: Any) -> Iterable[torch.Tensor]:
228
+ if torch.is_tensor(x):
229
+ yield x
230
+ return
231
+ if isinstance(x, dict):
232
+ for v in x.values():
233
+ yield from _walk_tensors(v)
234
+ elif isinstance(x, (list, tuple)):
235
+ for v in x:
236
+ yield from _walk_tensors(v)
237
+
238
+
239
+ def _pick_output_tensor(model_out: Any, img_size: int) -> torch.Tensor:
240
+ ts = list(_walk_tensors(model_out))
241
+ if not ts:
242
+ raise RuntimeError("Model forward returned no tensors.")
243
+ # Prefer (B,1,H,W) at img_size
244
+ for t in ts:
245
+ if t.ndim == 4 and t.shape[1] in (1, 3) and t.shape[2] == img_size and t.shape[3] == img_size:
246
+ return t
247
+ # Next: any 4D tensor with H,W == img_size
248
+ for t in ts:
249
+ if t.ndim == 4 and t.shape[2] == img_size and t.shape[3] == img_size:
250
+ return t
251
+ # Else: largest tensor
252
+ return max(ts, key=lambda z: z.numel())
253
+
254
+
255
+ class ExportWrapper(nn.Module):
256
+ def __init__(self, model: nn.Module, img_size: int):
257
+ super().__init__()
258
+ self.model = model
259
+ self.img_size = img_size
260
+
261
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
262
+ x = x.contiguous()
263
+ out = self.model(x)
264
+ return _pick_output_tensor(out, self.img_size)
265
+
266
+
267
+ # -------------------------
268
+ # Main
269
+ # -------------------------
270
+
271
+ def main() -> None:
272
+ ap = argparse.ArgumentParser()
273
+ ap.add_argument("--code_dir", required=True)
274
+ ap.add_argument("--weights", required=True)
275
+ ap.add_argument("--output", required=True)
276
+ ap.add_argument("--img_size", type=int, default=1024)
277
+ ap.add_argument("--opset", type=int, default=17)
278
+ ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
279
+ ap.add_argument("--skip_onnx_check", action="store_true")
280
+ args = ap.parse_args()
281
+
282
+ print("== Environment ==")
283
+ print("Python:", sys.version.replace("\n", " "))
284
+ print("Torch:", torch.__version__)
285
+ print("CUDA available:", torch.cuda.is_available())
286
+ print("Requested device:", args.device)
287
+
288
+ if args.device == "cuda" and not torch.cuda.is_available():
289
+ raise RuntimeError("You asked for --device cuda but CUDA is not available.")
290
+
291
+ device = torch.device(args.device)
292
+ print("Using device:", device)
293
+
294
+ # IMPORTANT: patch deform_conv2d exporter BEFORE export
295
+ _patch_and_register_deform_conv2d()
296
+
297
+ layout, BiRefNet, check_state_dict = _import_birefnet(args.code_dir)
298
+ print("BiRefNet layout detected:", layout)
299
+
300
+ print("== Building model ==")
301
+ kwargs = {}
302
+ try:
303
+ sig = inspect.signature(BiRefNet)
304
+ if "bb_pretrained" in sig.parameters:
305
+ kwargs["bb_pretrained"] = False
306
+ except Exception:
307
+ pass
308
+
309
+ model = BiRefNet(**kwargs) if kwargs else BiRefNet()
310
+ model.eval().to(device)
311
+
312
+ print("== Loading weights ==")
313
+ ckpt = torch.load(args.weights, map_location="cpu")
314
+
315
+ if layout == "github" and check_state_dict is not None:
316
+ # Colab-style path
317
+ sd = check_state_dict(ckpt)
318
+ missing, unexpected = model.load_state_dict(sd, strict=False)
319
+ else:
320
+ # HF-style path
321
+ sd = _extract_state_dict(ckpt)
322
+ sd = _clean_state_dict_keys(sd)
323
+ missing, unexpected = model.load_state_dict(sd, strict=False)
324
+
325
+ missing = list(missing)
326
+ unexpected = list(unexpected)
327
+ print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
328
+ if missing:
329
+ print(" (first 20 missing):", _pretty_list(missing, 20))
330
+ if unexpected:
331
+ print(" (first 20 unexpected):", _pretty_list(unexpected, 20))
332
+
333
+ wrapper = ExportWrapper(model, img_size=args.img_size).eval().to(device)
334
+
335
+ print("== Forward probe ==")
336
+ dummy = torch.randn(1, 3, args.img_size, args.img_size, device=device)
337
+ with torch.no_grad():
338
+ out = wrapper(dummy)
339
+ print("Picked output shape:", tuple(out.shape), "dtype:", out.dtype)
340
+
341
+ print("== Exporting ONNX ==")
342
+ out_path = os.path.abspath(args.output)
343
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
344
+
345
+ # NOTE: No dynamic_axes by default (keeps shapes static and avoids shape None issues).
346
+ torch.onnx.export(
347
+ wrapper,
348
+ dummy,
349
+ out_path,
350
+ export_params=True,
351
+ opset_version=args.opset,
352
+ do_constant_folding=True,
353
+ input_names=["input"],
354
+ output_names=["output"],
355
+ verbose=False,
356
+ )
357
+
358
+ print("Saved ONNX to:", out_path)
359
+
360
+ if not args.skip_onnx_check:
361
+ print("== Checking ONNX ==")
362
+ import onnx
363
+ m = onnx.load(out_path)
364
+ onnx.checker.check_model(m)
365
+ print("ONNX check: OK")
366
+
367
+ try:
368
+ mb = os.path.getsize(out_path) / (1024 * 1024)
369
+ print(f"ONNX size: {mb:.1f} MB")
370
+ except Exception:
371
+ pass
372
+
373
+
374
+ if __name__ == "__main__":
375
+ main()