saliacoel commited on
Commit
eac2ec4
·
verified ·
1 Parent(s): 34fc213

Upload export_birefnet_onnx.py

Browse files
Files changed (1) hide show
  1. export_birefnet_onnx.py +380 -375
export_birefnet_onnx.py CHANGED
@@ -1,375 +1,380 @@
1
- #!/usr/bin/env python3
2
- """
3
- BiRefNet .pth -> ONNX exporter (CPU/GPU), with robust deform_conv2d ONNX patch.
4
-
5
- Fixes:
6
- - deform_conv2d_onnx_exporter get_tensor_dim_size returning None (NoneType + int crash)
7
- - checkpoints saved with _orig_mod. prefix (torch.compile)
8
- - supports code_dir layouts:
9
- A) HuggingFace-style: code_dir/birefnet.py (class BiRefNet inside)
10
- B) GitHub-style: code_dir/models/birefnet.py + code_dir/utils.py
11
-
12
- Recommended baseline: torch==2.0.1, opset 17, fixed input size (e.g. 1024x1024).
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import argparse
18
- import importlib
19
- import inspect
20
- import os
21
- import re
22
- import sys
23
- from typing import Any, Dict, Iterable, List, Tuple
24
-
25
- import torch
26
- import torch.nn as nn
27
-
28
-
29
- # -------------------------
30
- # DeformConv2d ONNX patching
31
- # -------------------------
32
-
33
- def _patch_and_register_deform_conv2d() -> None:
34
- """
35
- Patch deform_conv2d_onnx_exporter.get_tensor_dim_size so it never returns None
36
- for H/W when possible (fallback to tensor type sizes/strides), then register the op.
37
-
38
- This specifically fixes:
39
- TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
40
- at create_dcn_params(...): in_h = get_tensor_dim_size(input, 2) + ...
41
- """
42
- try:
43
- import deform_conv2d_onnx_exporter as d
44
- import torch.onnx.symbolic_helper as sym_help
45
- except Exception as e:
46
- print(f"[deform_conv2d] exporter not available ({type(e).__name__}: {e})")
47
- return
48
-
49
- if not hasattr(d, "get_tensor_dim_size"):
50
- print("[deform_conv2d] deform_conv2d_onnx_exporter.get_tensor_dim_size not found; cannot patch.")
51
- return
52
-
53
- orig_get = d.get_tensor_dim_size
54
-
55
- def patched_get_tensor_dim_size(tensor, dim: int):
56
- # 1) Try original
57
- v = orig_get(tensor, dim)
58
- if v is not None:
59
- return v
60
-
61
- # 2) Try torch's internal tensor sizes helper (sometimes more available than _get_tensor_dim_size)
62
- try:
63
- sizes = sym_help._get_tensor_sizes(tensor) # type: ignore[attr-defined]
64
- if sizes is not None and len(sizes) > dim and sizes[dim] is not None:
65
- return int(sizes[dim])
66
- except Exception:
67
- pass
68
-
69
- # 3) Try TensorType sizes/strides (Colab-style fallback)
70
- try:
71
- import typing
72
- from torch import _C
73
-
74
- ttype = typing.cast(_C.TensorType, tensor.type())
75
- tsizes = ttype.sizes()
76
- if tsizes is not None and len(tsizes) > dim and tsizes[dim] is not None:
77
- return int(tsizes[dim])
78
-
79
- tstrides = ttype.strides()
80
- # For contiguous NCHW: strides = (C*H*W, H*W, W, 1)
81
- if tstrides is not None and len(tstrides) >= 4:
82
- s0, s1, s2, s3 = tstrides[0], tstrides[1], tstrides[2], tstrides[3]
83
-
84
- if dim == 3 and s2 is not None:
85
- return int(s2) # W
86
-
87
- if dim == 2 and s1 is not None and s2 not in (None, 0):
88
- return int(s1 // s2) # H = (H*W)/W
89
-
90
- if dim == 1 and s0 is not None and s1 not in (None, 0):
91
- return int(s0 // s1) # C = (C*H*W)/(H*W)
92
-
93
- if dim == 0:
94
- # We export with batch=1 dummy input; safe fallback.
95
- return 1
96
- except Exception:
97
- pass
98
-
99
- # 4) Last-resort: batch=1 fallback, otherwise hard error with actionable message
100
- if dim == 0:
101
- return 1
102
-
103
- raise RuntimeError(
104
- f"[deform_conv2d] Could not infer static dim={dim} for a tensor during ONNX export "
105
- f"(got None from torch). This typically happens with dynamic axes or missing shape info. "
106
- f"Use a fixed input size (no dynamic axes) and export again."
107
- )
108
-
109
- d.get_tensor_dim_size = patched_get_tensor_dim_size # type: ignore[assignment]
110
-
111
- # Register op after patching so the symbolic uses our patched helper at runtime
112
- try:
113
- d.register_deform_conv2d_onnx_op()
114
- print("[deform_conv2d] Patched get_tensor_dim_size + registered deform_conv2d ONNX op.")
115
- except Exception as e:
116
- print(f"[deform_conv2d] register_deform_conv2d_onnx_op failed ({type(e).__name__}: {e})")
117
-
118
-
119
- # -------------------------
120
- # BiRefNet importing helpers
121
- # -------------------------
122
-
123
- def _ensure_importable_package_dir(code_dir: str) -> Tuple[str, str]:
124
- """
125
- Make code_dir importable as a package so relative imports inside it work.
126
- Used for HF-style code_dir that contains birefnet.py and BiRefNet_config.py.
127
- """
128
- code_dir = os.path.abspath(code_dir)
129
- parent = os.path.dirname(code_dir)
130
- pkg = os.path.basename(code_dir)
131
-
132
- init_py = os.path.join(code_dir, "__init__.py")
133
- if not os.path.exists(init_py):
134
- open(init_py, "a", encoding="utf-8").close()
135
-
136
- if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", pkg):
137
- safe_pkg = "birefnet_pkg"
138
- safe_dir = os.path.join(parent, safe_pkg)
139
- if not os.path.exists(safe_dir):
140
- os.symlink(code_dir, safe_dir)
141
- pkg = safe_pkg
142
- code_dir = safe_dir
143
- init_py = os.path.join(code_dir, "__init__.py")
144
- if not os.path.exists(init_py):
145
- open(init_py, "a", encoding="utf-8").close()
146
-
147
- if parent not in sys.path:
148
- sys.path.insert(0, parent)
149
-
150
- return pkg, code_dir
151
-
152
-
153
- def _detect_layout(code_dir: str) -> str:
154
- code_dir = os.path.abspath(code_dir)
155
- if os.path.isfile(os.path.join(code_dir, "models", "birefnet.py")) and os.path.isfile(os.path.join(code_dir, "utils.py")):
156
- return "github"
157
- if os.path.isfile(os.path.join(code_dir, "birefnet.py")):
158
- return "hf"
159
- raise FileNotFoundError(
160
- f"Could not detect BiRefNet layout in {code_dir}.\n"
161
- f"Expected either:\n"
162
- f" - GitHub layout: models/birefnet.py and utils.py\n"
163
- f" - HF layout: birefnet.py\n"
164
- )
165
-
166
-
167
- def _import_birefnet(code_dir: str):
168
- layout = _detect_layout(code_dir)
169
-
170
- if layout == "github":
171
- # Mirror Colab: `from utils import check_state_dict` and `from models.birefnet import BiRefNet`
172
- if code_dir not in sys.path:
173
- sys.path.insert(0, code_dir)
174
- from utils import check_state_dict # type: ignore
175
- from models.birefnet import BiRefNet # type: ignore
176
- return layout, BiRefNet, check_state_dict
177
-
178
- # HF layout
179
- pkg, _ = _ensure_importable_package_dir(code_dir)
180
- mod = importlib.import_module(f"{pkg}.birefnet")
181
- if not hasattr(mod, "BiRefNet"):
182
- raise RuntimeError(f"BiRefNet class not found in {pkg}.birefnet")
183
- return layout, getattr(mod, "BiRefNet"), None
184
-
185
-
186
- # -------------------------
187
- # Weight loading helpers
188
- # -------------------------
189
-
190
- def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
191
- if isinstance(obj, dict):
192
- if obj and all(torch.is_tensor(v) for v in obj.values()):
193
- return obj # type: ignore[return-value]
194
- for k in ["state_dict", "model", "model_state_dict", "net", "params", "weights", "ema"]:
195
- if k in obj and isinstance(obj[k], dict) and obj[k] and all(torch.is_tensor(v) for v in obj[k].values()):
196
- return obj[k] # type: ignore[return-value]
197
- for v in obj.values():
198
- if isinstance(v, dict) and v and all(torch.is_tensor(tv) for tv in v.values()):
199
- return v # type: ignore[return-value]
200
- raise RuntimeError("Could not find a state_dict inside the checkpoint.")
201
-
202
-
203
- def _clean_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
204
- prefixes = ["module.", "_orig_mod.", "model.", "net.", "state_dict."]
205
- out: Dict[str, torch.Tensor] = {}
206
- for k, v in sd.items():
207
- nk = k
208
- changed = True
209
- while changed:
210
- changed = False
211
- for p in prefixes:
212
- if nk.startswith(p):
213
- nk = nk[len(p):]
214
- changed = True
215
- out[nk] = v
216
- return out
217
-
218
-
219
- def _pretty_list(xs: List[str], n: int = 20) -> List[str]:
220
- return xs[:n] + (["..."] if len(xs) > n else [])
221
-
222
-
223
- # -------------------------
224
- # Output selection / wrapper
225
- # -------------------------
226
-
227
- def _walk_tensors(x: Any) -> Iterable[torch.Tensor]:
228
- if torch.is_tensor(x):
229
- yield x
230
- return
231
- if isinstance(x, dict):
232
- for v in x.values():
233
- yield from _walk_tensors(v)
234
- elif isinstance(x, (list, tuple)):
235
- for v in x:
236
- yield from _walk_tensors(v)
237
-
238
-
239
- def _pick_output_tensor(model_out: Any, img_size: int) -> torch.Tensor:
240
- ts = list(_walk_tensors(model_out))
241
- if not ts:
242
- raise RuntimeError("Model forward returned no tensors.")
243
- # Prefer (B,1,H,W) at img_size
244
- for t in ts:
245
- if t.ndim == 4 and t.shape[1] in (1, 3) and t.shape[2] == img_size and t.shape[3] == img_size:
246
- return t
247
- # Next: any 4D tensor with H,W == img_size
248
- for t in ts:
249
- if t.ndim == 4 and t.shape[2] == img_size and t.shape[3] == img_size:
250
- return t
251
- # Else: largest tensor
252
- return max(ts, key=lambda z: z.numel())
253
-
254
-
255
- class ExportWrapper(nn.Module):
256
- def __init__(self, model: nn.Module, img_size: int):
257
- super().__init__()
258
- self.model = model
259
- self.img_size = img_size
260
-
261
- def forward(self, x: torch.Tensor) -> torch.Tensor:
262
- x = x.contiguous()
263
- out = self.model(x)
264
- return _pick_output_tensor(out, self.img_size)
265
-
266
-
267
- # -------------------------
268
- # Main
269
- # -------------------------
270
-
271
- def main() -> None:
272
- ap = argparse.ArgumentParser()
273
- ap.add_argument("--code_dir", required=True)
274
- ap.add_argument("--weights", required=True)
275
- ap.add_argument("--output", required=True)
276
- ap.add_argument("--img_size", type=int, default=1024)
277
- ap.add_argument("--opset", type=int, default=17)
278
- ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
279
- ap.add_argument("--skip_onnx_check", action="store_true")
280
- args = ap.parse_args()
281
-
282
- print("== Environment ==")
283
- print("Python:", sys.version.replace("\n", " "))
284
- print("Torch:", torch.__version__)
285
- print("CUDA available:", torch.cuda.is_available())
286
- print("Requested device:", args.device)
287
-
288
- if args.device == "cuda" and not torch.cuda.is_available():
289
- raise RuntimeError("You asked for --device cuda but CUDA is not available.")
290
-
291
- device = torch.device(args.device)
292
- print("Using device:", device)
293
-
294
- # IMPORTANT: patch deform_conv2d exporter BEFORE export
295
- _patch_and_register_deform_conv2d()
296
-
297
- layout, BiRefNet, check_state_dict = _import_birefnet(args.code_dir)
298
- print("BiRefNet layout detected:", layout)
299
-
300
- print("== Building model ==")
301
- kwargs = {}
302
- try:
303
- sig = inspect.signature(BiRefNet)
304
- if "bb_pretrained" in sig.parameters:
305
- kwargs["bb_pretrained"] = False
306
- except Exception:
307
- pass
308
-
309
- model = BiRefNet(**kwargs) if kwargs else BiRefNet()
310
- model.eval().to(device)
311
-
312
- print("== Loading weights ==")
313
- ckpt = torch.load(args.weights, map_location="cpu")
314
-
315
- if layout == "github" and check_state_dict is not None:
316
- # Colab-style path
317
- sd = check_state_dict(ckpt)
318
- missing, unexpected = model.load_state_dict(sd, strict=False)
319
- else:
320
- # HF-style path
321
- sd = _extract_state_dict(ckpt)
322
- sd = _clean_state_dict_keys(sd)
323
- missing, unexpected = model.load_state_dict(sd, strict=False)
324
-
325
- missing = list(missing)
326
- unexpected = list(unexpected)
327
- print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
328
- if missing:
329
- print(" (first 20 missing):", _pretty_list(missing, 20))
330
- if unexpected:
331
- print(" (first 20 unexpected):", _pretty_list(unexpected, 20))
332
-
333
- wrapper = ExportWrapper(model, img_size=args.img_size).eval().to(device)
334
-
335
- print("== Forward probe ==")
336
- dummy = torch.randn(1, 3, args.img_size, args.img_size, device=device)
337
- with torch.no_grad():
338
- out = wrapper(dummy)
339
- print("Picked output shape:", tuple(out.shape), "dtype:", out.dtype)
340
-
341
- print("== Exporting ONNX ==")
342
- out_path = os.path.abspath(args.output)
343
- os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
344
-
345
- # NOTE: No dynamic_axes by default (keeps shapes static and avoids shape None issues).
346
- torch.onnx.export(
347
- wrapper,
348
- dummy,
349
- out_path,
350
- export_params=True,
351
- opset_version=args.opset,
352
- do_constant_folding=True,
353
- input_names=["input"],
354
- output_names=["output"],
355
- verbose=False,
356
- )
357
-
358
- print("Saved ONNX to:", out_path)
359
-
360
- if not args.skip_onnx_check:
361
- print("== Checking ONNX ==")
362
- import onnx
363
- m = onnx.load(out_path)
364
- onnx.checker.check_model(m)
365
- print("ONNX check: OK")
366
-
367
- try:
368
- mb = os.path.getsize(out_path) / (1024 * 1024)
369
- print(f"ONNX size: {mb:.1f} MB")
370
- except Exception:
371
- pass
372
-
373
-
374
- if __name__ == "__main__":
375
- main()
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BiRefNet .pth -> ONNX exporter (CPU/GPU), with robust deform_conv2d ONNX patch.
4
+
5
+ Fixes:
6
+ - deform_conv2d_onnx_exporter get_tensor_dim_size returning None (NoneType + int crash)
7
+ - checkpoints saved with _orig_mod. prefix (torch.compile)
8
+ - supports code_dir layouts:
9
+ A) HuggingFace-style: code_dir/birefnet.py (class BiRefNet inside)
10
+ B) GitHub-style: code_dir/models/birefnet.py + code_dir/utils.py
11
+
12
+ Recommended baseline: torch==2.0.1, opset 17, fixed input size (e.g. 1024x1024).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import importlib
19
+ import inspect
20
+ import os
21
+ import re
22
+ import sys
23
+ from typing import Any, Dict, Iterable, List, Tuple
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ # -------------------------
30
+ # DeformConv2d ONNX patching
31
+ # -------------------------
32
+
33
+ def _patch_and_register_deform_conv2d() -> None:
34
+ """
35
+ Patch deform_conv2d_onnx_exporter.get_tensor_dim_size so it never returns None
36
+ for H/W when possible (fallback to tensor type sizes/strides), then register the op.
37
+
38
+ This specifically fixes:
39
+ TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
40
+ at create_dcn_params(...): in_h = get_tensor_dim_size(input, 2) + ...
41
+ """
42
+ try:
43
+ import deform_conv2d_onnx_exporter as d
44
+ import torch.onnx.symbolic_helper as sym_help
45
+ except Exception as e:
46
+ print(f"[deform_conv2d] exporter not available ({type(e).__name__}: {e})")
47
+ return
48
+
49
+ if not hasattr(d, "get_tensor_dim_size"):
50
+ print("[deform_conv2d] deform_conv2d_onnx_exporter.get_tensor_dim_size not found; cannot patch.")
51
+ return
52
+
53
+ orig_get = d.get_tensor_dim_size
54
+
55
+ def patched_get_tensor_dim_size(tensor, dim: int):
56
+ # 1) Try original
57
+ v = orig_get(tensor, dim)
58
+ if v is not None:
59
+ return v
60
+
61
+ # 2) Try torch's internal tensor sizes helper (sometimes more available than _get_tensor_dim_size)
62
+ try:
63
+ sizes = sym_help._get_tensor_sizes(tensor) # type: ignore[attr-defined]
64
+ if sizes is not None and len(sizes) > dim and sizes[dim] is not None:
65
+ return int(sizes[dim])
66
+ except Exception:
67
+ pass
68
+
69
+ # 3) Try TensorType sizes/strides (Colab-style fallback)
70
+ try:
71
+ import typing
72
+ from torch import _C
73
+
74
+ ttype = typing.cast(_C.TensorType, tensor.type())
75
+ tsizes = ttype.sizes()
76
+ if tsizes is not None and len(tsizes) > dim and tsizes[dim] is not None:
77
+ return int(tsizes[dim])
78
+
79
+ tstrides = ttype.strides()
80
+ # For contiguous NCHW: strides = (C*H*W, H*W, W, 1)
81
+ if tstrides is not None and len(tstrides) >= 4:
82
+ s0, s1, s2, s3 = tstrides[0], tstrides[1], tstrides[2], tstrides[3]
83
+
84
+ if dim == 3 and s2 is not None:
85
+ return int(s2) # W
86
+
87
+ if dim == 2 and s1 is not None and s2 not in (None, 0):
88
+ return int(s1 // s2) # H = (H*W)/W
89
+
90
+ if dim == 1 and s0 is not None and s1 not in (None, 0):
91
+ return int(s0 // s1) # C = (C*H*W)/(H*W)
92
+
93
+ if dim == 0:
94
+ # We export with batch=1 dummy input; safe fallback.
95
+ return 1
96
+ except Exception:
97
+ pass
98
+
99
+ # 4) Last-resort: batch=1 fallback, otherwise hard error with actionable message
100
+ if dim == 0:
101
+ return 1
102
+
103
+ raise RuntimeError(
104
+ f"[deform_conv2d] Could not infer static dim={dim} for a tensor during ONNX export "
105
+ f"(got None from torch). This typically happens with dynamic axes or missing shape info. "
106
+ f"Use a fixed input size (no dynamic axes) and export again."
107
+ )
108
+
109
+ d.get_tensor_dim_size = patched_get_tensor_dim_size # type: ignore[assignment]
110
+
111
+ # Register op after patching so the symbolic uses our patched helper at runtime
112
+ try:
113
+ d.register_deform_conv2d_onnx_op()
114
+ print("[deform_conv2d] Patched get_tensor_dim_size + registered deform_conv2d ONNX op.")
115
+ except Exception as e:
116
+ print(f"[deform_conv2d] register_deform_conv2d_onnx_op failed ({type(e).__name__}: {e})")
117
+
118
+
119
+ # -------------------------
120
+ # BiRefNet importing helpers
121
+ # -------------------------
122
+
123
+ def _ensure_importable_package_dir(code_dir: str) -> Tuple[str, str]:
124
+ """
125
+ Make code_dir importable as a package so relative imports inside it work.
126
+ Used for HF-style code_dir that contains birefnet.py and BiRefNet_config.py.
127
+ """
128
+ code_dir = os.path.abspath(code_dir)
129
+ parent = os.path.dirname(code_dir)
130
+ pkg = os.path.basename(code_dir)
131
+
132
+ init_py = os.path.join(code_dir, "__init__.py")
133
+ if not os.path.exists(init_py):
134
+ open(init_py, "a", encoding="utf-8").close()
135
+
136
+ if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", pkg):
137
+ safe_pkg = "birefnet_pkg"
138
+ safe_dir = os.path.join(parent, safe_pkg)
139
+ if not os.path.exists(safe_dir):
140
+ os.symlink(code_dir, safe_dir)
141
+ pkg = safe_pkg
142
+ code_dir = safe_dir
143
+ init_py = os.path.join(code_dir, "__init__.py")
144
+ if not os.path.exists(init_py):
145
+ open(init_py, "a", encoding="utf-8").close()
146
+
147
+ if parent not in sys.path:
148
+ sys.path.insert(0, parent)
149
+
150
+ return pkg, code_dir
151
+
152
+
153
+ def _detect_layout(code_dir: str) -> str:
154
+ code_dir = os.path.abspath(code_dir)
155
+ if os.path.isfile(os.path.join(code_dir, "models", "birefnet.py")) and os.path.isfile(os.path.join(code_dir, "utils.py")):
156
+ return "github"
157
+ if os.path.isfile(os.path.join(code_dir, "birefnet.py")):
158
+ return "hf"
159
+ raise FileNotFoundError(
160
+ f"Could not detect BiRefNet layout in {code_dir}.\n"
161
+ f"Expected either:\n"
162
+ f" - GitHub layout: models/birefnet.py and utils.py\n"
163
+ f" - HF layout: birefnet.py\n"
164
+ )
165
+
166
+
167
+ def _import_birefnet(code_dir: str):
168
+ layout = _detect_layout(code_dir)
169
+
170
+ if layout == "github":
171
+ # Mirror Colab: `from utils import check_state_dict` and `from models.birefnet import BiRefNet`
172
+ if code_dir not in sys.path:
173
+ sys.path.insert(0, code_dir)
174
+ from utils import check_state_dict # type: ignore
175
+ from models.birefnet import BiRefNet # type: ignore
176
+ return layout, BiRefNet, check_state_dict
177
+
178
+ # HF layout
179
+ pkg, _ = _ensure_importable_package_dir(code_dir)
180
+ mod = importlib.import_module(f"{pkg}.birefnet")
181
+ if not hasattr(mod, "BiRefNet"):
182
+ raise RuntimeError(f"BiRefNet class not found in {pkg}.birefnet")
183
+ return layout, getattr(mod, "BiRefNet"), None
184
+
185
+
186
+ # -------------------------
187
+ # Weight loading helpers
188
+ # -------------------------
189
+
190
+ def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
191
+ if isinstance(obj, dict):
192
+ if obj and all(torch.is_tensor(v) for v in obj.values()):
193
+ return obj # type: ignore[return-value]
194
+ for k in ["state_dict", "model", "model_state_dict", "net", "params", "weights", "ema"]:
195
+ if k in obj and isinstance(obj[k], dict) and obj[k] and all(torch.is_tensor(v) for v in obj[k].values()):
196
+ return obj[k] # type: ignore[return-value]
197
+ for v in obj.values():
198
+ if isinstance(v, dict) and v and all(torch.is_tensor(tv) for tv in v.values()):
199
+ return v # type: ignore[return-value]
200
+ raise RuntimeError("Could not find a state_dict inside the checkpoint.")
201
+
202
+
203
+ def _clean_state_dict_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
204
+ prefixes = ["module.", "_orig_mod.", "model.", "net.", "state_dict."]
205
+ out: Dict[str, torch.Tensor] = {}
206
+ for k, v in sd.items():
207
+ nk = k
208
+ changed = True
209
+ while changed:
210
+ changed = False
211
+ for p in prefixes:
212
+ if nk.startswith(p):
213
+ nk = nk[len(p):]
214
+ changed = True
215
+ out[nk] = v
216
+ return out
217
+
218
+
219
+ def _pretty_list(xs: List[str], n: int = 20) -> List[str]:
220
+ return xs[:n] + (["..."] if len(xs) > n else [])
221
+
222
+
223
+ # -------------------------
224
+ # Output selection / wrapper
225
+ # -------------------------
226
+
227
+ def _walk_tensors(x: Any) -> Iterable[torch.Tensor]:
228
+ if torch.is_tensor(x):
229
+ yield x
230
+ return
231
+ if isinstance(x, dict):
232
+ for v in x.values():
233
+ yield from _walk_tensors(v)
234
+ elif isinstance(x, (list, tuple)):
235
+ for v in x:
236
+ yield from _walk_tensors(v)
237
+
238
+
239
+ def _pick_output_tensor(model_out: Any, height: int, width: int) -> torch.Tensor:
240
+ ts = list(_walk_tensors(model_out))
241
+ if not ts:
242
+ raise RuntimeError("Model forward returned no tensors.")
243
+ # Prefer (B,1,H,W) at (height,width)
244
+ for t in ts:
245
+ if t.ndim == 4 and t.shape[1] in (1, 3) and t.shape[2] == height and t.shape[3] == width:
246
+ return t
247
+ # Next: any 4D tensor with H,W == (height,width)
248
+ for t in ts:
249
+ if t.ndim == 4 and t.shape[2] == height and t.shape[3] == width:
250
+ return t
251
+ # Else: largest tensor
252
+ return max(ts, key=lambda z: z.numel())
253
+
254
+
255
+ class ExportWrapper(nn.Module):
256
+ def __init__(self, model: nn.Module, height: int, width: int):
257
+ super().__init__()
258
+ self.model = model
259
+ self.height = height
260
+ self.width = width
261
+
262
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
263
+ x = x.contiguous()
264
+ out = self.model(x)
265
+ return _pick_output_tensor(out, self.height, self.width)
266
+
267
+
268
+
269
+ # -------------------------
270
+ # Main
271
+ # -------------------------
272
+
273
+ def main() -> None:
274
+ ap = argparse.ArgumentParser()
275
+ ap.add_argument("--code_dir", required=True)
276
+ ap.add_argument("--weights", required=True)
277
+ ap.add_argument("--output", required=True)
278
+ ap.add_argument("--width", type=int, default=1024)
279
+ ap.add_argument("--height", type=int, default=1024)
280
+ ap.add_argument("--opset", type=int, default=17)
281
+ ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
282
+ ap.add_argument("--skip_onnx_check", action="store_true")
283
+ args = ap.parse_args()
284
+
285
+ print("== Environment ==")
286
+ print("Python:", sys.version.replace("\n", " "))
287
+ print("Torch:", torch.__version__)
288
+ print("CUDA available:", torch.cuda.is_available())
289
+ print("Requested device:", args.device)
290
+
291
+ if args.device == "cuda" and not torch.cuda.is_available():
292
+ raise RuntimeError("You asked for --device cuda but CUDA is not available.")
293
+
294
+ device = torch.device(args.device)
295
+ print("Using device:", device)
296
+
297
+ # IMPORTANT: patch deform_conv2d exporter BEFORE export
298
+ _patch_and_register_deform_conv2d()
299
+
300
+ layout, BiRefNet, check_state_dict = _import_birefnet(args.code_dir)
301
+ print("BiRefNet layout detected:", layout)
302
+
303
+ print("== Building model ==")
304
+ kwargs = {}
305
+ try:
306
+ sig = inspect.signature(BiRefNet)
307
+ if "bb_pretrained" in sig.parameters:
308
+ kwargs["bb_pretrained"] = False
309
+ except Exception:
310
+ pass
311
+
312
+ model = BiRefNet(**kwargs) if kwargs else BiRefNet()
313
+ model.eval().to(device)
314
+
315
+ print("== Loading weights ==")
316
+ ckpt = torch.load(args.weights, map_location="cpu")
317
+
318
+ if layout == "github" and check_state_dict is not None:
319
+ # Colab-style path
320
+ sd = check_state_dict(ckpt)
321
+ missing, unexpected = model.load_state_dict(sd, strict=False)
322
+ else:
323
+ # HF-style path
324
+ sd = _extract_state_dict(ckpt)
325
+ sd = _clean_state_dict_keys(sd)
326
+ missing, unexpected = model.load_state_dict(sd, strict=False)
327
+
328
+ missing = list(missing)
329
+ unexpected = list(unexpected)
330
+ print(f"Loaded state_dict. Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
331
+ if missing:
332
+ print(" (first 20 missing):", _pretty_list(missing, 20))
333
+ if unexpected:
334
+ print(" (first 20 unexpected):", _pretty_list(unexpected, 20))
335
+
336
+ wrapper = ExportWrapper(model, height=args.height, width=args.width).eval().to(device)
337
+
338
+
339
+ print("== Forward probe ==")
340
+ dummy = torch.randn(1, 3, args.height, args.width, device=device)
341
+
342
+ with torch.no_grad():
343
+ out = wrapper(dummy)
344
+ print("Picked output shape:", tuple(out.shape), "dtype:", out.dtype)
345
+
346
+ print("== Exporting ONNX ==")
347
+ out_path = os.path.abspath(args.output)
348
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
349
+
350
+ # NOTE: No dynamic_axes by default (keeps shapes static and avoids shape None issues).
351
+ torch.onnx.export(
352
+ wrapper,
353
+ dummy,
354
+ out_path,
355
+ export_params=True,
356
+ opset_version=args.opset,
357
+ do_constant_folding=True,
358
+ input_names=["input"],
359
+ output_names=["output"],
360
+ verbose=False,
361
+ )
362
+
363
+ print("Saved ONNX to:", out_path)
364
+
365
+ if not args.skip_onnx_check:
366
+ print("== Checking ONNX ==")
367
+ import onnx
368
+ m = onnx.load(out_path)
369
+ onnx.checker.check_model(m)
370
+ print("ONNX check: OK")
371
+
372
+ try:
373
+ mb = os.path.getsize(out_path) / (1024 * 1024)
374
+ print(f"ONNX size: {mb:.1f} MB")
375
+ except Exception:
376
+ pass
377
+
378
+
379
+ if __name__ == "__main__":
380
+ main()